mirror of https://github.com/llvm/torch-mlir
[RefE2E] Refactor how tcf.add is lowered.
It was previously going through this awkward route that prematurely created linalg.generic ops, which was an annoying layering problem since we can't compute a shape transfer function for linalg.generic in the general case. Now we pass it through the same path as tcp.matmul, with the shape transfer function being defined for tcp.add. This also removed the need for TCPToLinalg (now deleted). The equivalent of that is happening in lower-shaped-results-to-memref. One interesting outcome of this: we're basically using linalg as a "Buffer TCP". We might want to look into using named structured ops for more of TCP, but that would be a big velocity hit since then any change to the ODS / verification for those ops would be a change to the upstream structured op ODS generator. After we have more experience defining this manually, we should re-evaluate rebasing TCP on generated named linalg ops.pull/52/head
parent
d8675f8ad2
commit
dc8afc9271
|
@ -20,15 +20,6 @@ def ConvertTCFToTCP : Pass<"convert-tcf-to-tcp", "ModuleOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
let constructor = "mlir::NPCOMP::createConvertTCFToTCPPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// TCPToLinalg
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def ConvertTCPToLinalg : Pass<"convert-tcp-to-linalg", "ModuleOp"> {
|
|
||||||
let summary = "Convert TCP to Linalg";
|
|
||||||
let constructor = "mlir::NPCOMP::createConvertTCPToLinalgPass()";
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Basicpy conversions
|
// Basicpy conversions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
//===------------------------------------------------------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H
|
|
||||||
#define NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H
|
|
||||||
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
namespace NPCOMP {
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTCPToLinalgPass();
|
|
||||||
}
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
||||||
#endif // NPCOMP_CONVERSION_TCPTOLINALG_CONVERTTCPTOLINALG_H
|
|
|
@ -1,7 +1,6 @@
|
||||||
add_subdirectory(BasicpyToStd)
|
add_subdirectory(BasicpyToStd)
|
||||||
add_subdirectory(NumpyToTCF)
|
add_subdirectory(NumpyToTCF)
|
||||||
add_subdirectory(TCFToTCP)
|
add_subdirectory(TCFToTCP)
|
||||||
add_subdirectory(TCPToLinalg)
|
|
||||||
|
|
||||||
if(NPCOMP_ENABLE_IREE)
|
if(NPCOMP_ENABLE_IREE)
|
||||||
add_subdirectory(BasicpyToIREEVM)
|
add_subdirectory(BasicpyToIREEVM)
|
||||||
|
@ -20,5 +19,4 @@ add_mlir_library(NPCOMPConversionPasses
|
||||||
NPCOMPBasicpyToSTD
|
NPCOMPBasicpyToSTD
|
||||||
NPCOMPNumpyToTCF
|
NPCOMPNumpyToTCF
|
||||||
NPCOMPTCFToTCP
|
NPCOMPTCFToTCP
|
||||||
NPCOMPTCPToLinalg
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
|
||||||
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
#include "npcomp/Conversion/NumpyToTCF/Passes.h"
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Pass registration
|
// Pass registration
|
||||||
|
|
|
@ -1,18 +0,0 @@
|
||||||
add_mlir_conversion_library(NPCOMPTCPToLinalg
|
|
||||||
TCPToLinalg.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TCPToLinalg
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
NPCOMPConversionPassIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRIR
|
|
||||||
MLIRPass
|
|
||||||
MLIRTransforms
|
|
||||||
MLIRShape
|
|
||||||
)
|
|
|
@ -1,88 +0,0 @@
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace NPCOMP;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ConvertAdd : public OpRewritePattern<tcp::AddOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(tcp::AddOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
size_t rank = op.getType().cast<RankedTensorType>().getRank();
|
|
||||||
SmallVector<StringRef, 6> iterators(rank, getParallelIteratorTypeName());
|
|
||||||
SmallVector<AffineMap, 3> accesses(/*args in + args out*/ 3,
|
|
||||||
rewriter.getMultiDimIdentityMap(rank));
|
|
||||||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
|
||||||
op.getLoc(), llvm::makeArrayRef({op.getType()}),
|
|
||||||
ValueRange({op.lhs(), op.rhs()}),
|
|
||||||
/*args_in=*/2,
|
|
||||||
/*args_out=*/1,
|
|
||||||
/*indexing_maps=*/accesses,
|
|
||||||
/*iterator_types=*/iterators,
|
|
||||||
/*function_ref=*/nullptr);
|
|
||||||
|
|
||||||
Region ®ion = genericOp.region();
|
|
||||||
Block *block = rewriter.createBlock(®ion, region.begin());
|
|
||||||
for (auto operandType : op.getOperandTypes()) {
|
|
||||||
block->addArgument(operandType.cast<RankedTensorType>().getElementType());
|
|
||||||
}
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
Value bodyValue = rewriter.create<AddFOp>(
|
|
||||||
op.getLoc(), block->getArgument(0), block->getArgument(1));
|
|
||||||
rewriter.create<linalg::YieldOp>(op.getLoc(), bodyValue);
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, genericOp.getResults());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ConvertTCPToLinalg : public ConvertTCPToLinalgBase<ConvertTCPToLinalg> {
|
|
||||||
public:
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
||||||
registry.insert<linalg::LinalgDialect>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
ModuleOp module = getOperation();
|
|
||||||
MLIRContext *context = &getContext();
|
|
||||||
|
|
||||||
ConversionTarget target(*context);
|
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
|
|
||||||
patterns.insert<ConvertAdd>(context);
|
|
||||||
target.addIllegalOp<tcp::AddOp>();
|
|
||||||
|
|
||||||
target.addLegalDialect<linalg::LinalgDialect>();
|
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
|
||||||
|
|
||||||
if (failed(applyPartialConversion(module, target, patterns))) {
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
|
||||||
mlir::NPCOMP::createConvertTCPToLinalgPass() {
|
|
||||||
return std::make_unique<ConvertTCPToLinalg>();
|
|
||||||
}
|
|
|
@ -16,61 +16,18 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
static bool isSimpleElementwiseLinalgGeneric(linalg::GenericOp op) {
|
|
||||||
// Only handle generic ops where all operands and results are tensors.
|
|
||||||
if (!llvm::all_of(op.getOperandTypes(),
|
|
||||||
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!llvm::all_of(op.getResultTypes(),
|
|
||||||
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Loosen restrictions on indexing maps.
|
|
||||||
// This will require more principled handling of shape reification
|
|
||||||
// earlier in the compilation stack, as in general output shapes of a
|
|
||||||
// linalg.generic cannot be inferred easily.
|
|
||||||
// See:
|
|
||||||
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
|
|
||||||
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
|
|
||||||
return map.cast<AffineMapAttr>().getValue().isIdentity();
|
|
||||||
})) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
|
|
||||||
return str.cast<StringAttr>().getValue() ==
|
|
||||||
getParallelIteratorTypeName();
|
|
||||||
})) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Don't just open-code all shape transfer functions here.
|
// TODO: Don't just open-code all shape transfer functions here.
|
||||||
// Note: for now, we can't just rely on an OpInterface, since OpInterfaces
|
|
||||||
// cannot be "externally applied". E.g. we can't change the definition of
|
|
||||||
// linalg::GenericOp.
|
|
||||||
static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
||||||
OpBuilder builder(&op);
|
OpBuilder builder(&op);
|
||||||
if (auto linalgGeneric = dyn_cast<linalg::GenericOp>(op)) {
|
|
||||||
// TODO: Avoid this excessive restriction.
|
|
||||||
// This will require more principled handling of the lowering to
|
|
||||||
// linalg.generic -- it should generally happen after this pass, becaue in
|
|
||||||
// general output shapes of a linalg.generic cannot be inferred easily. See:
|
|
||||||
// https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866
|
|
||||||
if (!isSimpleElementwiseLinalgGeneric(linalgGeneric))
|
|
||||||
return {};
|
|
||||||
// All shapes of all operands and results are the same for now. So
|
|
||||||
// arbitrarily pick the first operand.
|
|
||||||
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto broadcastTo = dyn_cast<tcp::BroadcastToOp>(op)) {
|
if (auto broadcastTo = dyn_cast<tcp::BroadcastToOp>(op)) {
|
||||||
return {broadcastTo.shape()};
|
return {broadcastTo.shape()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto add = dyn_cast<tcp::AddOp>(op)) {
|
||||||
|
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
|
||||||
|
}
|
||||||
|
|
||||||
if (auto matmul = dyn_cast<tcp::MatmulOp>(op)) {
|
if (auto matmul = dyn_cast<tcp::MatmulOp>(op)) {
|
||||||
auto lhsRows = builder.create<DimOp>(op.getLoc(), matmul.lhs(), 0);
|
auto lhsRows = builder.create<DimOp>(op.getLoc(), matmul.lhs(), 0);
|
||||||
auto rhsCols = builder.create<DimOp>(op.getLoc(), matmul.rhs(), 1);
|
auto rhsCols = builder.create<DimOp>(op.getLoc(), matmul.rhs(), 1);
|
||||||
|
|
|
@ -27,7 +27,6 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
@ -126,14 +125,6 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
||||||
// TCP does not. So we need to reify the broadcasting and error checking.
|
// TCP does not. So we need to reify the broadcasting and error checking.
|
||||||
pm.addPass(createConvertTCFToTCPPass());
|
pm.addPass(createConvertTCFToTCPPass());
|
||||||
|
|
||||||
// Convert tcp ops to Linalg where possible, as we want generic linalg
|
|
||||||
// tensor->memref to do most of the mechanical work of rewriting ops in
|
|
||||||
// terms of tensors to ops in terms of memrefs (since it is easy on that
|
|
||||||
// representation).
|
|
||||||
// TODO: Does this make sense? Should we instead go to an "TCP on buffers" and
|
|
||||||
// only lower to linalg at the buffer level?
|
|
||||||
pm.addPass(createConvertTCPToLinalgPass());
|
|
||||||
|
|
||||||
// For operations with a shape transfer function, explicitly bypass their
|
// For operations with a shape transfer function, explicitly bypass their
|
||||||
// shape computations with tcp.shaped_results ops.
|
// shape computations with tcp.shaped_results ops.
|
||||||
//
|
//
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
#include "mlir/Pass/PassRegistry.h"
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
|
|
||||||
|
@ -134,57 +133,45 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LowerLinalgGenericTensorToMemRef
|
class LowerTcpAddOp : public OpConversionPattern<tcp::AddOp> {
|
||||||
: public OpConversionPattern<linalg::GenericOp> {
|
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
|
matchAndRewrite(tcp::AddOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
// TODO: Replace this with more generic code operating on named
|
|
||||||
// structured ops too.
|
|
||||||
|
|
||||||
// These checks mirror those in BypassShapes.
|
|
||||||
if (!llvm::all_of(op.getOperandTypes(),
|
|
||||||
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "all operands must be tensors");
|
|
||||||
}
|
|
||||||
if (!llvm::all_of(op.getResultTypes(),
|
|
||||||
[](Type type) { return type.isa<RankedTensorType>(); })) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "all results must be tensors");
|
|
||||||
}
|
|
||||||
if (!llvm::all_of(op.indexing_maps(), [](Attribute map) {
|
|
||||||
return map.cast<AffineMapAttr>().getValue().isIdentity();
|
|
||||||
})) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "all indexing maps must be identity maps");
|
|
||||||
}
|
|
||||||
if (!llvm::all_of(op.iterator_types(), [](Attribute str) {
|
|
||||||
return str.cast<StringAttr>().getValue() ==
|
|
||||||
getParallelIteratorTypeName();
|
|
||||||
})) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "all iterator types must be 'parallel'");
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 6> memrefs(operands.begin(), operands.end());
|
|
||||||
|
|
||||||
auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc());
|
auto resultsOrFailure = allocateResults(op, rewriter, op.getLoc());
|
||||||
if (failed(resultsOrFailure))
|
if (failed(resultsOrFailure))
|
||||||
return failure();
|
return failure();
|
||||||
auto results = *resultsOrFailure;
|
auto results = *resultsOrFailure;
|
||||||
memrefs.append(results.begin(), results.end());
|
|
||||||
|
|
||||||
auto newGeneric = rewriter.create<linalg::GenericOp>(
|
SmallVector<Value, 6> args;
|
||||||
op.getLoc(), llvm::None, ValueRange(memrefs), op.getAttrs());
|
args.append(operands.begin(), operands.end());
|
||||||
newGeneric.region().getBlocks().clear();
|
args.append(results.begin(), results.end());
|
||||||
BlockAndValueMapping mapper;
|
|
||||||
op.region().cloneInto(&newGeneric.region(), mapper);
|
size_t rank = op.getType().cast<RankedTensorType>().getRank();
|
||||||
for (auto memref : results) {
|
SmallVector<StringRef, 6> iterators(rank, getParallelIteratorTypeName());
|
||||||
newGeneric.region().front().addArgument(
|
// TODO: Generalize this to other elementwise ops.
|
||||||
memref.getType().cast<MemRefType>().getElementType());
|
// All we need to do is to have a mapping of tcp.foo to scalar.foo.
|
||||||
}
|
// TODO: Should we just use linalg named ops for most of TCP?
|
||||||
|
// Doing so would make tcp very consistent, but also it would, at this early
|
||||||
|
// stage, make most non-trivial changes also require co-design with the
|
||||||
|
// linalg ODS generator, which would be a very slow process.
|
||||||
|
auto argsIn = operands.size();
|
||||||
|
auto argsOut = results.size();
|
||||||
|
SmallVector<AffineMap, 3> accesses(argsIn + argsOut,
|
||||||
|
rewriter.getMultiDimIdentityMap(rank));
|
||||||
|
rewriter.create<linalg::GenericOp>(
|
||||||
|
op.getLoc(), /*resultTypes=*/llvm::None,
|
||||||
|
/*args=*/args,
|
||||||
|
/*args_in=*/argsIn,
|
||||||
|
/*args_out=*/argsOut,
|
||||||
|
/*indexing_maps=*/accesses,
|
||||||
|
/*iterator_types=*/iterators,
|
||||||
|
/*bodyBuilder=*/
|
||||||
|
[](OpBuilder &builder, Location loc, ValueRange regionArgs) {
|
||||||
|
auto add = builder.create<AddFOp>(loc, regionArgs[0], regionArgs[1]);
|
||||||
|
builder.create<linalg::YieldOp>(loc, ValueRange({add}));
|
||||||
|
});
|
||||||
rewriter.replaceOp(op, results);
|
rewriter.replaceOp(op, results);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -295,23 +282,10 @@ class LowerShapedResultsToMemref
|
||||||
target.addLegalOp<tcp::MemrefToTensorOp>();
|
target.addLegalOp<tcp::MemrefToTensorOp>();
|
||||||
target.addLegalOp<tcp::TensorToMemrefOp>();
|
target.addLegalOp<tcp::TensorToMemrefOp>();
|
||||||
|
|
||||||
patterns.insert<LowerLinalgGenericTensorToMemRef>(typeConverter, context);
|
|
||||||
target.addDynamicallyLegalOp<linalg::GenericOp>([](linalg::GenericOp op) {
|
|
||||||
if (llvm::any_of(op.getOperandTypes(), [](Type type) {
|
|
||||||
return type.isa<RankedTensorType>();
|
|
||||||
})) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (llvm::any_of(op.getResultTypes(), [](Type type) {
|
|
||||||
return type.isa<RankedTensorType>();
|
|
||||||
})) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
|
|
||||||
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
||||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||||
|
patterns.insert<LowerTcpAddOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<tcp::AddOp>();
|
||||||
patterns.insert<LowerTcpMatmulOp>(typeConverter, context);
|
patterns.insert<LowerTcpMatmulOp>(typeConverter, context);
|
||||||
target.addIllegalOp<tcp::MatmulOp>();
|
target.addIllegalOp<tcp::MatmulOp>();
|
||||||
|
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
// RUN: npcomp-opt <%s -convert-tcp-to-linalg | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @f
|
|
||||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
||||||
// CHECK: linalg.generic
|
|
||||||
%0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
return %0 : tensor<?xf32>
|
|
||||||
}
|
|
|
@ -1,22 +1,37 @@
|
||||||
// RUN: npcomp-opt -bypass-shapes <%s | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt -bypass-shapes <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
#map0 = affine_map<(d0) -> (d0)>
|
|
||||||
// CHECK-LABEL: func @linalg_generic
|
|
||||||
func @linalg_generic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
||||||
// This is an elementwise linalg op, so output shape is equal to input shape.
|
|
||||||
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0
|
|
||||||
// CHECK: tcp.shaped_results %[[SHAPE]]
|
|
||||||
%0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0, %arg1 {
|
|
||||||
^bb0(%arg2: f32, %arg3: f32):
|
|
||||||
%8 = addf %arg2, %arg3 : f32
|
|
||||||
linalg.yield %8 : f32
|
|
||||||
}: tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
|
|
||||||
return %0 : tensor<?xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tcp_broadcast_to
|
// CHECK-LABEL: func @tcp_broadcast_to
|
||||||
func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) {
|
func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) {
|
||||||
// CHECK: %0 = tcp.shaped_results %arg1
|
// CHECK: %0 = tcp.shaped_results %arg1
|
||||||
%0 = "tcp.broadcast_to"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
|
%0 = "tcp.broadcast_to"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xindex>) -> tensor<?x?xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tcp_add(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %[[LHS]]
|
||||||
|
// CHECK: %[[RET:.*]] = tcp.shaped_results %[[LHSSHAPE]]
|
||||||
|
// CHECK: return %[[RET:.*]] : tensor<?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tcp_matmul(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[LHSCOLS:.*]] = dim %[[LHS]], %[[C0]]
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
|
// CHECK: %[[RHSROWS:.*]] = dim %[[RHS]], %[[C1]]
|
||||||
|
// CHECK: %[[RESULTSHAPE:.*]] = tensor_from_elements %[[LHSCOLS]], %[[RHSROWS]]
|
||||||
|
// CHECK: %[[RET:.*]] = tcp.shaped_results %[[RESULTSHAPE]] {
|
||||||
|
// CHECK: return %[[RET:.*]] : tensor<?x?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @tcp_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
|
%0 = tcp.matmul %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
return %0 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -1,26 +1,5 @@
|
||||||
// RUN: npcomp-opt -lower-shaped-results-to-memref <%s -split-input-file | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt -lower-shaped-results-to-memref <%s -split-input-file | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
#map0 = affine_map<(d0) -> (d0)>
|
|
||||||
// CHECK-LABEL: func @linalg_generic
|
|
||||||
func @linalg_generic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xindex>) -> tensor<?xf32> {
|
|
||||||
// CHECK: %[[LHS:.*]] = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
|
||||||
// CHECK: %[[RHS:.*]] = tcp.tensor_to_memref %arg1 : tensor<?xf32> -> memref<?xf32>
|
|
||||||
// CHECK: %[[DST:.*]] = tcp.alloc_memref %arg2 : memref<?xf32>
|
|
||||||
// CHECK: linalg.generic {{.*}} %[[LHS]], %[[RHS]], %[[DST]]
|
|
||||||
// CHECK-NOT: tcp.shaped_results
|
|
||||||
%0 = tcp.shaped_results %arg2 {
|
|
||||||
%0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0, %arg1 {
|
|
||||||
^bb0(%arg3: f32, %arg4: f32):
|
|
||||||
%8 = addf %arg3, %arg4 : f32
|
|
||||||
linalg.yield %8 : f32
|
|
||||||
} : tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
|
|
||||||
tcp.yield %0 : tensor<?xf32>
|
|
||||||
} : tensor<?xindex> -> tensor<?xf32>
|
|
||||||
return %0 : tensor<?xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tcp_broadcast_to
|
// CHECK-LABEL: func @tcp_broadcast_to
|
||||||
func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
|
||||||
// Check for two nested loops, but don't look at more detail for now.
|
// Check for two nested loops, but don't look at more detail for now.
|
||||||
|
@ -36,6 +15,31 @@ func @tcp_broadcast_to(%arg0: tensor<?xf32>, %arg1: tensor<?xindex>) -> tensor<?
|
||||||
return %0 : tensor<?x?xf32>
|
return %0 : tensor<?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @tcp_add(
|
||||||
|
// CHECK-SAME: %arg0: tensor<?xf32>,
|
||||||
|
// CHECK-SAME: %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
// CHECK: %[[LHSSHAPE:.*]] = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
|
||||||
|
// CHECK: %[[LHS:.*]] = tcp.tensor_to_memref %arg0 : tensor<?xf32> -> memref<?xf32>
|
||||||
|
// CHECK: %[[RHS:.*]] = tcp.tensor_to_memref %arg1 : tensor<?xf32> -> memref<?xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = tcp.alloc_memref %[[LHSSHAPE]] : memref<?xf32>
|
||||||
|
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, {{.*}}} %[[LHS]], %[[RHS]], %[[RESULT]] {
|
||||||
|
// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
|
||||||
|
// CHECK: %[[VAL_9:.*]] = addf %[[VAL_6]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: linalg.yield %[[VAL_9]] : f32
|
||||||
|
// CHECK: }: memref<?xf32>, memref<?xf32>, memref<?xf32>
|
||||||
|
// CHECK: %[[RET:.*]] = tcp.memref_to_tensor %[[RESULT]] : memref<?xf32> -> tensor<?xf32>
|
||||||
|
// CHECK: return %[[RET]] : tensor<?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @tcp_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
|
||||||
|
%1 = tcp.shaped_results %0 {
|
||||||
|
%2 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
tcp.yield %2 : tensor<?xf32>
|
||||||
|
} : tensor<?xindex> -> tensor<?xf32>
|
||||||
|
return %1 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @tcp_matmul(
|
// CHECK-LABEL: func @tcp_matmul(
|
||||||
|
|
Loading…
Reference in New Issue