torch-mlir/lib/E2E/LowerRankedShapes.cpp

197 lines
8.0 KiB
C++
Raw Normal View History

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "npcomp/E2E/E2E.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
// This has to be a "conversion pattern" since the `operands` argument
// gives access to the post-conversion operands from earlier ops.
namespace {
class LowerShapeBroadcastOp : public OpConversionPattern<shape::BroadcastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
shape::BroadcastOp::Adaptor adaptor(operands);
auto lhs = adaptor.lhs().getDefiningOp<shape::FromExtentsOp>();
auto rhs = adaptor.rhs().getDefiningOp<shape::FromExtentsOp>();
if (!lhs || !rhs)
return rewriter.notifyMatchFailure(op, "operands not converted");
// Establish invariant that rank(lhs) >= rank(rhs).
if (lhs.extents().size() < rhs.extents().size())
std::swap(lhs, rhs);
auto rankDiscrepancy = lhs.extents().size() - rhs.extents().size();
// Helper that creates IR
// ```
// abort_if(extent != resultExtent && extent != 1)
// ```
// This is the numpy broadcasting legality check.
auto createAbortIfIllegalBroadcastExtent = [&](Value extent,
Value resultExtent) {
auto c1 = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
auto extentNeMax = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne,
extent, resultExtent);
auto extentNeOne =
rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ne, extent, c1);
auto bothTrue =
rewriter.create<AndOp>(op.getLoc(), extentNeMax, extentNeOne);
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
// TODO: Should there be a more generic error-handling dialect?
// It seems a bit awkward to hardcode npcomprt here.
rewriter.create<npcomprt::AbortIfOp>(op.getLoc(), bothTrue);
};
auto resultExtents = llvm::to_vector<6>(lhs.extents());
for (int i = 0, e = rhs.extents().size(); i < e; i++) {
auto lhsExtent = lhs.extents()[rankDiscrepancy + i];
auto rhsExtent = rhs.extents()[i];
auto ugt = rewriter.create<CmpIOp>(op.getLoc(), CmpIPredicate::ugt,
lhsExtent, rhsExtent);
auto max =
rewriter.create<SelectOp>(op.getLoc(), ugt, lhsExtent, rhsExtent);
auto &resultExtent = resultExtents[rankDiscrepancy + i];
resultExtent = max;
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
}
// TODO: Remove the return type once ODS is fixed to do proper inference.
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(
op, shape::ShapeType::get(rewriter.getContext()), resultExtents);
return success();
}
};
} // namespace
// Rewrite `get_extent(from_extents(x1,x2,x3), N) -> xN`
//
// TODO: this should be a fold on tcp::GetExtentOp.
// (though then the contract of this pass depends on that set of folds,
// which isn't great)
//
// Also, we use OpConversionPattern to get post-rewrite operands as above.
namespace {
class LowerShapeGetExtentOp : public OpConversionPattern<tcp::GetExtentOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tcp::GetExtentOp::Adaptor adaptor(operands);
auto fromExtents = adaptor.shape().getDefiningOp<shape::FromExtentsOp>();
if (!fromExtents)
return rewriter.notifyMatchFailure(op, "not a from_extents op");
int64_t dim = op.dim().getLimitedValue();
rewriter.replaceOp(op, ValueRange(fromExtents.extents())[dim]);
return success();
}
};
} // namespace
namespace {
// Now that we have lowered ranked shapes, which reifies the eager
// error-handling code, the tcp::ShapeObserveErrorOp's are no longer
// needed.
class EraseShapeObserveErrorOp
: public OpConversionPattern<tcp::ShapeObserveErrorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
} // namespace
// Basic invariant of this pass:
// Every def of a !shape.shape type is replaced with a
// `shape.from_extents` op.
// When converting an op, look for the `shape.from_extents` op that
// defined all operands, then do a computation on the extents (i.e.
// operands to the `shape.from_extents` op) and produce a
// `shape.from_extents` op.
//
// We expect that previous passes have inserted a "root" set of
// shape::FromExtentsOp's that allow this process to get started.
//
// We then use this to resolve get_extent ops by using a rewrite
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
// maximally many places due to the above invariant.
//
// This is similar to the approach that is used in IREE. It is basically a
// combination of the ConvertShapeToShapex pass and the
// "ranked_dim(make_ranked_shape(x1, x2), N) -> xN" folding pattern.
//
// This pass depends heavily on ranked shapes, since only ranked shapes can
// be statically expanded to a fixed set of SSA extents.
//
// TODO: This approach doesn't naively work with control flow.
// In the presence of non-cyclic control flow, we can just generalize the
// `getDefiningOp<shape::FromExtentsOp>()` calls into something that will
// look through block arguments and rewrite "phi of shapes -> phi of extents".
// In the presence of cyclic control flow, we need to somehow resolve the
// ranks of use-def cycles ahead of time or optimistically assume that
// backedges will match the rank of forward edges, and somehow be robust
// when that assumption fails.
namespace {
class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<LowerShapeBroadcastOp>(context);
patterns.insert<LowerShapeGetExtentOp>(context);
patterns.insert<EraseShapeObserveErrorOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<shape::ShapeOfOp>();
target.addIllegalOp<shape::BroadcastOp>();
target.addIllegalOp<tcp::GetExtentOp>();
target.addLegalOp<shape::FromExtentsOp>();
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
target.addLegalOp<npcomprt::AbortIfOp>();
target.addLegalDialect<StandardOpsDialect>();
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
// Erase some stray shape ops from the program. They can't be
// deleted during conversion because they become unused only after
// subsequent patterns bypass them.
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
auto walkResult = func.walk([](Operation *op) {
if (!(isa<shape::FromExtentsOp>(op) || isa<shape::ConstShapeOp>(op)))
return WalkResult::advance();
if (!op->use_empty()) {
op->emitError("could not be eliminated");
return WalkResult::interrupt();
}
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
op->erase();
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createLowerRankedShapesPass() {
return std::make_unique<LowerRankedShapes>();
}