[TCP] Replace tcp.matmul with linalg.matmul.

This involved adding a `tcp.splatted` op to splat a dynamically sized
init tensor. See rationale in TCPOps.td docs.

One interesting observation is that when lowering tcf.matmul to
linalg.matmul, we need to both 1) create the error checks and 2)
calculate a shape transfer function to create the init tensors.
Previously, 2) was deferred to bufferizing tcp.matmul later. I'm not
sure if this is a conflation of concerns or not. For now, it's not a big
burden.
pull/113/head
Sean Silva 2020-11-09 15:49:22 -08:00
parent 0427aacb0b
commit 1c7c362e29
14 changed files with 243 additions and 111 deletions

View File

@ -38,6 +38,21 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()"; let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()";
} }
//===----------------------------------------------------------------------===//
// TCFToTCP
//===----------------------------------------------------------------------===//
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
let summary = "Convert TCF to Linalg";
let description = [{
The intention is for this pass to convert mainly to linalg named ops.
Because linalg is at the "TCP" layer of abstraction, this pass has to
concern itself with generating guards for error cases.
}];
let constructor = "mlir::NPCOMP::createConvertTCFToLinalgPass()";
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TCFToStd // TCFToStd
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,21 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H
#define NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToLinalgPass();
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H

View File

@ -20,33 +20,6 @@ class TCP_Op<string mnemonic, list<OpTrait> traits = []>
: Op<TCP_Dialect, mnemonic, traits> { : Op<TCP_Dialect, mnemonic, traits> {
} }
// TODO: Generalize this op appropriately and add more verification.
// For example, should we have a single primitive that does multidimensional
// contractions? + batching as well in the same op? In fact, if we want to
// get really general, we can include convolution as well; matmul is the 1x1
// image and 1x1 kernel special case.
// It still lowers trivially into linalg.generic even with such generalization
// -- the main question is what transforms we want to do at the TCP level that
// would be affected by those design choices.
def TCP_MatmulOp : TCP_Op<"matmul"> {
let summary = "Performs a matrix multiplication";
let description = [{
Performs a matrix multiplication.
The tensors have dimensions:
- lhs: [M, K]
- rhs: [K, N]
- result: [M, N]
If the `K` dimension mismatches between operands, this op has
undefined behavior.
}];
let arguments = (ins 2DTensorOf<[F32]>:$lhs, 2DTensorOf<[F32]>:$rhs);
let results = (outs 2DTensorOf<[F32]>:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
}
def TCP_BroadcastToOp : TCP_Op<"broadcast_to"> { def TCP_BroadcastToOp : TCP_Op<"broadcast_to"> {
let summary = "Broadcasts an operand to a given shape."; let summary = "Broadcasts an operand to a given shape.";
let description = [{ let description = [{
@ -60,4 +33,25 @@ It is undefined behavior if such a broadcast is not legal.
let assemblyFormat = "$operand `,` $shape attr-dict `:` functional-type(operands, results)"; let assemblyFormat = "$operand `,` $shape attr-dict `:` functional-type(operands, results)";
} }
def TCP_SplattedOp : TCP_Op<"splatted"> {
let summary = "Creates a tensor filled with a particular scalar value.";
let description = [{
Creates a tensor of shape `shape` with all elements filled with `splatVal`.
This op is somewhat redundant with tcp.broadcast_to. However,
tcp.broadcast_to handles degenerate "size-1" broadcasting which structurally
cannot happen with this op. So to avoid losing that information, we keep
this op separate.
NOTE: The name "splatted" separates it from std.splat, which currently
only handles statically shaped memrefs.
TODO: Improve std.splat to take dynamic shapes.
}];
let arguments = (ins AnyType:$splatVal, Shape_ExtentTensorType:$shape);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = "$splatVal `,` $shape attr-dict `:` functional-type(operands, results)";
}
#endif // TCP_OPS #endif // TCP_OPS

View File

@ -1,6 +1,7 @@
add_subdirectory(ATenToTCF) add_subdirectory(ATenToTCF)
add_subdirectory(BasicpyToStd) add_subdirectory(BasicpyToStd)
add_subdirectory(NumpyToTCF) add_subdirectory(NumpyToTCF)
add_subdirectory(TCFToLinalg)
add_subdirectory(TCFToStd) add_subdirectory(TCFToStd)
add_subdirectory(TCFToTCP) add_subdirectory(TCFToTCP)

View File

@ -11,6 +11,7 @@
#include "npcomp/Conversion/ATenToTCF/Passes.h" #include "npcomp/Conversion/ATenToTCF/Passes.h"
#include "npcomp/Conversion/BasicpyToStd/Passes.h" #include "npcomp/Conversion/BasicpyToStd/Passes.h"
#include "npcomp/Conversion/NumpyToTCF/Passes.h" #include "npcomp/Conversion/NumpyToTCF/Passes.h"
#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h"
#include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "npcomp/Conversion/TCFToStd/TCFToStd.h"
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"

View File

@ -0,0 +1,19 @@
add_npcomp_conversion_library(NPCOMPTCFToLinalg
TCFToLinalg.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TCFToLinalg
DEPENDS
NPCOMPConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTransforms
MLIRShape
NPCOMPTCFDialect
)

View File

@ -0,0 +1,103 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
static SmallVector<Value, 6> bypassResultShapes(Operation *op,
OpBuilder &builder) {
if (auto matmul = dyn_cast<tcf::MatmulOp>(op)) {
auto lhsRows = builder.create<DimOp>(op->getLoc(), matmul.lhs(), 0);
auto rhsCols = builder.create<DimOp>(op->getLoc(), matmul.rhs(), 1);
auto shape = builder.create<TensorFromElementsOp>(
op->getLoc(), ValueRange({lhsRows, rhsCols}));
return {shape};
}
// No shape transfer function.
return {};
}
namespace {
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tcf::MatmulOp op,
PatternRewriter &rewriter) const override {
// Create the constraints, and the assuming region.
Value lhsK = rewriter.create<DimOp>(op.getLoc(), op.lhs(), 1);
Value rhsK = rewriter.create<DimOp>(op.getLoc(), op.rhs(), 0);
Value matchingK =
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK);
Value witness = rewriter.create<shape::CstrRequireOp>(
op.getLoc(), matchingK, "mismatching contracting dimension for matmul");
auto assuming = rewriter.create<shape::AssumingOp>(
op.getLoc(), ArrayRef<Type>{op.getType()}, witness);
// Build the region body.
rewriter.createBlock(&assuming.doRegion());
// Create the init tensor for the matmul.
// TODO: Expand supported data types.
Value c0 =
rewriter.create<ConstantOp>(op.getLoc(), rewriter.getF32FloatAttr(0.0));
Value shape = bypassResultShapes(op, rewriter)[0];
Value initTensor =
rewriter.create<tcp::SplattedOp>(op.getLoc(), op.getType(), c0, shape);
// Create the matmul.
auto matmul = rewriter.create<linalg::MatmulOp>(
op.getLoc(), TypeRange(op.getType()), op.getOperands(), ValueRange(),
ValueRange(initTensor));
rewriter.create<shape::AssumingYieldOp>(op.getLoc(), matmul.getResult(0));
// Finally, replace with the results of the shape.assuming
rewriter.replaceOp(op, assuming.getResults());
return success();
}
};
} // namespace
namespace {
class ConvertTCFToLinalg : public ConvertTCFToLinalgBase<ConvertTCFToLinalg> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
}
void runOnOperation() override {
ModuleOp module = getOperation();
(void)applyPatternsAndFoldGreedily(module, getPatterns());
}
FrozenRewritePatternList getPatterns() {
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<ConvertMatmul>(context);
return std::move(patterns);
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::createConvertTCFToLinalgPass() {
return std::make_unique<ConvertTCFToLinalg>();
}

View File

@ -21,35 +21,6 @@
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
namespace {
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tcf::MatmulOp op,
PatternRewriter &rewriter) const override {
// Create the constraints, and the assuming region.
Value lhsK = rewriter.create<DimOp>(op.getLoc(), op.lhs(), 1);
Value rhsK = rewriter.create<DimOp>(op.getLoc(), op.rhs(), 0);
Value matchingK =
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK);
Value witness = rewriter.create<shape::CstrRequireOp>(
op.getLoc(), matchingK, "mismatching contracting dimension for matmul");
auto assuming = rewriter.create<shape::AssumingOp>(
op.getLoc(), ArrayRef<Type>{op.getType()}, witness);
// Build the region body.
rewriter.createBlock(&assuming.doRegion());
Value matmul = rewriter.create<tcp::MatmulOp>(op.getLoc(), op.getType(),
op.lhs(), op.rhs());
rewriter.create<shape::AssumingYieldOp>(op.getLoc(), matmul);
// Finally, replace with the results of the shape.assuming
rewriter.replaceOp(op, assuming.getResults());
return success();
}
};
} // namespace
namespace { namespace {
class ConvertTCFToTCP : public ConvertTCFToTCPBase<ConvertTCFToTCP> { class ConvertTCFToTCP : public ConvertTCFToTCPBase<ConvertTCFToTCP> {
public: public:
@ -63,9 +34,10 @@ public:
} }
FrozenRewritePatternList getPatterns() { FrozenRewritePatternList getPatterns() {
MLIRContext *context = &getContext(); // NOTE: We are keeping this pass around, even though it currently does
// nothing, in order to avoid having to reintroduce the same
// boilerplate.
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<ConvertMatmul>(context);
return std::move(patterns); return std::move(patterns);
} }
}; };

