torch-mlir/lib/E2E/LowerRankedShapes.cpp

216 lines
8.7 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "npcomp/E2E/E2E.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
namespace {
class LowerConstShapeOp : public OpConversionPattern<shape::ConstShapeOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::ConstShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto extents = llvm::to_vector<6>(llvm::map_range(
op.shape().getValues<int64_t>(), [&](int64_t extent) -> Value {
return rewriter.create<ConstantIndexOp>(op.getLoc(), extent);
}));
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
op, rewriter.getType<shape::ShapeType>(), extents);
return success();
}
};
} // namespace
namespace {
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::BroadcastOp::Adaptor adaptor(operands);
auto lhs = adaptor.lhs().getDefiningOp<shape::FromExtentsOp>();
auto rhs = adaptor.rhs().getDefiningOp<shape::FromExtentsOp>();
if (!lhs || !rhs)
return rewriter.notifyMatchFailure(op, "operands not converted");
// Establish invariant that rank(lhs) >= rank(rhs).
if (lhs.extents().size() < rhs.extents().size())
std::swap(lhs, rhs);
auto rankDiscrepancy = lhs.extents().size() - rhs.extents().size();
// Helper that creates IR
// ```
// abort_if(extent != resultExtent && extent != 1)
// ```
// This is the numpy broadcasting legality check.
auto createAbortIfIllegalBroadcastExtent = [&](Value extent,
Value resultExtent) {
auto c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
auto extentNeMax = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne,
extent, resultExtent);
auto extentNeOne =
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, extent, c1);
auto bothTrue =
rewriter.create<AndOp>(op.getLoc(), extentNeMax, extentNeOne);
// TODO: Should there be a more generic error-handling dialect?
// It seems a bit awkward to hardcode npcomprt here.
rewriter.create<npcomprt::AbortIfOp>(op.getLoc(), bothTrue);
};
auto resultExtents = llvm::to_vector<6>(lhs.extents());
for (int i = 0, e = rhs.extents().size(); i < e; i++) {
auto lhsExtent = lhs.extents()[rankDiscrepancy + i];
auto rhsExtent = rhs.extents()[i];
auto ugt = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ugt,
lhsExtent, rhsExtent);
auto max =
rewriter.create<SelectOp>(op.getLoc(), ugt, lhsExtent, rhsExtent);
auto &resultExtent = resultExtents[rankDiscrepancy + i];
resultExtent = max;
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
}
// TODO: Remove the return type once ODS is fixed to do proper inference.
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
op, shape::ShapeType::get(rewriter.getContext()), resultExtents);
return success();
}
};
} // namespace
// Rewrite `get_extent(from_extents(x1,x2,x3), N) -> xN`
//
// TODO: this should be a fold on tcp::GetExtentOp.
// (though then the contract of this pass depends on that set of folds,
// which isn't great)
//
// Also, we use OpConversionPattern to get post-rewrite operands as above.
namespace {
class LowerShapeGetExtentOp : public OpConversionPattern<tcp::GetExtentOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tcp::GetExtentOp::Adaptor adaptor(operands);
auto fromExtents = adaptor.shape().getDefiningOp<shape::FromExtentsOp>();
if (!fromExtents)
return rewriter.notifyMatchFailure(op, "not a from_extents op");
int64_t dim = op.dim().getLimitedValue();
rewriter.replaceOp(op, ValueRange(fromExtents.extents())[dim]);
return success();
}
};
} // namespace
namespace {
// Now that we have lowered ranked shapes, which reifies the eager
// error-handling code, the tcp::ShapeObserveErrorOp's are no longer
// needed.
class EraseShapeObserveErrorOp
: public OpConversionPattern<tcp::ShapeObserveErrorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
} // namespace
// Basic invariant of this pass:
// Every def of a !shape.shape type is replaced with a
// `shape.from_extents` op.
// When converting an op, look for the `shape.from_extents` op that
// defined all operands, then do a computation on the extents (i.e.
// operands to the `shape.from_extents` op) and produce a
// `shape.from_extents` op.
//
// We expect that previous passes have inserted a "root" set of
// shape::FromExtentsOp's that allow this process to get started.
//
// We then use this to resolve get_extent ops by using a rewrite
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
// maximally many places due to the above invariant.
//
// This is similar to the approach that is used in IREE. It is basically a
// combination of the ConvertShapeToShapex pass and the
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
// These patterns have to be "conversion patterns" since the `operands` argument
// gives access to the post-conversion operands from earlier ops.
//
// This pass depends heavily on ranked shapes, since only ranked shapes can
// be statically expanded to a fixed set of SSA extents.
//
// TODO: This approach doesn't naively work with control flow.
// In the presence of non-cyclic control flow, we can just generalize the
// `getDefiningOp<shape::FromExtentsOp>()` calls into something that will
// look through block arguments and rewrite "phi of shapes -> phi of extents".
// In the presence of cyclic control flow, we need to somehow resolve the
// ranks of use-def cycles ahead of time or optimistically assume that
// backedges will match the rank of forward edges, and somehow be robust
// when that assumption fails.
namespace {
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<LowerConstShapeOp>(context);
patterns.insert<LowerShapeBroadcastOp>(context);
patterns.insert<LowerShapeGetExtentOp>(context);
patterns.insert<EraseShapeObserveErrorOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<shape::ShapeOfOp>();
target.addIllegalOp<shape::BroadcastOp>();
target.addIllegalOp<tcp::GetExtentOp>();
target.addLegalOp<shape::FromExtentsOp>();
target.addLegalOp<npcomprt::AbortIfOp>();
target.addLegalDialect<StandardOpsDialect>();
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
// Erase some stray shape ops from the program. They can't be
// deleted during conversion because they become unused only after
// subsequent patterns bypass them.
auto walkResult = func.walk([](Operation *op) {
if (!isa<shape::FromExtentsOp>(op))
return WalkResult::advance();
if (!op->use_empty()) {
op->emitError("could not be eliminated");
return WalkResult::interrupt();
}
op->erase();
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createLowerRankedShapesPass() {
return std::make_unique<LowerRankedShapes>();
}