torch-mlir/lib/E2E/E2E.cpp

453 lines
18 KiB
C++
Raw Normal View History

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the base file for our "end-to-end" npcomp lowering pipeline.
// At the moment, the first "end" is TCF ops and the second "end" is `llvm`
// dialect suitable for jitting.
//
// This is still work-in-progress and not even working end-to-end for the
// most trivial examples, see TODO's in createE2ELoweringPipeline for the
// status.
//
// As a pragmatic matter, I generally tend to drop random passes and stuff
// inside this top-level file and then shard it out to separate files once
// a clear organizing principle arises (to avoid premature organizing).
//
// Once we have end-to-end functionality working, we will throw
// increasingly complex programs and augment this pass pipeline, likely
// introducing better structure and more clear principles.
//
// I wish I had a clear view of how this pipeline should perfectly layer
// ahead of time, but unfortunately I don't since it crosses half a dozen
// abstraction levels / dialects, some of which have no precedent that I'm
// aware of (dynamic-shape-aware, error-aware TCF -> TCP) or very little
// (tensor -> memref/buffer with dynamic shapes, shape -> SSA values for
// ranked shape extents).
//
// Right now there's lots of stuff in this pipeline that is limited to
// special cases where I have an idea of how to elaborate it to the general
// case. The priority is getting and end-to-end flow working that we can
// grow out organically to a curriculum of more complex cases, elaborating
// on the design principles and layering as necessitated by the curriculum.
//
// This should be fun :)
//
//===----------------------------------------------------------------------===//
#include "npcomp/E2E/E2E.h"
#include "PassDetail.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
#include "npcomp/Conversion/TCPToLinalg/TCPToLinalg.h"
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
//===----------------------------------------------------------------------===//
// ResolveShapeOfOps
//===----------------------------------------------------------------------===//
namespace {
class ResolveShapeOfOpViaAllocMemRefOp
: public OpRewritePattern<shape::ShapeOfOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
if (auto tensorLoad = llvm::dyn_cast_or_null<TensorLoadOp>(
op.getOperand().getDefiningOp())) {
if (auto allocMemRef = llvm::dyn_cast_or_null<tcp::AllocMemRefOp>(
tensorLoad.getOperand().getDefiningOp())) {
rewriter.replaceOp(op, allocMemRef.getOperand());
return success();
}
}
return failure();
}
};
} // namespace
namespace {
class ResolveShapeOfOps : public ResolveShapeOfOpsBase<ResolveShapeOfOps> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<ResolveShapeOfOpViaAllocMemRefOp>(context);
ConversionTarget target(*context);
// target.addIllegalOp<shape::ShapeOfOp>();
target.addDynamicallyLegalOp<shape::ShapeOfOp>(
[](shape::ShapeOfOp shapeOf) {
// Only shape.shape_of on arguments to the entry block are legal at
// this point. They are assumed to be resolved eventually via
// the lowering of the tensor argument to some ABI that has the
// relevant information available. But this is ABI dependent.
// TODO: Convince myself that we never need to deal with general
// block operands, or implement general handling of block
// operands (need to add new bb operands of !shape.shape type).
if (auto blockArg = shapeOf.getOperand().dyn_cast<BlockArgument>()) {
Block *block = blockArg.getOwner();
if (&block->getParent()->front() == block) {
return true;
}
}
return false;
});
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createResolveShapeOfOpsPass() {
return std::make_unique<ResolveShapeOfOps>();
}
//===----------------------------------------------------------------------===//
// ResolveTensorLoadStoreOps
//===----------------------------------------------------------------------===//
namespace {
class ReplaceTensorStoreWithCopyPattern
: public OpRewritePattern<TensorStoreOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TensorStoreOp op,
PatternRewriter &rewriter) const override {
auto tensorLoad =
llvm::dyn_cast_or_null<TensorLoadOp>(op.tensor().getDefiningOp());
if (!tensorLoad)
return rewriter.notifyMatchFailure(op, "not fed by tensor_load op");
rewriter.replaceOpWithNewOp<linalg::CopyOp>(op, tensorLoad.memref(),
op.memref());
return success();
}
};
} // namespace
namespace {
class EraseUnusedTensorLoadOpPattern : public OpRewritePattern<TensorLoadOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TensorLoadOp op,
PatternRewriter &rewriter) const override {
if (!op.use_empty())
return rewriter.notifyMatchFailure(op, "has uses");
rewriter.eraseOp(op);
return success();
}
};
} // namespace
namespace {
class ResolveTensorLoadStoreOps
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<ReplaceTensorStoreWithCopyPattern>(context);
patterns.insert<EraseUnusedTensorLoadOpPattern>(context);
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect>();
target.addDynamicallyLegalOp<TensorLoadOp>([](TensorLoadOp op) {
for (auto user : op.getResult().getUsers())
if (!isa<ReturnOp>(user))
return false;
return true;
});
target.addDynamicallyLegalOp<TensorStoreOp>(
[](TensorStoreOp op) { return op.tensor().isa<BlockArgument>(); });
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createResolveTensorLoadStoreOpsPass() {
return std::make_unique<ResolveTensorLoadStoreOps>();
}
//===----------------------------------------------------------------------===//
// LowerLinalgLoopDimOps
//===----------------------------------------------------------------------===//
namespace {
class LowerLinalgLoopDimOp : public OpRewritePattern<DimOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp op,
PatternRewriter &rewriter) const override {
// TODO: Remove this const pattern when lowering to shape.get_extent.
auto constIndex = op.getConstantIndex();
if (!constIndex)
return failure();
auto allocMemRef = op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
if (!allocMemRef)
return rewriter.notifyMatchFailure(op, "could not find alloc_memref");
rewriter.replaceOpWithNewOp<tcp::GetExtentOp>(op, allocMemRef.shape(),
*constIndex);
return success();
}
};
} // namespace
namespace {
class LowerLinalgLoopDimOps
: public LowerLinalgLoopDimOpsBase<LowerLinalgLoopDimOps> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<LowerLinalgLoopDimOp>(context);
ConversionTarget target(*context);
target.addDynamicallyLegalOp<DimOp>([](DimOp op) -> bool {
// TODO: We only need this because we use `dim` ops for the memref
// ABI. Once we layer that out into our own runtime types, we can
// remove this.
return !op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
});
target.addLegalOp<tcp::GetExtentOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createLowerLinalgLoopDimOpsPass() {
return std::make_unique<LowerLinalgLoopDimOps>();
}
//===----------------------------------------------------------------------===//
// LowerAllocMemRefOps
//===----------------------------------------------------------------------===//
namespace {
class LowerAllocMemRefOp : public OpRewritePattern<tcp::AllocMemRefOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tcp::AllocMemRefOp op,
PatternRewriter &rewriter) const override {
auto memrefType = op.getType().cast<MemRefType>();
auto shape = op.getOperand();
// std.alloc only accepts the dynamic extents as operands, so only
// collect those.
SmallVector<Value, 6> dynamicExtents;
for (int i = 0, e = memrefType.getRank(); i < e; i++) {
if (memrefType.isDynamicDim(i)) {
auto extent = rewriter.create<tcp::GetExtentOp>(op.getLoc(), shape, i);
dynamicExtents.push_back(extent);
}
}
rewriter.replaceOpWithNewOp<AllocOp>(op, memrefType, dynamicExtents);
return success();
}
};
} // namespace
namespace {
class LowerAllocMemRefOps
: public LowerAllocMemRefOpsBase<LowerAllocMemRefOps> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<LowerAllocMemRefOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<tcp::AllocMemRefOp>();
target.addLegalOp<tcp::GetExtentOp>();
target.addLegalOp<AllocOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createLowerAllocMemRefOpsPass() {
return std::make_unique<LowerAllocMemRefOps>();
}
//===----------------------------------------------------------------------===//
// createE2ELoweringPipeline
//===----------------------------------------------------------------------===//
void mlir::NPCOMP::createE2ELoweringPipeline(
OpPassManager &pm, const E2ELoweringPipelineOptions &options) {
// Input IR is TCF ops.
// Convert to TCP.
pm.addPass(createConvertTCFToTCPPass());
// TODO: Do tcp.island coarsening here.
// TODO: This is approximately the place that we would fork off when
// lowering to IREE.
// --------------------------------------------------------------------------
// Tensor to buffer (memref) conversion.
// --------------------------------------------------------------------------
// Convert tcp ops to Linalg where possible, as we want generic linalg
// tensor->memref to do most of the mechanical work of rewriting ops in
// terms of tensors to ops in terms of memrefs (since it is easy on that
// representation).
pm.addPass(createConvertTCPToLinalgPass());
// Lower to hybrid tensor/memref
//
// The hybrid tensor/memref representation guarantees:
// - every use of a tensor is a tensor_store op writing it into a memref
// - every def of a tensor is a tensor_load op loading out of some memref.
// - every memref is allocated by a `tcp.alloc_memref(%shape)` op.
// - every memref is only ever writen once, and never mutated
//
// Exceptions: "boundaries" such as function arguments and island
// live-outs.
//
// Or, another way to say this: the hybrid tensor/memref representation
// doesn't attempt to eliminate the original tensors from the program,
// but rather locally expands operations on tensors to be small subgraphs
// with tensor_load/tensor_store at the boundaries, leaving enough
// invariants that we can clean it up later.
//
// The core invariants that are needed for this step are that the
// tensor-level ops we receive as input have a way of calculating the
// sizes for their outputs. This is equivalent to saying that
// `shape.shape_of` on the result of an op must be calculatable in terms
// of the shapes of the inputs to the op.
createLowerToHybridTensorMemRefPipeline(pm);
// At this point, the invariants of the hybrid tensor/memref
// representation allow us to resolve `shape.shape_of` ops to shape
// computations earlier in the program. Specifically, every
// `shape.shape_of` can be resolved to the shape argument to the
// corresponding `tcp.alloc_memref` op of the tensor_load that produced
// that tensor.
pm.addPass(createResolveShapeOfOpsPass());
// Now, we use the hybrid tensor/memref invariants to replace the
// tensor_store ops with memref copy operations and erase the
// tensor_load/tensor_store ops.
pm.addPass(createResolveTensorLoadStoreOpsPass());
// At this point, the IR is in a form where there are no tensor ops
// (except tensor_store's of arguments, tensor_load's of returns, and
// constants).
//
// This is a reasonable representation for doing buffer assignment.
// TODO: Do buffer assignment here.
// We need to finalize the removal of tensors from the program. To do
// that, we need to interface with a runtime ABI.
Rework e2e flow to use new "npcomprt" This ~totally reworks the existing "runtime" stuff to be more principled and usable, such as from Python. It's still not fully production-quality, mainly in the department of memory management (e.g. it currently leaks memory; we need to figure out "who frees memrefs" + the analysis and transformation needed to do that (maybe use upstream buffer allocation pass?)). The user API is in include/npcomp/runtime/UserAPI.h, though include/npcomp/JITRuntime/JITModule.h is a friendlier wrapper. The stuff under {include,lib}/runtime is totally firewalled from the compiler and tiny (<6kB, though no attention has gone into optimizing that size). For example, we don't link in libSupport into the runtime, instead having our own bare bones replacements for basics like ArrayRef (the JITRuntime helps with bridging that gap, since it *can* depend on all common LLVM utilities). The overall features of npcomprt is that it exposes a module that with multiple function entry points. Each function has arguments and results that are tensor-valued, and npcomprt::Tensor is the runtime type that is used to interact with that (and a npcomprt::Ref<T> reference-counting wrapper is provided to wrap npcomprt::Tensor in the common case). From an implementation perspective, an npcomprt module at the LLVM/object/binary level exposes a single module descriptor struct that has pointers to other metadata (currently just a list of function metadata descriptors). All interactions with the npcomp runtime are keyed off of that module descriptor, including function lookups and dispatching. This is done to dodge platform ABI issues and also allow enough reflection to e.g. verify provided arguments. Most of the compiler-side work here was in LowerToNpcomprtABI and LowerToLLVM. Also, - Rename npcomp_rt/NpcompRt to npcomprt/Npcomprt; it was getting annoying to type the underscores/caps. - misc improvements to bash_helpers.sh
2020-07-09 08:15:40 +08:00
// We have a specialized dialect npcomprt which models the runtime data
// structures, and function signatures (and presumably eventually, other
// ABI boundaries like external calls if we ever support it) will be
// converted.
pm.addPass(createLowerToNpcomprtABIPass());
// TODO: Might want a different kind of island to better represent this.
// This island op would explicitly capture all tensors as inputs, and it
// would establish a more formalized ABI with the interior of the body
// region (much like IREE does with dispatch regions). For now, we are
// planning on just inlining the islands, so there is little value in
// doing this, but we should look at the layering aspects here later.
// At this point, we have loose shape calculations floating around, so
// it's a good time to do some general cleanups.
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}
// --------------------------------------------------------------------------
// Preparation for converting to an LLVM module.
// --------------------------------------------------------------------------
// Now, we begin the process of lowering to LLVM's level of abstraction
// (after which LLVM will take over lowering to machine code).
// Lower linalg ops to loops.
// TODO: Do some linalg optimizations like tiling here.
pm.addPass(createConvertLinalgToLoopsPass());
// Lowering linalg to loops introduces `dim` ops. Here we look through
// use-def chains to find `tcp.alloc_memref` ops that we can get a shape
// out of.
// Currently, this is trivial, but after more aggressive buffer
// allocation optimizations or linalg tiling this step will need to look
// through slices/views and stuff.
// TODO: It seems that "dim on memrefs" is being resolved in a
// fundamentally different way from "dim on tensors" is earlier in the
// pipeline. Investigate.
// We could somewhat unify them by having enough folding patterns for
// `shape.shape_of`. Above, we used the pattern
// "shape_of(tensor_load(alloc_memref(%shape))) -> %shape". Here we are
// doing `shape_of(alloc_memref(%shape)) -> %shape". It seems
// dangerous to just have a pile of these patterns and hope that one of
// them resolves things at any given point. So what we do is to use a
// very narrowly focused set of patterns that exploit just the invariants
// at each point.
pm.addPass(createLowerLinalgLoopDimOpsPass());
// AllocMemRefOp's take a `!shape.shape` as an argument. We need to
// resolve this to individual extents before we lower ranked shapes.
pm.addPass(createLowerAllocMemRefOpsPass());
// Lower shapes to SSA values.
// This replaces all tcf::GetExtentOp's with explicit SSA computations
// for the scalar extent. This requires shapes which are ranked. Any
// unranked shapes will need to be handled by a runtime shape type,
// though we don't currently support that.
//
// At this point, in the case of programs with only ranked shapes, all
// !shape.shape types will be gone.
// TODO: Better demarcate the invariants here, such as having a verifier
// pass that checks no !shape.shape types left.
pm.addPass(createLowerRankedShapesPass());
// Run a some final cleanups.
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}
// --------------------------------------------------------------------------
// Final conversion to an LLVM module.
// --------------------------------------------------------------------------
// Convert scf to std control flow in preparation for going to LLVM.
pm.addPass(createLowerToCFGPass());
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass
// which reuses the upstream patterns and gives us a place to add our own
// patterns for any custom ops and types we wish to lower.
pm.addPass(createLowerToLLVMPass());
}