View File

@ -32,12 +32,8 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
return {broadcastTo.shape()}; return {broadcastTo.shape()};
} }
if (auto matmul = dyn_cast<tcp::MatmulOp>(op)) { if (auto splatted = dyn_cast<tcp::SplattedOp>(op)) {
auto lhsRows = builder.create<DimOp>(op.getLoc(), matmul.lhs(), 0); return {splatted.shape()};
auto rhsCols = builder.create<DimOp>(op.getLoc(), matmul.rhs(), 1);
auto shape = builder.create<TensorFromElementsOp>(
op.getLoc(), ValueRange({lhsRows, rhsCols}));
return {shape};
} }
// No shape transfer function. // No shape transfer function.
@ -144,20 +140,17 @@ public:
} // namespace } // namespace
namespace { namespace {
class BufferizeMatmulOp : public OpConversionPattern<tcp::MatmulOp> { class BufferizeSplattedOp : public OpConversionPattern<tcp::SplattedOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(tcp::MatmulOp op, ArrayRef<Value> operands, matchAndRewrite(tcp::SplattedOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc()); auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc());
if (failed(resultsOrFailure)) if (failed(resultsOrFailure))
return failure(); return failure();
auto results = *resultsOrFailure; auto results = *resultsOrFailure;
auto c0 = rewriter.create<linalg::FillOp>(op.getLoc(), results[0], op.splatVal());
rewriter.create<ConstantOp>(op.getLoc(), rewriter.getF32FloatAttr(0.0));
rewriter.create<linalg::FillOp>(op.getLoc(), results[0], c0);
rewriter.create<linalg::MatmulOp>(op.getLoc(), operands, results);
rewriter.replaceOp(op, results); rewriter.replaceOp(op, results);
return success(); return success();
} }
@ -190,8 +183,8 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context); patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
target.addIllegalOp<tcp::BroadcastToOp>(); target.addIllegalOp<tcp::BroadcastToOp>();
patterns.insert<BufferizeMatmulOp>(typeConverter, context); patterns.insert<BufferizeSplattedOp>(typeConverter, context);
target.addIllegalOp<tcp::MatmulOp>(); target.addIllegalOp<tcp::SplattedOp>();
target.addLegalDialect<linalg::LinalgDialect>(); target.addLegalDialect<linalg::LinalgDialect>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();

