diff --git a/include/npcomp/Conversion/Passes.td b/include/npcomp/Conversion/Passes.td index 532e4f8cf..b43d20760 100644 --- a/include/npcomp/Conversion/Passes.td +++ b/include/npcomp/Conversion/Passes.td @@ -38,6 +38,21 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> { let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()"; } +//===----------------------------------------------------------------------===// +// TCFToTCP +//===----------------------------------------------------------------------===// + +def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> { + let summary = "Convert TCF to Linalg"; + let description = [{ + The intention is for this pass to convert mainly to linalg named ops. + + Because linalg is at the "TCP" layer of abstraction, this pass has to + concern itself with generating guards for error cases. + }]; + let constructor = "mlir::NPCOMP::createConvertTCFToLinalgPass()"; +} + //===----------------------------------------------------------------------===// // TCFToStd //===----------------------------------------------------------------------===// diff --git a/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h b/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h new file mode 100644 index 000000000..1bbfc588c --- /dev/null +++ b/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h @@ -0,0 +1,21 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H +#define NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace NPCOMP { +std::unique_ptr> createConvertTCFToLinalgPass(); +} +} // namespace mlir + +#endif // NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H diff --git a/include/npcomp/Dialect/TCP/IR/TCPOps.td b/include/npcomp/Dialect/TCP/IR/TCPOps.td index 4be0150c3..81f90f0fa 100644 --- a/include/npcomp/Dialect/TCP/IR/TCPOps.td +++ b/include/npcomp/Dialect/TCP/IR/TCPOps.td @@ -20,33 +20,6 @@ class TCP_Op traits = []> : Op { } -// TODO: Generalize this op appropriately and add more verification. -// For example, should we have a single primitive that does multidimensional -// contractions? + batching as well in the same op? In fact, if we want to -// get really general, we can include convolution as well; matmul is the 1x1 -// image and 1x1 kernel special case. -// It still lowers trivially into linalg.generic even with such generalization -// -- the main question is what transforms we want to do at the TCP level that -// would be affected by those design choices. -def TCP_MatmulOp : TCP_Op<"matmul"> { - let summary = "Performs a matrix multiplication"; - let description = [{ - Performs a matrix multiplication. - - The tensors have dimensions: - - lhs: [M, K] - - rhs: [K, N] - - result: [M, N] - - If the `K` dimension mismatches between operands, this op has - undefined behavior. - }]; - let arguments = (ins 2DTensorOf<[F32]>:$lhs, 2DTensorOf<[F32]>:$rhs); - let results = (outs 2DTensorOf<[F32]>:$result); - - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; -} - def TCP_BroadcastToOp : TCP_Op<"broadcast_to"> { let summary = "Broadcasts an operand to a given shape."; let description = [{ @@ -60,4 +33,25 @@ It is undefined behavior if such a broadcast is not legal. let assemblyFormat = "$operand `,` $shape attr-dict `:` functional-type(operands, results)"; } +def TCP_SplattedOp : TCP_Op<"splatted"> { + let summary = "Creates a tensor filled with a particular scalar value."; + let description = [{ + Creates a tensor of shape `shape` with all elements filled with `splatVal`. + + This op is somewhat redundant with tcp.broadcast_to. However, + tcp.broadcast_to handles degenerate "size-1" broadcasting which structurally + cannot happen with this op. So to avoid losing that information, we keep + this op separate. + + NOTE: The name "splatted" separates it from std.splat, which currently + only handles statically shaped memrefs. + + TODO: Improve std.splat to take dynamic shapes. + }]; + let arguments = (ins AnyType:$splatVal, Shape_ExtentTensorType:$shape); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = "$splatVal `,` $shape attr-dict `:` functional-type(operands, results)"; +} + #endif // TCP_OPS diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 2b24c8a96..7bea29748 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(ATenToTCF) add_subdirectory(BasicpyToStd) add_subdirectory(NumpyToTCF) +add_subdirectory(TCFToLinalg) add_subdirectory(TCFToStd) add_subdirectory(TCFToTCP) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 5262075c5..01e4d0f2a 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -11,6 +11,7 @@ #include "npcomp/Conversion/ATenToTCF/Passes.h" #include "npcomp/Conversion/BasicpyToStd/Passes.h" #include "npcomp/Conversion/NumpyToTCF/Passes.h" +#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h" #include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" diff --git a/lib/Conversion/TCFToLinalg/CMakeLists.txt b/lib/Conversion/TCFToLinalg/CMakeLists.txt new file mode 100644 index 000000000..9b26f9519 --- /dev/null +++ b/lib/Conversion/TCFToLinalg/CMakeLists.txt @@ -0,0 +1,19 @@ +add_npcomp_conversion_library(NPCOMPTCFToLinalg + TCFToLinalg.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TCFToLinalg + + DEPENDS + NPCOMPConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + MLIRShape + NPCOMPTCFDialect +) diff --git a/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp b/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp new file mode 100644 index 000000000..35c9f54fa --- /dev/null +++ b/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "npcomp/Dialect/TCF/IR/TCFOps.h" +#include "npcomp/Dialect/TCP/IR/TCPDialect.h" +#include "npcomp/Dialect/TCP/IR/TCPOps.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +static SmallVector bypassResultShapes(Operation *op, + OpBuilder &builder) { + + if (auto matmul = dyn_cast(op)) { + auto lhsRows = builder.create(op->getLoc(), matmul.lhs(), 0); + auto rhsCols = builder.create(op->getLoc(), matmul.rhs(), 1); + auto shape = builder.create( + op->getLoc(), ValueRange({lhsRows, rhsCols})); + return {shape}; + } + + // No shape transfer function. + return {}; +} + +namespace { +class ConvertMatmul : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tcf::MatmulOp op, + PatternRewriter &rewriter) const override { + // Create the constraints, and the assuming region. + Value lhsK = rewriter.create(op.getLoc(), op.lhs(), 1); + Value rhsK = rewriter.create(op.getLoc(), op.rhs(), 0); + Value matchingK = + rewriter.create(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK); + Value witness = rewriter.create( + op.getLoc(), matchingK, "mismatching contracting dimension for matmul"); + auto assuming = rewriter.create( + op.getLoc(), ArrayRef{op.getType()}, witness); + + // Build the region body. + rewriter.createBlock(&assuming.doRegion()); + // Create the init tensor for the matmul. + // TODO: Expand supported data types. + Value c0 = + rewriter.create(op.getLoc(), rewriter.getF32FloatAttr(0.0)); + Value shape = bypassResultShapes(op, rewriter)[0]; + Value initTensor = + rewriter.create(op.getLoc(), op.getType(), c0, shape); + + // Create the matmul. + auto matmul = rewriter.create( + op.getLoc(), TypeRange(op.getType()), op.getOperands(), ValueRange(), + ValueRange(initTensor)); + rewriter.create(op.getLoc(), matmul.getResult(0)); + + // Finally, replace with the results of the shape.assuming + rewriter.replaceOp(op, assuming.getResults()); + return success(); + } +}; +} // namespace + +namespace { +class ConvertTCFToLinalg : public ConvertTCFToLinalgBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + (void)applyPatternsAndFoldGreedily(module, getPatterns()); + } + + FrozenRewritePatternList getPatterns() { + MLIRContext *context = &getContext(); + OwningRewritePatternList patterns; + patterns.insert(context); + return std::move(patterns); + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::createConvertTCFToLinalgPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/TCFToTCP/TCFToTCP.cpp b/lib/Conversion/TCFToTCP/TCFToTCP.cpp index 94572c781..dba2dd0c3 100644 --- a/lib/Conversion/TCFToTCP/TCFToTCP.cpp +++ b/lib/Conversion/TCFToTCP/TCFToTCP.cpp @@ -21,35 +21,6 @@ using namespace mlir; using namespace mlir::NPCOMP; -namespace { -class ConvertMatmul : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tcf::MatmulOp op, - PatternRewriter &rewriter) const override { - // Create the constraints, and the assuming region. - Value lhsK = rewriter.create(op.getLoc(), op.lhs(), 1); - Value rhsK = rewriter.create(op.getLoc(), op.rhs(), 0); - Value matchingK = - rewriter.create(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK); - Value witness = rewriter.create( - op.getLoc(), matchingK, "mismatching contracting dimension for matmul"); - auto assuming = rewriter.create( - op.getLoc(), ArrayRef{op.getType()}, witness); - - // Build the region body. - rewriter.createBlock(&assuming.doRegion()); - Value matmul = rewriter.create(op.getLoc(), op.getType(), - op.lhs(), op.rhs()); - rewriter.create(op.getLoc(), matmul); - - // Finally, replace with the results of the shape.assuming - rewriter.replaceOp(op, assuming.getResults()); - return success(); - } -}; -} // namespace - namespace { class ConvertTCFToTCP : public ConvertTCFToTCPBase { public: @@ -63,9 +34,10 @@ public: } FrozenRewritePatternList getPatterns() { - MLIRContext *context = &getContext(); + // NOTE: We are keeping this pass around, even though it currently does + // nothing, in order to avoid having to reintroduce the same + // boilerplate. OwningRewritePatternList patterns; - patterns.insert(context); return std::move(patterns); } }; diff --git a/lib/Dialect/TCP/Transforms/Bufferize.cpp b/lib/Dialect/TCP/Transforms/Bufferize.cpp index 8fabfbcd9..058458711 100644 --- a/lib/Dialect/TCP/Transforms/Bufferize.cpp +++ b/lib/Dialect/TCP/Transforms/Bufferize.cpp @@ -32,12 +32,8 @@ static SmallVector bypassResultShapes(Operation &op) { return {broadcastTo.shape()}; } - if (auto matmul = dyn_cast(op)) { - auto lhsRows = builder.create(op.getLoc(), matmul.lhs(), 0); - auto rhsCols = builder.create(op.getLoc(), matmul.rhs(), 1); - auto shape = builder.create( - op.getLoc(), ValueRange({lhsRows, rhsCols})); - return {shape}; + if (auto splatted = dyn_cast(op)) { + return {splatted.shape()}; } // No shape transfer function. @@ -144,20 +140,17 @@ public: } // namespace namespace { -class BufferizeMatmulOp : public OpConversionPattern { +class BufferizeSplattedOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tcp::MatmulOp op, ArrayRef operands, + matchAndRewrite(tcp::SplattedOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc()); if (failed(resultsOrFailure)) return failure(); auto results = *resultsOrFailure; - auto c0 = - rewriter.create(op.getLoc(), rewriter.getF32FloatAttr(0.0)); - rewriter.create(op.getLoc(), results[0], c0); - rewriter.create(op.getLoc(), operands, results); + rewriter.create(op.getLoc(), results[0], op.splatVal()); rewriter.replaceOp(op, results); return success(); } @@ -190,8 +183,8 @@ class TCPBufferizePass : public TCPBufferizeBase { patterns.insert(typeConverter, context); target.addIllegalOp(); - patterns.insert(typeConverter, context); - target.addIllegalOp(); + patterns.insert(typeConverter, context); + target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 137a34c32..6b998e417 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -40,6 +40,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h" #include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" #include "npcomp/Dialect/Refback/IR/RefbackOps.h" @@ -305,6 +306,7 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline( // // TCP does not. So we need to reify the broadcasting and error checking. pm.addPass(createConvertTCFToStdPass()); + pm.addPass(createConvertTCFToLinalgPass()); pm.addPass(createConvertTCFToTCPPass()); createRefBackendLoweringPipeline(pm, options); diff --git a/test/Conversion/TCFToLinalg/basic.mlir b/test/Conversion/TCFToLinalg/basic.mlir new file mode 100644 index 000000000..e72dd0b49 --- /dev/null +++ b/test/Conversion/TCFToLinalg/basic.mlir @@ -0,0 +1,25 @@ +// RUN: npcomp-opt <%s -convert-tcf-to-linalg | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @tcf_matmul( +// CHECK-SAME: %[[LHS:.*]]: tensor, +// CHECK-SAME: %[[RHS:.*]]: tensor) -> tensor { +// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor +// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor +// CHECK: %[[KEQUAL:.*]] = cmpi "eq", %[[LHSK]], %[[RHSK]] : index +// CHECK: %[[WINESS:.*]] = shape.cstr_require %[[KEQUAL]], "mismatching contracting dimension for matmul" +// CHECK: %[[RET:.*]] = shape.assuming %[[WINESS]] -> (tensor) { +// CHECK: %[[LHSROWS:.*]] = dim %[[LHS]], %[[C0]] : tensor +// CHECK: %[[RHSCOLS:.*]] = dim %[[RHS]], %[[C1]] : tensor +// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[LHSROWS]], %[[RHSCOLS]] : tensor<2xindex> +// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<2xindex>) -> tensor +// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor, tensor) init(%[[INIT_TENSOR]] : tensor) -> tensor +// CHECK: shape.assuming_yield %[[MATMUL]] : tensor +// CHECK: } +// CHECK: return %[[RET:.*]] : tensor +func @tcf_matmul(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tcf.matmul %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/test/Conversion/TCFToTCP/basic.mlir b/test/Conversion/TCFToTCP/basic.mlir index d9c884db2..308c90154 100644 --- a/test/Conversion/TCFToTCP/basic.mlir +++ b/test/Conversion/TCFToTCP/basic.mlir @@ -1,21 +1,10 @@ -// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail +// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s -// CHECK-LABEL: func @tcf_matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor, -// CHECK-SAME: %[[RHS:.*]]: tensor) -> tensor { -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor -// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[KEQUAL:.*]] = cmpi "eq", %[[LHSK]], %[[RHSK]] : index -// CHECK: %[[WITNESS:.*]] = shape.cstr_require %[[KEQUAL]], "{{.*}}" -// CHECK: %[[RET:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { -// CHECK: %[[MATMUL:.*]] = tcp.matmul %[[LHS]], %[[RHS]] : (tensor, tensor) -> tensor -// CHECK: shape.assuming_yield %[[MATMUL]] : tensor -// CHECK: } -// CHECK: return %[[RET:.*]] : tensor -// CHECK: } -func @tcf_matmul(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = tcf.matmul %arg0, %arg1 : (tensor, tensor) -> tensor - return %0 : tensor +// NOTE: We are keeping this pass around, even though it currently does +// nothing, in order to avoid having to reintroduce the same +// boilerplate. + +// CHECK: @f +func @f() { + return } diff --git a/test/Dialect/TCP/bufferize.mlir b/test/Dialect/TCP/bufferize.mlir index 96dea0d51..16fd050b2 100644 --- a/test/Dialect/TCP/bufferize.mlir +++ b/test/Dialect/TCP/bufferize.mlir @@ -14,24 +14,14 @@ func @tcp_broadcast_to(%arg0: tensor, %arg1: tensor) -> tensor } -// CHECK-LABEL: func @tcp_matmul( -// CHECK-SAME: %[[LHS_TENSOR:.*]]: tensor, -// CHECK-SAME: %[[RHS_TENSOR:.*]]: tensor) -> tensor { -// CHECK: %[[LHS:.*]] = tensor_to_memref %[[LHS_TENSOR]] : memref -// CHECK: %[[RHS:.*]] = tensor_to_memref %[[RHS_TENSOR]] : memref -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[LHS_ROWS:.*]] = dim %[[LHS_TENSOR]], %[[C0]] : tensor -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[RHS_COLS:.*]] = dim %[[RHS_TENSOR]], %[[C1]] : tensor -// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[LHS_ROWS]], %[[RHS_COLS]] : tensor<2xindex> +// CHECK-LABEL: func @tcp_splatted( +// CHECK-SAME: %[[SPLAT_VAL:.*]]: f32, +// CHECK-SAME: %[[SHAPE:.*]]: tensor) -> tensor { // CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref -// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32 -// CHECK: linalg.fill(%[[RESULT]], %[[C0F32]]) : memref, f32 -// CHECK: linalg.matmul ins(%[[LHS]], %[[RHS]] : memref, memref) outs(%[[RESULT]] : memref) +// CHECK: linalg.fill(%[[RESULT]], %[[SPLAT_VAL]]) : memref, f32 // CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT]] : memref // CHECK: return %[[RESULT_TENSOR]] : tensor -// CHECK: } -func @tcp_matmul(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = tcp.matmul %arg0, %arg1 : (tensor, tensor) -> tensor +func @tcp_splatted(%arg0: f32, %arg1: tensor) -> tensor { + %0 = tcp.splatted %arg0, %arg1 : (f32, tensor) -> tensor return %0 : tensor } diff --git a/test/Dialect/TCP/ops.mlir b/test/Dialect/TCP/ops.mlir index 80f07bde3..ca08c173e 100644 --- a/test/Dialect/TCP/ops.mlir +++ b/test/Dialect/TCP/ops.mlir @@ -1,8 +1,15 @@ // RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s -// CHECK-LABEL: func @matmul -func @matmul(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: tcp.matmul %arg0, %arg1 : (tensor, tensor) -> tensor - %0 = tcp.matmul %arg0, %arg1 : (tensor, tensor) -> tensor +// CHECK-LABEL: @broadcast_to +func @broadcast_to(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tcp.broadcast_to + %0 = tcp.broadcast_to %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @splatted +func @splatted(%arg0: f32, %arg1: tensor) -> tensor { + // CHECK: tcp.splatted + %0 = tcp.splatted %arg0, %arg1 : (f32, tensor) -> tensor return %0 : tensor }