torch-mlir/lib/E2E/LowerRankedShapes.cpp

269 lines
11 KiB
C++
Raw Normal View History

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "npcomp/E2E/E2E.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
namespace {
class LowerConstShapeOp : public OpConversionPattern<shape::ConstShapeOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::ConstShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto extents = llvm::to_vector<6>(llvm::map_range(
op.shape().getValues<int64_t>(), [&](int64_t extent) -> Value {
return rewriter.create<ConstantIndexOp>(op.getLoc(), extent);
}));
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
op, rewriter.getType<shape::ShapeType>(), extents);
return success();
}
};
} // namespace
namespace {
// Given an operand that is either a Shape or Extent Tensor, returns an
// Extent Tensor or nullptr if this cannot be locally determined.
// The return value, if !nullptr, will be a 1D RankedTensorType (with possibly
// unknown element).
Value findExtentsFromShape(Value operand, bool requireKnownRank) {
if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
if (tensorType.getRank() == 1 &&
(!requireKnownRank || tensorType.hasStaticShape())) {
return operand;
}
}
return nullptr;
}
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::BroadcastOp::Adaptor adaptor(operands);
// When the ranks are statically known, generate non-branchy code.
// TODO: Generate rank-generic code.
auto lhsExtents = findExtentsFromShape(adaptor.lhs(), true);
auto rhsExtents = findExtentsFromShape(adaptor.rhs(), true);
if (!lhsExtents || !rhsExtents)
return rewriter.notifyMatchFailure(op, "dynamic extents not supported");
// Establish invariant that rank(lhs) >= rank(rhs).
auto lhsSize = lhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
auto rhsSize = rhsExtents.getType().cast<RankedTensorType>().getDimSize(0);
if (lhsSize < rhsSize) {
std::swap(lhsExtents, rhsExtents);
std::swap(lhsSize, rhsSize);
}
auto rankDiscrepancy = lhsSize - rhsSize;
// Helper that creates IR
// ```
// abort_if(extent != resultExtent && extent != 1)
// ```
// This is the numpy broadcasting legality check.
auto createAbortIfIllegalBroadcastExtent = [&](Value extent,
Value resultExtent) {
auto c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
auto extentNeMax = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne,
extent, resultExtent);
auto extentNeOne =
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, extent, c1);
auto bothTrue =
rewriter.create<AndOp>(op.getLoc(), extentNeMax, extentNeOne);
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
// TODO: Should there be a more generic error-handling dialect?
// It seems a bit awkward to hardcode npcomprt here.
rewriter.create<npcomprt::AbortIfOp>(op.getLoc(), bothTrue);
};
SmallVector<Value, 6> resultExtents;
for (int i = 0, e = lhsSize; i < e; i++) {
auto lhsDim = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
auto lhsExtent = rewriter.create<ExtractElementOp>(
op.getLoc(), lhsExtents, ValueRange{lhsDim});
if (i < rankDiscrepancy) {
// Padded extent.
resultExtents.push_back(lhsExtent);
continue;
}
// Non-padded extent.
auto rhsDim =
rewriter.create<ConstantIndexOp>(op.getLoc(), i - rankDiscrepancy);
auto rhsExtent = rewriter.create<ExtractElementOp>(
op.getLoc(), rhsExtents, ValueRange{rhsDim});
auto ugt = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ugt,
lhsExtent, rhsExtent);
auto resultExtent =
rewriter.create<SelectOp>(op.getLoc(), ugt, lhsExtent, rhsExtent);
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
resultExtents.push_back(resultExtent);
}
// TODO: Remove the return type once ODS is fixed to do proper inference.
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
op, shape::ShapeType::get(rewriter.getContext()), resultExtents);
return success();
}
};
} // namespace
namespace {
class LowerShapeToExtentTensorOp
: public OpConversionPattern<shape::ToExtentTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::ToExtentTensorOpAdaptor adaptor(operands);
if (adaptor.input().getType().isa<shape::ShapeType>()) {
// Convert by matching to a producing FromExtentsOp.
auto fromExtents = adaptor.input().getDefiningOp<shape::FromExtentsOp>();
if (!fromExtents) {
return rewriter.notifyMatchFailure(op, "not a from_extents op");
}
rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op,
fromExtents.extents());
return success();
}
// Assume that it is already an extent tensor.
// TODO: Since these ops are all multi-type, there should be a utility
// for switching on the allowable types instead of just assuming that it
// is an extent tensor.
rewriter.replaceOp(op, adaptor.input());
return success();
}
};
class LowerShapeGetExtentOp : public OpConversionPattern<shape::GetExtentOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::GetExtentOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, adaptor.shape(),
adaptor.dim());
return success();
}
};
} // namespace
namespace {
// Now that we have lowered ranked shapes, which reifies the eager
// error-handling code, the tcp::ShapeObserveErrorOp's are no longer
// needed.
class EraseShapeObserveErrorOp
: public OpConversionPattern<tcp::ShapeObserveErrorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
} // namespace
// Basic invariant of this pass:
// Every `shape.from_extents` op operating on an extent tensor
// (`tensor<?xindex>`) is replaced by corresponding standard ops and folded
// away (for the ranked case, it should be possible to eliminate these).
//
// We expect that previous passes have inserted a "root" set of
// shape::FromExtentsOp's that allow this process to get started.
//
// This is similar to the approach that is used in IREE. It is basically a
// combination of the ConvertShapeToShapex pass and the
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
// These patterns have to be "conversion patterns" since the `operands` argument
// gives access to the post-conversion operands from earlier ops.
//
// This pass depends heavily on ranked shapes, since only ranked shapes can
// be statically expanded to a fixed set of SSA extents.
//
// TODO: This approach doesn't naively work with control flow.
// In the presence of non-cyclic control flow, we can just generalize the
// `getDefiningOp<shape::FromExtentsOp>()` calls into something that will
// look through block arguments and rewrite "phi of shapes -> phi of extents".
// In the presence of cyclic control flow, we need to somehow resolve the
// ranks of use-def cycles ahead of time or optimistically assume that
// backedges will match the rank of forward edges, and somehow be robust
// when that assumption fails.
//
// TODO: Add in a fold of
// `extract_element(tensor_from_elements(x0, x1, ...), n) -> xn` to restore
// the above invariant without relying on a subsequent canonicalization
// step.
namespace {
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<LowerConstShapeOp>(context);
patterns.insert<LowerShapeBroadcastOp>(context);
patterns.insert<LowerShapeGetExtentOp>(context);
patterns.insert<LowerShapeToExtentTensorOp>(context);
patterns.insert<EraseShapeObserveErrorOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<shape::ShapeOfOp>();
target.addIllegalOp<shape::BroadcastOp>();
target.addIllegalOp<shape::GetExtentOp>();
target.addLegalOp<shape::FromExtentsOp>();
target.addIllegalOp<shape::ToExtentTensorOp>();
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
target.addLegalOp<npcomprt::AbortIfOp>();
target.addLegalDialect<StandardOpsDialect>();
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
// Erase some stray shape ops from the program. They can't be
// deleted during conversion because they become unused only after
// subsequent patterns bypass them.
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
auto walkResult = func.walk([](Operation *op) {
if (!isa<shape::FromExtentsOp>(op))
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
return WalkResult::advance();
if (op->use_empty()) {
op->erase();
} else {
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
op->emitError("could not be eliminated");
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createLowerRankedShapesPass() {
return std::make_unique<LowerRankedShapes>();
}