View File

@ -40,6 +40,7 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h"
#include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "npcomp/Conversion/TCFToStd/TCFToStd.h"
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
#include "npcomp/Dialect/Refback/IR/RefbackOps.h" #include "npcomp/Dialect/Refback/IR/RefbackOps.h"
@ -305,6 +306,7 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline(
// //
// TCP does not. So we need to reify the broadcasting and error checking. // TCP does not. So we need to reify the broadcasting and error checking.
pm.addPass(createConvertTCFToStdPass()); pm.addPass(createConvertTCFToStdPass());
pm.addPass(createConvertTCFToLinalgPass());
pm.addPass(createConvertTCFToTCPPass()); pm.addPass(createConvertTCFToTCPPass());
createRefBackendLoweringPipeline(pm, options); createRefBackendLoweringPipeline(pm, options);

View File

@ -0,0 +1,25 @@
// RUN: npcomp-opt <%s -convert-tcf-to-linalg | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @tcf_matmul(
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[KEQUAL:.*]] = cmpi "eq", %[[LHSK]], %[[RHSK]] : index
// CHECK: %[[WINESS:.*]] = shape.cstr_require %[[KEQUAL]], "mismatching contracting dimension for matmul"
// CHECK: %[[RET:.*]] = shape.assuming %[[WINESS]] -> (tensor<?x?xf32>) {
// CHECK: %[[LHSROWS:.*]] = dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[RHSCOLS:.*]] = dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[LHSROWS]], %[[RHSCOLS]] : tensor<2xindex>
// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) init(%[[INIT_TENSOR]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: shape.assuming_yield %[[MATMUL]] : tensor<?x?xf32>
// CHECK: }
// CHECK: return %[[RET:.*]] : tensor<?x?xf32>
func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tcf.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

