torch-mlir/lib/E2E/LowerShapeConstraints.cpp

190 lines
8.4 KiB
C++
Raw Normal View History

Totally rework RefE2E tensor to memref flow. (#42) This now gets the overall "RefE2E" compilation stack to a point that I'm fairly happy with. We simplify it by mostly embracing the "descriptor" view of the world. The overall flow is best understood by reading through the createE2ELoweringPipeline function in lib/E2E/E2E.cpp That function creates a pass pipeline that lowers from "TCF" (which is ~numpy level of abstraction) down to LLVM IR. A brief high-level summary of what happens there: 1. TCF to TCP conversion. This involves reifying error handling in the form of shape constraints. See test/Conversion/TCFToTCP/basic.mlir 2. Lowering shape constraints. This converts shape constraints into eager error-handling code. See test/E2E/lower-shape-constraints.mlir This pass will soon go upstream. Because this lowers to std.assert, some later passes like LowerToNpcomprtABI and LowerToLLVM are updated to properly plumb this through e2e. See test/npcomp-run-mlir/invalid-broadcast.mlir for an execution test that properly aborts in case of an error. 3. Lowering tensors to memrefs. This is done via a series of passes rather than an single mega conversion. Unlike the previous code that mixed in the npcomprt ABI stuff here, it's now a very clean "pure memref" conversion. See test/E2E/lower-*-to-memref.mlir and lib/E2E/TensorToMemref/ Most of the changes are concentrated here. 4. As part of the above, we use the upstream ConvertShapeToStandard for lowering shapes. 5. We lower linalg to loops and lower loops to CFG using upstream passes. 6. Rewrite the "ABI" boundaries of the program to npcomprt data structures (LowerToNpcomprtABI). This mainly affects ABI boundaries and how global tensor constants are represented. One of the major improvements in this commit is that now it's a very clean rewrite that just replaces memrefs on ABI boundaries with !npcomprt.tensor (before there was a get_extent function that is not needed). See test/E2E/lower-to-npcomprt-abi.mlir 7. Lower to LLVM with upstream mlir patterns + some patterns for the npcomprt lowerings. One aspect here that is still a remnant of a non-descriptor-based tensor to memref flow is the BypassShapes + LowerShapedResultsToMemref. BypassShapes wraps the "tensor compute" ops in a tcp.shaped_results (basically a "tie_shape" kind of op), and then LowerShapedResultsToMemref uses those annotations to allocate output buffers while lowering the "tensor compute ops". Note that there are very few "tensor compute" ops currently supported (tcp.add + tcp.broadcast_to), so we just hardcode them in both passes. Realistically, I expect this to go away as we fully embrace the descriptor-based approach for simplicity, so don't look too deep into it.
2020-09-17 08:31:40 +08:00
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "npcomp/E2E/E2E.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
using namespace mlir;
using namespace mlir::NPCOMP;
namespace {
class LowerCstrBroadcastableOp
: public OpRewritePattern<shape::CstrBroadcastableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
// A shape.cstr_* op should be the result of lowering a !shape.shape; it
// should not itself ever consume or produce a !shape.shape.
//
// There is no way to "sink" a !shape.shape type, because one cannot inspect
// if it is an error. The only way to use it safely is to lower the op that
// produced the value to a set of constraints and then use the witness to
// guard a shape.assuming.
//
// Consider for example what we do when lowering TCF to TCP: we need to do a
// shape calculation for the broadcasting. But we create the
// shape.cstr_broadcastable and use its witness to guard a `shape.assuming {
// ... shape.broadcast ...}`. There's never any need to create a
// !shape.shape.
//
// The use of !shape.shape should be restricted to contexts like
// declarations of shape transfer functions, with automatic utilities to
// lower !shape.shape types to corresponding constraints + shape.assuming +
// tensors. In this (npcomp e2e) lowering flow, we don't have any such
// "declarative shape transfer functions" or utilities to expand them to
// constraints. So !shape.shape should never exist in our IR.
//
// Historically, we used !shape.shape type for everything, and
// shape.to_extent_tensor would abort in case of an error. But that's not a
// useful semantics for lowering, since the error is defined to happen as
// part of the shape.to_extent_tensor op, which requires materializing an
// "is error" bit in the IR and carrying it around everywhere that the
// original !shape.shape value was being used. In practice, nobody respects
// that, which opens us up to miscompilations. That is, the lowering
// strategy is either "not emit errors at all" or "emit errors as part of
// lowering e.g. the shape.broadcast op itself" (which technically puts the
// errors in some random location in the IR that is not the
// shape.to_extent_tensor op). E.g. the following code would miscompile with
// either of those ways that these ops get lowered in practice:
// ```
// %shape = shape.broadcast %lhs, %rhs : !shape.shape
// if %cond:
// shape.to_extent_tensor(%shape)
// ```
// It's not possible to correctly compile this code without significant
// contortions (such as carrying an "is error" bit). And to boot, we
// shouldn't be getting into that situation in the first place! But the
// `shape.to_extent_tensor : !shape.shape -> tensor<?xindex>` abstraction
// opens up that possibility.
//
// shape.to_extent_tensor should not really be a thing, since it creates
// these ill-defined situations about where errors are observed. A
// !shape.shape type should only exist (for this compilation flow) as part
// of a utility, something like "I want to do this shape calculation on
// !shape.shape type, create IR that uses tensor<?xindex> and witnesses to
// implement it, on the assumption that the error can be
// observed anywhere inside the shape calculation".
//
// !shape.shape type would still be useful for lowerings that actually
// result in a runtime type that carries an "is error" bit inside it, though
// TBD if such use cases arise.
if (op.getType().isa<shape::ShapeType>() ||
op.lhs().getType().isa<shape::ShapeType>() ||
op.rhs().getType().isa<shape::ShapeType>()) {
return op.emitError() << "Error shapes should not exist at this point";
}
auto loc = op.getLoc();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
Value lhsSmaller =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
Type extentTensorTy = op.lhs().getType();
auto ifOp = rewriter.create<scf::IfOp>(
loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
lhsSmaller,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
},
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
});
Value lesserRank = ifOp.getResult(0);
Value lesserRankOperand = ifOp.getResult(1);
Value greaterRank = ifOp.getResult(2);
Value greaterRankOperand = ifOp.getResult(3);
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
// Compare the shapes extent by extent, and emit errors for
// non-broadcast-compatible shapes.
// Two extents are broadcast-compatible if
// 1. they are both equal, or
// 2. at least one of them is 1.
rewriter.create<scf::ForOp>(
loc, rankDiff, greaterRank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
loc, greaterRankOperand, ValueRange{iv});
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
loc, lesserRankOperand, ValueRange{ivShifted});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
Value extentsAgree =
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
lesserRankOperandExtent);
auto broadcastIsValid =
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
lesserRankOperandExtentIsOne));
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
b.create<scf::YieldOp>(loc);
});
// Now that we have emitted all the assertions, the witness is trivially
// satisfied.
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
};
} // namespace
namespace {
// This pass eliminates shape constraints from the program.
//
// After this pass finishes, there are no !shape.witness types in the program,
// no shape.assuming, no shape.cstr_*.
//
// TODO: This should move to upstream ShapeToStandard conversions.
class LowerShapeConstraints
: public LowerShapeConstraintsBase<LowerShapeConstraints> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<LowerCstrBroadcastableOp>(context);
// Add in the canonicalization patterns for shape.assuming so that it gets
// inlined when its witness becomes a true constant witness.
shape::AssumingOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
return signalPassFailure();
// TODO: Check that there are no remaining !shape.witness, shape.assuming,
// shape.cstr_* ops, etc.
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createLowerShapeConstraintsPass() {
return std::make_unique<LowerShapeConstraints>();
}