torch-mlir/lib/Conversion/TCFToTCP/TCFToTCP.cpp

185 lines
6.8 KiB
C++
Raw Normal View History

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/TCF/IR/TCFOps.h"
2020-09-09 23:12:52 +08:00
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
static RankedTensorType getExtentTensorType(Builder &builder) {
return RankedTensorType::get({ShapedType::kDynamicSize},
builder.getIndexType());
}
// Non-templated version of the body of ConvertBinaryElementwise to keep things
// simple.
static LogicalResult
matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Location loc = op->getLoc();
Value result = op->getResult(0);
auto lhsType = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsType = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsType || !rhsType)
return rewriter.notifyMatchFailure(op, "requires ranked tensors");
Value lhsShape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
Value rhsShape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
// Create the constraints, and the assuming region.
Value witness =
rewriter.create<shape::CstrBroadcastableOp>(loc, lhsShape, rhsShape);
auto assuming = rewriter.create<shape::AssumingOp>(
loc, ArrayRef<Type>{result.getType()}, witness);
// Start building the region body.
rewriter.createBlock(&assuming.doRegion());
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
loc, getExtentTensorType(rewriter), lhsShape, rhsShape,
/*error=*/nullptr);
// TODO: It's annoying to do the dynamic broadcast above then
// do the static transfer function here. Would be nice if they could
// somehow be unified.
SmallVector<int64_t, 6> broadcastedStaticShape;
OpTrait::util::getBroadcastedShape(lhsType.getShape(), rhsType.getShape(),
broadcastedStaticShape);
auto resultType =
RankedTensorType::get(broadcastedStaticShape, lhsType.getElementType());
Value lhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
loc, resultType, lhs, broadcastedShape);
Value rhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
loc, resultType, rhs, broadcastedShape);
Value binaryOpResult;
if (isa<tcf::AddOp>(op)) {
binaryOpResult = rewriter.create<tcp::AddOp>(
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
} else if (isa<tcf::MaxOp>(op)) {
binaryOpResult = rewriter.create<tcp::MaxOp>(
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
} else if (isa<tcf::MulOp>(op)) {
binaryOpResult = rewriter.create<tcp::MulOp>(
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
} else {
op->dump();
llvm::report_fatal_error(
"unhandled op (see dump above): TCF->TCP binary elementwise");
}
rewriter.create<shape::AssumingYieldOp>(loc, binaryOpResult);
Totally rework RefE2E tensor to memref flow. (#42) This now gets the overall "RefE2E" compilation stack to a point that I'm fairly happy with. We simplify it by mostly embracing the "descriptor" view of the world. The overall flow is best understood by reading through the createE2ELoweringPipeline function in lib/E2E/E2E.cpp That function creates a pass pipeline that lowers from "TCF" (which is ~numpy level of abstraction) down to LLVM IR. A brief high-level summary of what happens there: 1. TCF to TCP conversion. This involves reifying error handling in the form of shape constraints. See test/Conversion/TCFToTCP/basic.mlir 2. Lowering shape constraints. This converts shape constraints into eager error-handling code. See test/E2E/lower-shape-constraints.mlir This pass will soon go upstream. Because this lowers to std.assert, some later passes like LowerToNpcomprtABI and LowerToLLVM are updated to properly plumb this through e2e. See test/npcomp-run-mlir/invalid-broadcast.mlir for an execution test that properly aborts in case of an error. 3. Lowering tensors to memrefs. This is done via a series of passes rather than an single mega conversion. Unlike the previous code that mixed in the npcomprt ABI stuff here, it's now a very clean "pure memref" conversion. See test/E2E/lower-*-to-memref.mlir and lib/E2E/TensorToMemref/ Most of the changes are concentrated here. 4. As part of the above, we use the upstream ConvertShapeToStandard for lowering shapes. 5. We lower linalg to loops and lower loops to CFG using upstream passes. 6. Rewrite the "ABI" boundaries of the program to npcomprt data structures (LowerToNpcomprtABI). This mainly affects ABI boundaries and how global tensor constants are represented. One of the major improvements in this commit is that now it's a very clean rewrite that just replaces memrefs on ABI boundaries with !npcomprt.tensor (before there was a get_extent function that is not needed). See test/E2E/lower-to-npcomprt-abi.mlir 7. Lower to LLVM with upstream mlir patterns + some patterns for the npcomprt lowerings. One aspect here that is still a remnant of a non-descriptor-based tensor to memref flow is the BypassShapes + LowerShapedResultsToMemref. BypassShapes wraps the "tensor compute" ops in a tcp.shaped_results (basically a "tie_shape" kind of op), and then LowerShapedResultsToMemref uses those annotations to allocate output buffers while lowering the "tensor compute ops". Note that there are very few "tensor compute" ops currently supported (tcp.add + tcp.broadcast_to), so we just hardcode them in both passes. Realistically, I expect this to go away as we fully embrace the descriptor-based approach for simplicity, so don't look too deep into it.
2020-09-17 08:31:40 +08:00
// Finally, replace with the results of the shape.assuming
rewriter.replaceOp(op, assuming.getResults());
return success();
}
Totally rework RefE2E tensor to memref flow. (#42) This now gets the overall "RefE2E" compilation stack to a point that I'm fairly happy with. We simplify it by mostly embracing the "descriptor" view of the world. The overall flow is best understood by reading through the createE2ELoweringPipeline function in lib/E2E/E2E.cpp That function creates a pass pipeline that lowers from "TCF" (which is ~numpy level of abstraction) down to LLVM IR. A brief high-level summary of what happens there: 1. TCF to TCP conversion. This involves reifying error handling in the form of shape constraints. See test/Conversion/TCFToTCP/basic.mlir 2. Lowering shape constraints. This converts shape constraints into eager error-handling code. See test/E2E/lower-shape-constraints.mlir This pass will soon go upstream. Because this lowers to std.assert, some later passes like LowerToNpcomprtABI and LowerToLLVM are updated to properly plumb this through e2e. See test/npcomp-run-mlir/invalid-broadcast.mlir for an execution test that properly aborts in case of an error. 3. Lowering tensors to memrefs. This is done via a series of passes rather than an single mega conversion. Unlike the previous code that mixed in the npcomprt ABI stuff here, it's now a very clean "pure memref" conversion. See test/E2E/lower-*-to-memref.mlir and lib/E2E/TensorToMemref/ Most of the changes are concentrated here. 4. As part of the above, we use the upstream ConvertShapeToStandard for lowering shapes. 5. We lower linalg to loops and lower loops to CFG using upstream passes. 6. Rewrite the "ABI" boundaries of the program to npcomprt data structures (LowerToNpcomprtABI). This mainly affects ABI boundaries and how global tensor constants are represented. One of the major improvements in this commit is that now it's a very clean rewrite that just replaces memrefs on ABI boundaries with !npcomprt.tensor (before there was a get_extent function that is not needed). See test/E2E/lower-to-npcomprt-abi.mlir 7. Lower to LLVM with upstream mlir patterns + some patterns for the npcomprt lowerings. One aspect here that is still a remnant of a non-descriptor-based tensor to memref flow is the BypassShapes + LowerShapedResultsToMemref. BypassShapes wraps the "tensor compute" ops in a tcp.shaped_results (basically a "tie_shape" kind of op), and then LowerShapedResultsToMemref uses those annotations to allocate output buffers while lowering the "tensor compute ops". Note that there are very few "tensor compute" ops currently supported (tcp.add + tcp.broadcast_to), so we just hardcode them in both passes. Realistically, I expect this to go away as we fully embrace the descriptor-based approach for simplicity, so don't look too deep into it.
2020-09-17 08:31:40 +08:00
namespace {
template <typename SourceOp>
class ConvertBinaryElementwise : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
return matchAndRewriteBinaryElementwise(op, rewriter);
}
};
} // namespace
static LogicalResult
matchAndRewriteUnaryElementwise(Operation *op, PatternRewriter &rewriter) {
if (isa<tcf::ExpOp>(op)) {
rewriter.replaceOpWithNewOp<tcp::ExpOp>(op, op->getOperand(0));
} else if (isa<tcf::TanhOp>(op)) {
rewriter.replaceOpWithNewOp<tcp::TanhOp>(op, op->getOperand(0));
} else {
op->dump();
llvm::report_fatal_error(
"unhandled op (see dump above): TCF->TCP unary elementwise");
}
return success();
}
namespace {
template <typename SourceOp>
class ConvertUnaryElementwise : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
return matchAndRewriteUnaryElementwise(op, rewriter);
}
};
} // namespace
namespace {
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tcf::MatmulOp op,
PatternRewriter &rewriter) const override {
// Create the constraints, and the assuming region.
Value lhsK = rewriter.create<DimOp>(op.getLoc(), op.lhs(), 1);
Value rhsK = rewriter.create<DimOp>(op.getLoc(), op.rhs(), 0);
Value matchingK =
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::eq, lhsK, rhsK);
Value witness = rewriter.create<shape::CstrRequireOp>(
op.getLoc(), matchingK, "mismatching contracting dimension for matmul");
auto assuming = rewriter.create<shape::AssumingOp>(
op.getLoc(), ArrayRef<Type>{op.getType()}, witness);
// Build the region body.
rewriter.createBlock(&assuming.doRegion());
Value matmul = rewriter.create<tcp::MatmulOp>(op.getLoc(), op.getType(),
op.lhs(), op.rhs());
rewriter.create<shape::AssumingYieldOp>(op.getLoc(), matmul);
// Finally, replace with the results of the shape.assuming
rewriter.replaceOp(op, assuming.getResults());
return success();
}
};
} // namespace
namespace {
class ConvertTCFToTCP : public ConvertTCFToTCPBase<ConvertTCFToTCP> {
public:
2020-09-09 23:12:52 +08:00
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, tcp::TCPDialect>();
2020-09-09 23:12:52 +08:00
}
void runOnOperation() override {
ModuleOp module = getOperation();
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<ConvertUnaryElementwise<tcf::ExpOp>,
ConvertUnaryElementwise<tcf::TanhOp>>(context);
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
ConvertBinaryElementwise<tcf::MaxOp>,
ConvertBinaryElementwise<tcf::MulOp>>(context);
patterns.insert<ConvertMatmul>(context);
(void)applyPatternsAndFoldGreedily(module, patterns);
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::createConvertTCFToTCPPass() {
return std::make_unique<ConvertTCFToTCP>();
}