View File

@ -1,21 +1,10 @@
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail // RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s
// CHECK-LABEL: func @tcf_matmul( // NOTE: We are keeping this pass around, even though it currently does
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>, // nothing, in order to avoid having to reintroduce the same
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> { // boilerplate.
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: @f
// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor<?x?xf32> func @f() {
// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?x?xf32> return
// CHECK: %[[KEQUAL:.*]] = cmpi "eq", %[[LHSK]], %[[RHSK]] : index
// CHECK: %[[WITNESS:.*]] = shape.cstr_require %[[KEQUAL]], "{{.*}}"
// CHECK: %[[RET:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?x?xf32>) {
// CHECK: %[[MATMUL:.*]] = tcp.matmul %[[LHS]], %[[RHS]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: shape.assuming_yield %[[MATMUL]] : tensor<?x?xf32>
// CHECK: }
// CHECK: return %[[RET:.*]] : tensor<?x?xf32>
// CHECK: }
func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = tcf.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
} }

View File

@ -14,24 +14,14 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
return %0 : tensor<?x?xf32> return %0 : tensor<?x?xf32>
} }
// CHECK-LABEL: func @tcp_matmul( // CHECK-LABEL: func @tcp_splatted(
// CHECK-SAME: %[[LHS_TENSOR:.*]]: tensor<?x?xf32>, // CHECK-SAME: %[[SPLAT_VAL:.*]]: f32,
// CHECK-SAME: %[[RHS_TENSOR:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> { // CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: %[[LHS:.*]] = tensor_to_memref %[[LHS_TENSOR]] : memref<?x?xf32>
// CHECK: %[[RHS:.*]] = tensor_to_memref %[[RHS_TENSOR]] : memref<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[LHS_ROWS:.*]] = dim %[[LHS_TENSOR]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[RHS_COLS:.*]] = dim %[[RHS_TENSOR]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[LHS_ROWS]], %[[RHS_COLS]] : tensor<2xindex>
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref<?x?xf32> // CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref<?x?xf32>
// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32 // CHECK: linalg.fill(%[[RESULT]], %[[SPLAT_VAL]]) : memref<?x?xf32>, f32
// CHECK: linalg.fill(%[[RESULT]], %[[C0F32]]) : memref<?x?xf32>, f32
// CHECK: linalg.matmul ins(%[[LHS]], %[[RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[RESULT]] : memref<?x?xf32>)
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT]] : memref<?x?xf32> // CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT]] : memref<?x?xf32>
// CHECK: return %[[RESULT_TENSOR]] : tensor<?x?xf32> // CHECK: return %[[RESULT_TENSOR]] : tensor<?x?xf32>
// CHECK: } func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
func @tcp_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { %0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
%0 = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32> return %0 : tensor<?x?xf32>
} }

View File

@ -1,8 +1,15 @@
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s // RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
// CHECK-LABEL: func @matmul // CHECK-LABEL: @broadcast_to
func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { func @broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: tcp.broadcast_to
%0 = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> %0 = tcp.broadcast_to %arg0, %arg1 : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: @splatted
func @splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: tcp.splatted
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32> return %0 : tensor<?x?xf32>
} }