mirror of https://github.com/llvm/torch-mlir
[TCP] Replace tcp.matmul with linalg.matmul.
This involved adding a `tcp.splatted` op to splat a dynamically sized init tensor. See rationale in TCPOps.td docs. One interesting observation is that when lowering tcf.matmul to linalg.matmul, we need to both 1) create the error checks and 2) calculate a shape transfer function to create the init tensors. Previously, 2) was deferred to bufferizing tcp.matmul later. I'm not sure if this is a conflation of concerns or not. For now, it's not a big burden.pull/113/head
parent
0427aacb0b
commit
1c7c362e29
|
@ -38,6 +38,21 @@ def ConvertNumpyToTCF : Pass<"convert-numpy-to-tcf", "FuncOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()";
|
let constructor = "mlir::NPCOMP::createConvertNumpyToTCFPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TCFToTCP
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvertTCFToLinalg : Pass<"convert-tcf-to-linalg", "ModuleOp"> {
|
||||||
|
let summary = "Convert TCF to Linalg";
|
||||||
|
let description = [{
|
||||||
|
The intention is for this pass to convert mainly to linalg named ops.
|
||||||
|
|
||||||
|
Because linalg is at the "TCP" layer of abstraction, this pass has to
|
||||||
|
concern itself with generating guards for error cases.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::NPCOMP::createConvertTCFToLinalgPass()";
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TCFToStd
|
// TCFToStd
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H
|
||||||
|
#define NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCFToLinalgPass();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_CONVERSION_TCFTOLINALG_TCFTOLINALG_H
|
|
@ -20,33 +20,6 @@ class TCP_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
: Op<TCP_Dialect, mnemonic, traits> {
|
: Op<TCP_Dialect, mnemonic, traits> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Generalize this op appropriately and add more verification.
|
|
||||||
// For example, should we have a single primitive that does multidimensional
|
|
||||||
// contractions? + batching as well in the same op? In fact, if we want to
|
|
||||||
// get really general, we can include convolution as well; matmul is the 1x1
|
|
||||||
// image and 1x1 kernel special case.
|
|
||||||
// It still lowers trivially into linalg.generic even with such generalization
|
|
||||||
// -- the main question is what transforms we want to do at the TCP level that
|
|
||||||
// would be affected by those design choices.
|
|
||||||
def TCP_MatmulOp : TCP_Op<"matmul"> {
|
|
||||||
let summary = "Performs a matrix multiplication";
|
|
||||||
let description = [{
|
|
||||||
Performs a matrix multiplication.
|
|
||||||
|
|
||||||
The tensors have dimensions:
|
|
||||||
- lhs: [M, K]
|
|
||||||
- rhs: [K, N]
|
|
||||||
- result: [M, N]
|
|
||||||
|
|
||||||
If the `K` dimension mismatches between operands, this op has
|
|
||||||
undefined behavior.
|
|
||||||
}];
|
|
||||||
let arguments = (ins 2DTensorOf<[F32]>:$lhs, 2DTensorOf<[F32]>:$rhs);
|
|
||||||
let results = (outs 2DTensorOf<[F32]>:$result);
|
|
||||||
|
|
||||||
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` functional-type(operands, results)";
|
|
||||||
}
|
|
||||||
|
|
||||||
def TCP_BroadcastToOp : TCP_Op<"broadcast_to"> {
|
def TCP_BroadcastToOp : TCP_Op<"broadcast_to"> {
|
||||||
let summary = "Broadcasts an operand to a given shape.";
|
let summary = "Broadcasts an operand to a given shape.";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -60,4 +33,25 @@ It is undefined behavior if such a broadcast is not legal.
|
||||||
let assemblyFormat = "$operand `,` $shape attr-dict `:` functional-type(operands, results)";
|
let assemblyFormat = "$operand `,` $shape attr-dict `:` functional-type(operands, results)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TCP_SplattedOp : TCP_Op<"splatted"> {
|
||||||
|
let summary = "Creates a tensor filled with a particular scalar value.";
|
||||||
|
let description = [{
|
||||||
|
Creates a tensor of shape `shape` with all elements filled with `splatVal`.
|
||||||
|
|
||||||
|
This op is somewhat redundant with tcp.broadcast_to. However,
|
||||||
|
tcp.broadcast_to handles degenerate "size-1" broadcasting which structurally
|
||||||
|
cannot happen with this op. So to avoid losing that information, we keep
|
||||||
|
this op separate.
|
||||||
|
|
||||||
|
NOTE: The name "splatted" separates it from std.splat, which currently
|
||||||
|
only handles statically shaped memrefs.
|
||||||
|
|
||||||
|
TODO: Improve std.splat to take dynamic shapes.
|
||||||
|
}];
|
||||||
|
let arguments = (ins AnyType:$splatVal, Shape_ExtentTensorType:$shape);
|
||||||
|
let results = (outs AnyRankedTensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = "$splatVal `,` $shape attr-dict `:` functional-type(operands, results)";
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TCP_OPS
|
#endif // TCP_OPS
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
add_subdirectory(ATenToTCF)
|
add_subdirectory(ATenToTCF)
|
||||||
add_subdirectory(BasicpyToStd)
|
add_subdirectory(BasicpyToStd)
|
||||||
add_subdirectory(NumpyToTCF)
|
add_subdirectory(NumpyToTCF)
|
||||||
|
add_subdirectory(TCFToLinalg)
|
||||||
add_subdirectory(TCFToStd)
|
add_subdirectory(TCFToStd)
|
||||||
add_subdirectory(TCFToTCP)
|
add_subdirectory(TCFToTCP)
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
#include "npcomp/Conversion/ATenToTCF/Passes.h"
|
||||||
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
||||||
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
||||||
|
#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h"
|
||||||
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
add_npcomp_conversion_library(NPCOMPTCFToLinalg
|
||||||
|
TCFToLinalg.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TCFToLinalg
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
NPCOMPConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRTransforms
|
||||||
|
MLIRShape
|
||||||
|
NPCOMPTCFDialect
|
||||||
|
)
|
|
@ -0,0 +1,103 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Traits.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
static SmallVector<Value, 6> bypassResultShapes(Operation *op,
|
||||||
|
OpBuilder &builder) {
|
||||||
|
|
||||||
|
if (auto matmul = dyn_cast<tcf::MatmulOp>(op)) {
|
||||||
|
auto lhsRows = builder.create<DimOp>(op->getLoc(), matmul.lhs(), 0);
|
||||||
|
auto rhsCols = builder.create<DimOp>(op->getLoc(), matmul.rhs(), 1);
|
||||||
|
auto shape = builder.create<TensorFromElementsOp>(
|
||||||
|
op->getLoc(), ValueRange({lhsRows, rhsCols}));
|
||||||
|
return {shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
// No shape transfer function.
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(tcf::MatmulOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Create the constraints, and the assuming region.
|
||||||
|
Value lhsK = rewriter.create<DimOp>(op.getLoc(), op.lhs(), 1);
|
||||||
|
Value rhsK = rewriter.create<DimOp>(op.getLoc(), op.rhs(), 0);
|
||||||
|
Value matchingK =
|
||||||
|
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK);
|
||||||
|
Value witness = rewriter.create<shape::CstrRequireOp>(
|
||||||
|
op.getLoc(), matchingK, "mismatching contracting dimension for matmul");
|
||||||
|
auto assuming = rewriter.create<shape::AssumingOp>(
|
||||||
|
op.getLoc(), ArrayRef<Type>{op.getType()}, witness);
|
||||||
|
|
||||||
|
// Build the region body.
|
||||||
|
rewriter.createBlock(&assuming.doRegion());
|
||||||
|
// Create the init tensor for the matmul.
|
||||||
|
// TODO: Expand supported data types.
|
||||||
|
Value c0 =
|
||||||
|
rewriter.create<ConstantOp>(op.getLoc(), rewriter.getF32FloatAttr(0.0));
|
||||||
|
Value shape = bypassResultShapes(op, rewriter)[0];
|
||||||
|
Value initTensor =
|
||||||
|
rewriter.create<tcp::SplattedOp>(op.getLoc(), op.getType(), c0, shape);
|
||||||
|
|
||||||
|
// Create the matmul.
|
||||||
|
auto matmul = rewriter.create<linalg::MatmulOp>(
|
||||||
|
op.getLoc(), TypeRange(op.getType()), op.getOperands(), ValueRange(),
|
||||||
|
ValueRange(initTensor));
|
||||||
|
rewriter.create<shape::AssumingYieldOp>(op.getLoc(), matmul.getResult(0));
|
||||||
|
|
||||||
|
// Finally, replace with the results of the shape.assuming
|
||||||
|
rewriter.replaceOp(op, assuming.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertTCFToLinalg : public ConvertTCFToLinalgBase<ConvertTCFToLinalg> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
ModuleOp module = getOperation();
|
||||||
|
(void)applyPatternsAndFoldGreedily(module, getPatterns());
|
||||||
|
}
|
||||||
|
|
||||||
|
FrozenRewritePatternList getPatterns() {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<ConvertMatmul>(context);
|
||||||
|
return std::move(patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::NPCOMP::createConvertTCFToLinalgPass() {
|
||||||
|
return std::make_unique<ConvertTCFToLinalg>();
|
||||||
|
}
|
|
@ -21,35 +21,6 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(tcf::MatmulOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
// Create the constraints, and the assuming region.
|
|
||||||
Value lhsK = rewriter.create<DimOp>(op.getLoc(), op.lhs(), 1);
|
|
||||||
Value rhsK = rewriter.create<DimOp>(op.getLoc(), op.rhs(), 0);
|
|
||||||
Value matchingK =
|
|
||||||
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK);
|
|
||||||
Value witness = rewriter.create<shape::CstrRequireOp>(
|
|
||||||
op.getLoc(), matchingK, "mismatching contracting dimension for matmul");
|
|
||||||
auto assuming = rewriter.create<shape::AssumingOp>(
|
|
||||||
op.getLoc(), ArrayRef<Type>{op.getType()}, witness);
|
|
||||||
|
|
||||||
// Build the region body.
|
|
||||||
rewriter.createBlock(&assuming.doRegion());
|
|
||||||
Value matmul = rewriter.create<tcp::MatmulOp>(op.getLoc(), op.getType(),
|
|
||||||
op.lhs(), op.rhs());
|
|
||||||
rewriter.create<shape::AssumingYieldOp>(op.getLoc(), matmul);
|
|
||||||
|
|
||||||
// Finally, replace with the results of the shape.assuming
|
|
||||||
rewriter.replaceOp(op, assuming.getResults());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertTCFToTCP : public ConvertTCFToTCPBase<ConvertTCFToTCP> {
|
class ConvertTCFToTCP : public ConvertTCFToTCPBase<ConvertTCFToTCP> {
|
||||||
public:
|
public:
|
||||||
|
@ -63,9 +34,10 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
FrozenRewritePatternList getPatterns() {
|
FrozenRewritePatternList getPatterns() {
|
||||||
MLIRContext *context = &getContext();
|
// NOTE: We are keeping this pass around, even though it currently does
|
||||||
|
// nothing, in order to avoid having to reintroduce the same
|
||||||
|
// boilerplate.
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<ConvertMatmul>(context);
|
|
||||||
return std::move(patterns);
|
return std::move(patterns);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -32,12 +32,8 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
||||||
return {broadcastTo.shape()};
|
return {broadcastTo.shape()};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto matmul = dyn_cast<tcp::MatmulOp>(op)) {
|
if (auto splatted = dyn_cast<tcp::SplattedOp>(op)) {
|
||||||
auto lhsRows = builder.create<DimOp>(op.getLoc(), matmul.lhs(), 0);
|
return {splatted.shape()};
|
||||||
auto rhsCols = builder.create<DimOp>(op.getLoc(), matmul.rhs(), 1);
|
|
||||||
auto shape = builder.create<TensorFromElementsOp>(
|
|
||||||
op.getLoc(), ValueRange({lhsRows, rhsCols}));
|
|
||||||
return {shape};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// No shape transfer function.
|
// No shape transfer function.
|
||||||
|
@ -144,20 +140,17 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class BufferizeMatmulOp : public OpConversionPattern<tcp::MatmulOp> {
|
class BufferizeSplattedOp : public OpConversionPattern<tcp::SplattedOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(tcp::MatmulOp op, ArrayRef<Value> operands,
|
matchAndRewrite(tcp::SplattedOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc());
|
auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc());
|
||||||
if (failed(resultsOrFailure))
|
if (failed(resultsOrFailure))
|
||||||
return failure();
|
return failure();
|
||||||
auto results = *resultsOrFailure;
|
auto results = *resultsOrFailure;
|
||||||
auto c0 =
|
rewriter.create<linalg::FillOp>(op.getLoc(), results[0], op.splatVal());
|
||||||
rewriter.create<ConstantOp>(op.getLoc(), rewriter.getF32FloatAttr(0.0));
|
|
||||||
rewriter.create<linalg::FillOp>(op.getLoc(), results[0], c0);
|
|
||||||
rewriter.create<linalg::MatmulOp>(op.getLoc(), operands, results);
|
|
||||||
rewriter.replaceOp(op, results);
|
rewriter.replaceOp(op, results);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -190,8 +183,8 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
|
||||||
|
|
||||||
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
||||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||||
patterns.insert<BufferizeMatmulOp>(typeConverter, context);
|
patterns.insert<BufferizeSplattedOp>(typeConverter, context);
|
||||||
target.addIllegalOp<tcp::MatmulOp>();
|
target.addIllegalOp<tcp::SplattedOp>();
|
||||||
|
|
||||||
target.addLegalDialect<linalg::LinalgDialect>();
|
target.addLegalDialect<linalg::LinalgDialect>();
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
|
|
@ -40,6 +40,7 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "npcomp/Conversion/TCFToLinalg/TCFToLinalg.h"
|
||||||
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
||||||
|
@ -305,6 +306,7 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline(
|
||||||
//
|
//
|
||||||
// TCP does not. So we need to reify the broadcasting and error checking.
|
// TCP does not. So we need to reify the broadcasting and error checking.
|
||||||
pm.addPass(createConvertTCFToStdPass());
|
pm.addPass(createConvertTCFToStdPass());
|
||||||
|
pm.addPass(createConvertTCFToLinalgPass());
|
||||||
pm.addPass(createConvertTCFToTCPPass());
|
pm.addPass(createConvertTCFToTCPPass());
|
||||||
|
|
||||||
createRefBackendLoweringPipeline(pm, options);
|
createRefBackendLoweringPipeline(pm, options);
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
// RUN: npcomp-opt <%s -convert-tcf-to-linalg | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tcf_matmul(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
|
// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[KEQUAL:.*]] = cmpi "eq", %[[LHSK]], %[[RHSK]] : index
|
||||||
|
// CHECK: %[[WINESS:.*]] = shape.cstr_require %[[KEQUAL]], "mismatching contracting dimension for matmul"
|
||||||
|
// CHECK: %[[RET:.*]] = shape.assuming %[[WINESS]] -> (tensor<?x?xf32>) {
|
||||||
|
// CHECK: %[[LHSROWS:.*]] = dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[RHSCOLS:.*]] = dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[LHSROWS]], %[[RHSCOLS]] : tensor<2xindex>
|
||||||
|
// CHECK: %[[INIT_TENSOR:.*]] = tcp.splatted %[[C0F32]], %[[SHAPE]] : (f32, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) init(%[[INIT_TENSOR]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: shape.assuming_yield %[[MATMUL]] : tensor<?x?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: return %[[RET:.*]] : tensor<?x?xf32>
|
||||||
|
func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
|
%0 = tcf.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
return %0 : tensor<?x?xf32>
|
||||||
|
}
|
|
@ -1,21 +1,10 @@
|
||||||
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @tcf_matmul(
|
// NOTE: We are keeping this pass around, even though it currently does
|
||||||
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
|
// nothing, in order to avoid having to reintroduce the same
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
// boilerplate.
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: @f
|
||||||
// CHECK: %[[LHSK:.*]] = dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
|
func @f() {
|
||||||
// CHECK: %[[RHSK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
|
return
|
||||||
// CHECK: %[[KEQUAL:.*]] = cmpi "eq", %[[LHSK]], %[[RHSK]] : index
|
|
||||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_require %[[KEQUAL]], "{{.*}}"
|
|
||||||
// CHECK: %[[RET:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?x?xf32>) {
|
|
||||||
// CHECK: %[[MATMUL:.*]] = tcp.matmul %[[LHS]], %[[RHS]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
|
||||||
// CHECK: shape.assuming_yield %[[MATMUL]] : tensor<?x?xf32>
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: return %[[RET:.*]] : tensor<?x?xf32>
|
|
||||||
// CHECK: }
|
|
||||||
func @tcf_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|
||||||
%0 = tcf.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
|
||||||
return %0 : tensor<?x?xf32>
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,24 +14,14 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
|
||||||
return %0 : tensor<?x?xf32>
|
return %0 : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @tcp_matmul(
|
// CHECK-LABEL: func @tcp_splatted(
|
||||||
// CHECK-SAME: %[[LHS_TENSOR:.*]]: tensor<?x?xf32>,
|
// CHECK-SAME: %[[SPLAT_VAL:.*]]: f32,
|
||||||
// CHECK-SAME: %[[RHS_TENSOR:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||||
// CHECK: %[[LHS:.*]] = tensor_to_memref %[[LHS_TENSOR]] : memref<?x?xf32>
|
|
||||||
// CHECK: %[[RHS:.*]] = tensor_to_memref %[[RHS_TENSOR]] : memref<?x?xf32>
|
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
|
||||||
// CHECK: %[[LHS_ROWS:.*]] = dim %[[LHS_TENSOR]], %[[C0]] : tensor<?x?xf32>
|
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
|
||||||
// CHECK: %[[RHS_COLS:.*]] = dim %[[RHS_TENSOR]], %[[C1]] : tensor<?x?xf32>
|
|
||||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[LHS_ROWS]], %[[RHS_COLS]] : tensor<2xindex>
|
|
||||||
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref<?x?xf32>
|
// CHECK: %[[RESULT:.*]] = refback.alloc_memref %[[SHAPE]] : memref<?x?xf32>
|
||||||
// CHECK: %[[C0F32:.*]] = constant 0.000000e+00 : f32
|
// CHECK: linalg.fill(%[[RESULT]], %[[SPLAT_VAL]]) : memref<?x?xf32>, f32
|
||||||
// CHECK: linalg.fill(%[[RESULT]], %[[C0F32]]) : memref<?x?xf32>, f32
|
|
||||||
// CHECK: linalg.matmul ins(%[[LHS]], %[[RHS]] : memref<?x?xf32>, memref<?x?xf32>) outs(%[[RESULT]] : memref<?x?xf32>)
|
|
||||||
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT]] : memref<?x?xf32>
|
// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT]] : memref<?x?xf32>
|
||||||
// CHECK: return %[[RESULT_TENSOR]] : tensor<?x?xf32>
|
// CHECK: return %[[RESULT_TENSOR]] : tensor<?x?xf32>
|
||||||
// CHECK: }
|
func @tcp_splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||||
func @tcp_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
|
||||||
%0 = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
|
||||||
return %0 : tensor<?x?xf32>
|
return %0 : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
|
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @matmul
|
// CHECK-LABEL: @broadcast_to
|
||||||
func @matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
func @broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||||
// CHECK: tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
// CHECK: tcp.broadcast_to
|
||||||
%0 = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
%0 = tcp.broadcast_to %arg0, %arg1 : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
|
||||||
|
return %0 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @splatted
|
||||||
|
func @splatted(%arg0: f32, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||||
|
// CHECK: tcp.splatted
|
||||||
|
%0 = tcp.splatted %arg0, %arg1 : (f32, tensor<?xindex>) -> tensor<?x?xf32>
|
||||||
return %0 : tensor<?x?xf32>
|
return %0 : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue