From 1c7c362e296c4a83499cc683f1856228c99f85dc Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Mon, 9 Nov 2020 15:49:22 -0800 Subject: [PATCH] [TCP] Replace tcp.matmul with linalg.matmul. This involved adding a `tcp.splatted` op to splat a dynamically sized init tensor. See rationale in TCPOps.td docs. One interesting observation is that when lowering tcf.matmul to linalg.matmul, we need to both 1) create the error checks and 2) calculate a shape transfer function to create the init tensors. Previously, 2) was deferred to bufferizing tcp.matmul later. I'm not sure if this is a conflation of concerns or not. For now, it's not a big burden. --- include/npcomp/Conversion/Passes.td | 15 +++ .../Conversion/TCFToLinalg/TCFToLinalg.h | 21 ++++ include/npcomp/Dialect/TCP/IR/TCPOps.td | 48 ++++---- lib/Conversion/CMakeLists.txt | 1 + lib/Conversion/Passes.cpp | 1 + lib/Conversion/TCFToLinalg/CMakeLists.txt | 19 ++++ lib/Conversion/TCFToLinalg/TCFToLinalg.cpp | 103 ++++++++++++++++++ lib/Conversion/TCFToTCP/TCFToTCP.cpp | 34 +----- lib/Dialect/TCP/Transforms/Bufferize.cpp | 21 ++-- lib/RefBackend/RefBackend.cpp | 2 + test/Conversion/TCFToLinalg/basic.mlir | 25 +++++ test/Conversion/TCFToTCP/basic.mlir | 27 ++--- test/Dialect/TCP/bufferize.mlir | 22 +--- test/Dialect/TCP/ops.mlir | 15 ++- 14 files changed, 243 insertions(+), 111 deletions(-) create mode 100644 include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h create mode 100644 lib/Conversion/TCFToLinalg/CMakeLists.txt create mode 100644 lib/Conversion/TCFToLinalg/TCFToLinalg.cpp create mode 100644 test/Conversion/TCFToLinalg/basic.mlir diff --git a/include/npcomp/Conversion/Passes.td b/include/npcomp/Conversion/Passes.td index 532e4f8cf..b43d20760 100644 --- a/include/npcomp/Conversion/Passes.td +++ b/include/npcomp/Conversion/Passes.td @@ -38,6 +38,21 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> { let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()"; } +//===----------------------------------------------------------------------===// +// TCFToTCP +//===----------------------------------------------------------------------===// + +def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> { + let summary = "Convert TCF to Linalg"; + let description = [{ + The intention is for this pass to convert mainly to linalg named ops. + + Because linalg is at the "TCP" layer of abstraction, this pass has to + concern itself with generating guards for error cases. + }]; + let constructor = "mlir::NPCOMP::createConvertTCFToLinalgPass()"; +} + //===----------------------------------------------------------------------===// // TCFToStd //===----------------------------------------------------------------------===// diff --git a/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h b/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h new file mode 100644 index 000000000..1bbfc588c --- /dev/null +++ b/include/npcomp/Conversion/TCFToLinalg/TCFToLinalg.h @@ -0,0 +1,21 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H +#define NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace NPCOMP { +std::unique_ptr> createConvertTCFToLinalgPass(); +} +} // namespace mlir + +#endif // NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H diff --git a/include/npcomp/Dialect/TCP/IR/TCPOps.td b/include/npcomp/Dialect/TCP/IR/TCPOps.td index 4be0150c3..81f90f0fa 100644 --- a/include/npcomp/Dialect/TCP/IR/TCPOps.td +++ b/include/npcomp/Dialect/TCP/IR/TCPOps.td @@ -20,33 +20,6 @@ class TCP_Op traits = []> : Op { } -// TODO: Generalize this op appropriately and add more verification. -// For example, should we have a single primitive that does multidimensional -// contractions? + batching as well in the same op? In fact, if we want to -// get really general, we can include convolution as well; matmul is the 1x1 -// image and 1x1 kernel special case. -// It still lowers trivially into linalg.generic even with such generalization -// -- the main question is what transforms we want to do at the TCP level that -// would be affected by those design choices. -def TCP_MatmulOp : TCP_Op<"matmul"> { - let summary = "Performs a matrix multiplication"; - let description = [{ - Performs a matrix multiplication. - - The tensors have dimensions: - - lhs: [M, K] - - rhs: [K, N] - - result: [M, N] - - If the `K` dimension mismatches between operands, this op has - undefined behavior. - }]; - let arguments = (ins 2DTensorOf<[F32]>:$lhs, 2DTensorOf<[F32]>:$rhs); - let results = (outs 2DTensorOf<[F32]>:$result); - - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)"; -} - def TCP_BroadcastToOp : TCP_Op<"broadcast_to"> { let summary = "Broadcasts an operand to a given shape."; let description = [{ @@ -60,4 +33,25 @@ It is undefined behavior if such a broadcast is not legal. let assemblyFormat = "$operand `,` $shape attr-dict `:` functional-type(operands, results)"; } +def TCP_SplattedOp : TCP_Op<"splatted"> { + let summary = "Creates a tensor filled with a particular scalar value."; + let description = [{ + Creates a tensor of shape `shape` with all elements filled with `splatVal`. + + This op is somewhat redundant with tcp.broadcast_to. However, + tcp.broadcast_to handles degenerate "size-1" broadcasting which structurally + cannot happen with this op. So to avoid losing that information, we keep + this op separate. + + NOTE: The name "splatted" separates it from std.splat, which currently + only handles statically shaped memrefs. + + TODO: Improve std.splat to take dynamic shapes. + }]; + let arguments = (ins AnyType:$splatVal, Shape_ExtentTensorType:$shape); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = "$splatVal `,` $shape attr-dict `:` functional-type(operands, results)"; +} + #endif // TCP_OPS diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 2b24c8a96..7bea29748 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(ATenToTCF) add_subdirectory(BasicpyToStd) add_subdirectory(NumpyToTCF) +add_subdirectory(TCFToLinalg) add_subdirectory(TCFToStd) add_subdirectory(TCFToTCP) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 5262075c5..01e4d0f2a 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -11,6 +11,7 @@ #include "npcomp/Conversion/ATenToTCF/Passes.h" #include "npcomp/Conversion/BasicpyToStd/Passes.h" #include "npcomp/Conversion/NumpyToTCF/Passes.h" +#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h" #include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" diff --git a/lib/Conversion/TCFToLinalg/CMakeLists.txt b/lib/Conversion/TCFToLinalg/CMakeLists.txt new file mode 100644 index 000000000..9b26f9519 --- /dev/null +++ b/lib/Conversion/TCFToLinalg/CMakeLists.txt @@ -0,0 +1,19 @@ +add_npcomp_conversion_library(NPCOMPTCFToLinalg + TCFToLinalg.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TCFToLinalg + + DEPENDS + NPCOMPConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + MLIRShape + NPCOMPTCFDialect +) diff --git a/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp b/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp new file mode 100644 index 000000000..35c9f54fa --- /dev/null +++ b/lib/Conversion/TCFToLinalg/TCFToLinalg.cpp @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "npcomp/Dialect/TCF/IR/TCFOps.h" +#include "npcomp/Dialect/TCP/IR/TCPDialect.h" +#include "npcomp/Dialect/TCP/IR/TCPOps.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +static SmallVector bypassResultShapes(Operation *op, + OpBuilder &builder) { + + if (auto matmul = dyn_cast(op)) { + auto lhsRows = builder.create(op->getLoc(), matmul.lhs(), 0); + auto rhsCols = builder.create(op->getLoc(), matmul.rhs(), 1); + auto shape = builder.create( + op->getLoc(), ValueRange({lhsRows, rhsCols})); + return {shape}; + } + + // No shape transfer function. + return {}; +} + +namespace { +class ConvertMatmul : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tcf::MatmulOp op, + PatternRewriter &rewriter) const override { + // Create the constraints, and the assuming region. + Value lhsK = rewriter.create(op.getLoc(), op.lhs(), 1); + Value rhsK = rewriter.create(op.getLoc(), op.rhs(), 0); + Value matchingK = + rewriter.create(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK); + Value witness = rewriter.create( + op.getLoc(), matchingK, "mismatching contracting dimension for matmul"); + auto assuming = rewriter.create( + op.getLoc(), ArrayRef{op.getType()}, witness); + + // Build the region body. + rewriter.createBlock(&assuming.doRegion()); + // Create the init tensor for the matmul. + // TODO: Expand supported data types. + Value c0 = + rewriter.create(op.getLoc(), rewriter.getF32FloatAttr(0.0)); + Value shape = bypassResultShapes(op, rewriter)[0]; + Value initTensor = + rewriter.create(op.getLoc(), op.getType(), c0, shape); + + // Create the matmul. + auto matmul = rewriter.create( + op.getLoc(), TypeRange(op.getType()), op.getOperands(), ValueRange(), + ValueRange(initTensor)); + rewriter.create(op.getLoc(), matmul.getResult(0)); + + // Finally, replace with the results of the shape.assuming + rewriter.replaceOp(op, assuming.getResults()); + return success(); + } +}; +} // namespace + +namespace { +class ConvertTCFToLinalg : public ConvertTCFToLinalgBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + (void)applyPatternsAndFoldGreedily(module, getPatterns()); + } + + FrozenRewritePatternList getPatterns() { + MLIRContext *context = &getContext(); + OwningRewritePatternList patterns; + patterns.insert(context); + return std::move(patterns); + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::createConvertTCFToLinalgPass() { + return std::make_unique(); +} diff --git a/lib/Conversion/TCFToTCP/TCFToTCP.cpp b/lib/Conversion/TCFToTCP/TCFToTCP.cpp index 94572c781..dba2dd0c3 100644 --- a/lib/Conversion/TCFToTCP/TCFToTCP.cpp +++ b/lib/Conversion/TCFToTCP/TCFToTCP.cpp @@ -21,35 +21,6 @@ using namespace mlir; using namespace mlir::NPCOMP; -namespace { -class ConvertMatmul : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tcf::MatmulOp op, - PatternRewriter &rewriter) const override { - // Create the constraints, and the assuming region. - Value lhsK = rewriter.create(op.getLoc(), op.lhs(), 1); - Value rhsK = rewriter.create(op.getLoc(), op.rhs(), 0); - Value matchingK = - rewriter.create(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK); - Value witness = rewriter.create( - op.getLoc(), matchingK, "mismatching contracting dimension for matmul"); - auto assuming = rewriter.create( - op.getLoc(), ArrayRef{op.getType()}, witness); - - // Build the region body. - rewriter.createBlock(&assuming.doRegion()); - Value matmul = rewriter.create(op.getLoc(), op.getType(), - op.lhs(), op.rhs()); - rewriter.create(op.getLoc(), matmul); - - // Finally, replace with the results of the shape.assuming - rewriter.replaceOp(op, assuming.getResults()); - return success(); - } -}; -} // namespace - namespace { class ConvertTCFToTCP : public ConvertTCFToTCPBase { public: @@ -63,9 +34,10 @@ public: } FrozenRewritePatternList getPatterns() { - MLIRContext *context = &getContext(); + // NOTE: We are keeping this pass around, even though it currently does + // nothing, in order to avoid having to reintroduce the same + // boilerplate. OwningRewritePatternList patterns; - patterns.insert(context); return std::move(patterns); } }; diff --git a/lib/Dialect/TCP/Transforms/Bufferize.cpp b/lib/Dialect/TCP/Transforms/Bufferize.cpp index 8fabfbcd9..058458711 100644 --- a/lib/Dialect/TCP/Transforms/Bufferize.cpp +++ b/lib/Dialect/TCP/Transforms/Bufferize.cpp @@ -32,12 +32,8 @@ static SmallVector bypassResultShapes(Operation &op) { return {broadcastTo.shape()}; } - if (auto matmul = dyn_cast(op)) { - auto lhsRows = builder.create(op.getLoc(), matmul.lhs(), 0); - auto rhsCols = builder.create(op.getLoc(), matmul.rhs(), 1); - auto shape = builder.create( - op.getLoc(), ValueRange({lhsRows, rhsCols})); - return {shape}; + if (auto splatted = dyn_cast(op)) { + return {splatted.shape()}; } // No shape transfer function. @@ -144,20 +140,17 @@ public: } // namespace namespace { -class BufferizeMatmulOp : public OpConversionPattern { +class BufferizeSplattedOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tcp::MatmulOp op, ArrayRef operands, + matchAndRewrite(tcp::SplattedOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc()); if (failed(resultsOrFailure)) return failure(); auto results = *resultsOrFailure; - auto c0 = - rewriter.create(op.getLoc(), rewriter.getF32FloatAttr(0.0)); - rewriter.create(op.getLoc(), results[0], c0); - rewriter.create(op.getLoc(), operands, results); + rewriter.create(op.getLoc(), results[0], op.splatVal()); rewriter.replaceOp(op, results); return success(); } @@ -190,8 +183,8 @@ class TCPBufferizePass : public TCPBufferizeBase { patterns.insert(typeConverter, context); target.addIllegalOp(); - patterns.insert(typeConverter, context); - target.addIllegalOp(); + patterns.insert(typeConverter, context); + target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 137a34c32..6b998e417 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -40,6 +40,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h" #include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" #include "npcomp/Dialect/Refback/IR/RefbackOps.h" @@ -305,6 +306,7 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline( // // TCP does not. So we need to reify the broadcasting and error checking. pm.addPass(createConvertTCFToStdPass()); + pm.addPass(createConvertTCFToLinalgPass()); pm.addPass(createConvertTCFToTCPPass()); createRefBackendLoweringPipeline(pm, options); diff --git a/test/Conversion/TCFToLinalg/basic.mlir b/test/Conversion/TCFToLinalg/basic.mlir new file mode 100644 index 000000000..e72dd0b49 --- /dev/null +++ b/test/Conversion/TCFToLinalg/basic.mlir @@ -0,0 +1,25 @@ +// RUN: npcomp-opt <%s -convert-tcf-to-linalg | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @tcf_matmul( +// CHECK-SAME: %[[LHS:.*]]: tensor, +// CHECK-SAME: %[[RHS:.*]]: tensor) -> tensor { +// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor +// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor +// CHECK: %[[KEQUAL:.*]] = cmpi "eq", %[[LHSK]], %[[RHSK]] : index +// CHECK: %[[WINESS:.*]] = shape.cstr_require %[[KEQUAL]], "mismatching contracting dimension for matmul" +// CHECK: %[[RET:.*]] = shape.assuming %[[WINESS]] -> (tensor) { +// CHECK: %[[LHSROWS:.*]] = dim %[[LHS]], %[[C0]] : tensor +// CHECK: %[[RHSCOLS:.*]] = dim %[[RHS]], %[[C1]] : tensor +// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[LHSROWS]], %[[RHSCOLS]] : tensor<2xindex> +// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<2xindex>) -> tensor +// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor, tensor) init(%[[INIT_TENSOR]] : tensor) -> tensor +// CHECK: shape.assuming_yield %[[MATMUL]] : tensor +// CHECK: } +// CHECK: return %[[RET:.*]] : tensor +func @tcf_matmul(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tcf.matmul %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} diff --git a/test/Conversion/TCFToTCP/basic.mlir b/test/Conversion/TCFToTCP/basic.mlir index d9c884db2..308c90154 100644 --- a/test/Conversion/TCFToTCP/basic.mlir +++ b/test/Conversion/TCFToTCP/basic.mlir @@ -1,21 +1,10 @@ -// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail +// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s -// CHECK-LABEL: func @tcf_matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor, -// CHECK-SAME: %[[RHS:.*]]: tensor) -> tensor { -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor -// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[KEQUAL:.*]] = cmpi "eq", %[[LHSK]], %[[RHSK]] : index -// CHECK: %[[WITNESS:.*]] = shape.cstr_require %[[KEQUAL]], "{{.*}}" -// CHECK: %[[RET:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { -// CHECK: %[[MATMUL:.*]] = tcp.matmul %[[LHS]], %[[RHS]] : (tensor, tensor) -> tensor -// CHECK: shape.assuming_yield %[[MATMUL]] : tensor -// CHECK: } -// CHECK: return %[[RET:.*]] : tensor -// CHECK: } -func @tcf_matmul(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = tcf.matmul %arg0, %arg1 : (tensor, tensor) -> tensor - return %0 : tensor +// NOTE: We are keeping this pass around, even though it currently does +// nothing, in order to avoid having to reintroduce the same +// boilerplate. + +// CHECK: @f +func @f() { + return } diff --git a/test/Dialect/TCP/bufferize.mlir b/test/Dialect/TCP/bufferize.mlir index 96dea0d51..16fd050b2 100644 --- a/test/Dialect/TCP/bufferize.mlir +++ b/test/Dialect/TCP/bufferize.mlir @@ -14,24 +14,14 @@ func @tcp_broadcast_to(%arg0: tensor, %arg1: tensor) -> tensor } -// CHECK-LABEL: func @tcp_matmul( -// CHECK-SAME: %[[LHS_TENSOR:.*]]: tensor, -// CHECK-SAME: %[[RHS_TENSOR:.*]]: tensor) -> tensor { -// CHECK: %[[LHS:.*]] = tensor_to_memref %[[LHS_TENSOR]] : memref -// CHECK: %[[RHS:.*]] = tensor_to_memref %[[RHS_TENSOR]] : memref -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[LHS_ROWS:.*]] = dim %[[LHS_TENSOR]], %[[C0]] : tensor -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[RHS_COLS:.*]] = dim %[[RHS_TENSOR]], %[[C1]] : tensor -// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[LHS_ROWS]], %[[RHS_COLS]] : tensor<2xindex> +// CHECK-LABEL: func @tcp_splatted( +// CHECK-SAME: %[[SPLAT_VAL:.*]]: f32, +// CHECK-SAME: %[[SHAPE:.*]]: tensor) -> tensor { // CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref -// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32 -// CHECK: linalg.fill(%[[RESULT]], %[[C0F32]]) : memref, f32 -// CHECK: linalg.matmul ins(%[[LHS]], %[[RHS]] : memref, memref) outs(%[[RESULT]] : memref) +// CHECK: linalg.fill(%[[RESULT]], %[[SPLAT_VAL]]) : memref, f32 // CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT]] : memref // CHECK: return %[[RESULT_TENSOR]] : tensor -// CHECK: } -func @tcp_matmul(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = tcp.matmul %arg0, %arg1 : (tensor, tensor) -> tensor +func @tcp_splatted(%arg0: f32, %arg1: tensor) -> tensor { + %0 = tcp.splatted %arg0, %arg1 : (f32, tensor) -> tensor return %0 : tensor } diff --git a/test/Dialect/TCP/ops.mlir b/test/Dialect/TCP/ops.mlir index 80f07bde3..ca08c173e 100644 --- a/test/Dialect/TCP/ops.mlir +++ b/test/Dialect/TCP/ops.mlir @@ -1,8 +1,15 @@ // RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s -// CHECK-LABEL: func @matmul -func @matmul(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: tcp.matmul %arg0, %arg1 : (tensor, tensor) -> tensor - %0 = tcp.matmul %arg0, %arg1 : (tensor, tensor) -> tensor +// CHECK-LABEL: @broadcast_to +func @broadcast_to(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: tcp.broadcast_to + %0 = tcp.broadcast_to %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @splatted +func @splatted(%arg0: f32, %arg1: tensor) -> tensor { + // CHECK: tcp.splatted + %0 = tcp.splatted %arg0, %arg1 : (f32, tensor) -> tensor return %0 : tensor }