mirror of https://github.com/llvm/torch-mlir
269 lines
11 KiB
C++
269 lines
11 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "npcomp/E2E/E2E.h"
|
|
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::NPCOMP;
|
|
|
|
namespace {
|
|
class LowerConstShapeOp : public OpConversionPattern<shape::ConstShapeOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(shape::ConstShapeOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto extents = llvm::to_vector<6>(llvm::map_range(
|
|
op.shape().getValues<int64_t>(), [&](int64_t extent) -> Value {
|
|
return rewriter.create<ConstantIndexOp>(op.getLoc(), extent);
|
|
}));
|
|
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
|
|
op, rewriter.getType<shape::ShapeType>(), extents);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
|
|
// Given an operand that is either a Shape or Extent Tensor, returns an
|
|
// Extent Tensor or nullptr if this cannot be locally determined.
|
|
// The return value, if !nullptr, will be a 1D RankedTensorType (with possibly
|
|
// unknown element).
|
|
Value findExtentsFromShape(Value operand, bool requireKnownRank) {
|
|
if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
|
|
if (tensorType.getRank() == 1 &&
|
|
(!requireKnownRank || tensorType.hasStaticShape())) {
|
|
return operand;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
shape::BroadcastOp::Adaptor adaptor(operands);
|
|
// When the ranks are statically known, generate non-branchy code.
|
|
// TODO: Generate rank-generic code.
|
|
auto lhsExtents = findExtentsFromShape(adaptor.lhs(), true);
|
|
auto rhsExtents = findExtentsFromShape(adaptor.rhs(), true);
|
|
if (!lhsExtents || !rhsExtents)
|
|
return rewriter.notifyMatchFailure(op, "dynamic extents not supported");
|
|
|
|
// Establish invariant that rank(lhs) >= rank(rhs).
|
|
auto lhsSize = lhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
|
|
auto rhsSize = rhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
|
|
if (lhsSize < rhsSize) {
|
|
std::swap(lhsExtents, rhsExtents);
|
|
std::swap(lhsSize, rhsSize);
|
|
}
|
|
auto rankDiscrepancy = lhsSize - rhsSize;
|
|
|
|
// Helper that creates IR
|
|
// ```
|
|
// abort_if(extent != resultExtent && extent != 1)
|
|
// ```
|
|
// This is the numpy broadcasting legality check.
|
|
auto createAbortIfIllegalBroadcastExtent = [&](Value extent,
|
|
Value resultExtent) {
|
|
auto c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
|
|
auto extentNeMax = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne,
|
|
extent, resultExtent);
|
|
auto extentNeOne =
|
|
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, extent, c1);
|
|
auto bothTrue =
|
|
rewriter.create<AndOp>(op.getLoc(), extentNeMax, extentNeOne);
|
|
// TODO: Should there be a more generic error-handling dialect?
|
|
// It seems a bit awkward to hardcode npcomprt here.
|
|
rewriter.create<npcomprt::AbortIfOp>(op.getLoc(), bothTrue);
|
|
};
|
|
|
|
SmallVector<Value, 6> resultExtents;
|
|
for (int i = 0, e = lhsSize; i < e; i++) {
|
|
auto lhsDim = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
|
auto lhsExtent = rewriter.create<ExtractElementOp>(
|
|
op.getLoc(), lhsExtents, ValueRange{lhsDim});
|
|
if (i < rankDiscrepancy) {
|
|
// Padded extent.
|
|
resultExtents.push_back(lhsExtent);
|
|
continue;
|
|
}
|
|
|
|
// Non-padded extent.
|
|
auto rhsDim =
|
|
rewriter.create<ConstantIndexOp>(op.getLoc(), i - rankDiscrepancy);
|
|
auto rhsExtent = rewriter.create<ExtractElementOp>(
|
|
op.getLoc(), rhsExtents, ValueRange{rhsDim});
|
|
auto ugt = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ugt,
|
|
lhsExtent, rhsExtent);
|
|
auto resultExtent =
|
|
rewriter.create<SelectOp>(op.getLoc(), ugt, lhsExtent, rhsExtent);
|
|
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
|
|
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
|
|
resultExtents.push_back(resultExtent);
|
|
}
|
|
|
|
// TODO: Remove the return type once ODS is fixed to do proper inference.
|
|
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
|
|
op, shape::ShapeType::get(rewriter.getContext()), resultExtents);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
class LowerShapeToExtentTensorOp
|
|
: public OpConversionPattern<shape::ToExtentTensorOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
shape::ToExtentTensorOpAdaptor adaptor(operands);
|
|
if (adaptor.input().getType().isa<shape::ShapeType>()) {
|
|
// Convert by matching to a producing FromExtentsOp.
|
|
auto fromExtents = adaptor.input().getDefiningOp<shape::FromExtentsOp>();
|
|
if (!fromExtents) {
|
|
return rewriter.notifyMatchFailure(op, "not a from_extents op");
|
|
}
|
|
rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op,
|
|
fromExtents.extents());
|
|
return success();
|
|
}
|
|
|
|
// Assume that it is already an extent tensor.
|
|
// TODO: Since these ops are all multi-type, there should be a utility
|
|
// for switching on the allowable types instead of just assuming that it
|
|
// is an extent tensor.
|
|
rewriter.replaceOp(op, adaptor.input());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class LowerShapeGetExtentOp : public OpConversionPattern<shape::GetExtentOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(shape::GetExtentOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
shape::GetExtentOp::Adaptor adaptor(operands);
|
|
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, adaptor.shape(),
|
|
adaptor.dim());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
// Now that we have lowered ranked shapes, which reifies the eager
|
|
// error-handling code, the tcp::ShapeObserveErrorOp's are no longer
|
|
// needed.
|
|
class EraseShapeObserveErrorOp
|
|
: public OpConversionPattern<tcp::ShapeObserveErrorOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
// Basic invariant of this pass:
|
|
// Every `shape.from_extents` op operating on an extent tensor
|
|
// (`tensor<?xindex>`) is replaced by corresponding standard ops and folded
|
|
// away (for the ranked case, it should be possible to eliminate these).
|
|
//
|
|
// We expect that previous passes have inserted a "root" set of
|
|
// shape::FromExtentsOp's that allow this process to get started.
|
|
//
|
|
// This is similar to the approach that is used in IREE. It is basically a
|
|
// combination of the ConvertShapeToShapex pass and the
|
|
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
|
|
// These patterns have to be "conversion patterns" since the `operands` argument
|
|
// gives access to the post-conversion operands from earlier ops.
|
|
//
|
|
// This pass depends heavily on ranked shapes, since only ranked shapes can
|
|
// be statically expanded to a fixed set of SSA extents.
|
|
//
|
|
// TODO: This approach doesn't naively work with control flow.
|
|
// In the presence of non-cyclic control flow, we can just generalize the
|
|
// `getDefiningOp<shape::FromExtentsOp>()` calls into something that will
|
|
// look through block arguments and rewrite "phi of shapes -> phi of extents".
|
|
// In the presence of cyclic control flow, we need to somehow resolve the
|
|
// ranks of use-def cycles ahead of time or optimistically assume that
|
|
// backedges will match the rank of forward edges, and somehow be robust
|
|
// when that assumption fails.
|
|
//
|
|
// TODO: Add in a fold of
|
|
// `extract_element(tensor_from_elements(x0, x1, ...), n) -> xn` to restore
|
|
// the above invariant without relying on a subsequent canonicalization
|
|
// step.
|
|
namespace {
|
|
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|
void runOnOperation() {
|
|
auto func = getOperation();
|
|
auto *context = &getContext();
|
|
|
|
OwningRewritePatternList patterns;
|
|
patterns.insert<LowerConstShapeOp>(context);
|
|
patterns.insert<LowerShapeBroadcastOp>(context);
|
|
patterns.insert<LowerShapeGetExtentOp>(context);
|
|
patterns.insert<LowerShapeToExtentTensorOp>(context);
|
|
patterns.insert<EraseShapeObserveErrorOp>(context);
|
|
ConversionTarget target(*context);
|
|
target.addIllegalOp<shape::ShapeOfOp>();
|
|
target.addIllegalOp<shape::BroadcastOp>();
|
|
target.addIllegalOp<shape::GetExtentOp>();
|
|
target.addLegalOp<shape::FromExtentsOp>();
|
|
target.addIllegalOp<shape::ToExtentTensorOp>();
|
|
target.addLegalOp<npcomprt::AbortIfOp>();
|
|
target.addLegalDialect<StandardOpsDialect>();
|
|
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
|
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
|
return signalPassFailure();
|
|
}
|
|
|
|
// Erase some stray shape ops from the program. They can't be
|
|
// deleted during conversion because they become unused only after
|
|
// subsequent patterns bypass them.
|
|
auto walkResult = func.walk([](Operation *op) {
|
|
if (!isa<shape::FromExtentsOp>(op))
|
|
return WalkResult::advance();
|
|
if (op->use_empty()) {
|
|
op->erase();
|
|
} else {
|
|
op->emitError("could not be eliminated");
|
|
return WalkResult::interrupt();
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
if (walkResult.wasInterrupted())
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<FuncOp>>
|
|
mlir::NPCOMP::createLowerRankedShapesPass() {
|
|
return std::make_unique<LowerRankedShapes>();
|
|
}
|