mirror of https://github.com/llvm/torch-mlir
"Finish" tensor -> memref conversion.
There's a lot of details to flesh out here, but the basic approach seems promising (see comments in createE2ELoweringPipeline). This approach will be put to the test when we try to do our first fusions since that tickles some of the nasty phase ordering issues involved here. But we're not there yet.pull/1/head
parent
fec2ee0072
commit
53c17dbed9
|
@ -23,6 +23,8 @@ createLowerLinalgOnTensorToLinalgOnMemrefPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createResolveShapeOfOpsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass();
|
||||
|
||||
void createLowerToHybridTensorMemRefPipeline(OpPassManager &pm);
|
||||
|
||||
// The main pipeline that encapsulates the full E2E lowering.
|
||||
|
|
|
@ -28,4 +28,9 @@ def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> {
|
|||
let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()";
|
||||
}
|
||||
|
||||
def ResolveTensorLoadStoreOps : Pass<"resolve-tensor-load-store-ops", "FuncOp"> {
|
||||
let summary = "Resolve tensor_load/tensor_store ops";
|
||||
let constructor = "mlir::NPCOMP::createResolveTensorLoadStoreOpsPass()";
|
||||
}
|
||||
|
||||
#endif // NPCOMP_E2E_PASSES
|
||||
|
|
129
lib/E2E/E2E.cpp
129
lib/E2E/E2E.cpp
|
@ -58,6 +58,10 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResolveShapeOfOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ResolveShapeOfOpViaAllocMemRefOp : public OpRewritePattern<shape::ShapeOfOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
@ -112,6 +116,68 @@ mlir::NPCOMP::createResolveShapeOfOpsPass() {
|
|||
return std::make_unique<ResolveShapeOfOps>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResolveTensorLoadStoreOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ReplaceTensorStoreWithCopyPattern
|
||||
: public OpRewritePattern<TensorStoreOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(TensorStoreOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto tensorLoad =
|
||||
llvm::dyn_cast_or_null<TensorLoadOp>(op.tensor().getDefiningOp());
|
||||
if (!tensorLoad)
|
||||
return rewriter.notifyMatchFailure(op, "not fed by tensor_load op");
|
||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(op, tensorLoad.memref(),
|
||||
op.memref());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class EraseUnusedTensorLoadOpPattern : public OpRewritePattern<TensorLoadOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(TensorLoadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.use_empty())
|
||||
return rewriter.notifyMatchFailure(op, "has uses");
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ResolveTensorLoadStoreOps
|
||||
: public ResolveTensorLoadStoreOpsBase<ResolveTensorLoadStoreOps> {
|
||||
void runOnOperation() {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ReplaceTensorStoreWithCopyPattern>(context);
|
||||
patterns.insert<EraseUnusedTensorLoadOpPattern>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect>();
|
||||
target.addDynamicallyLegalOp<TensorLoadOp>([](TensorLoadOp op) {
|
||||
for (auto user : op.getResult().getUsers())
|
||||
if (!isa<tcp::YieldOp>(user))
|
||||
return false;
|
||||
return true;
|
||||
});
|
||||
target.addDynamicallyLegalOp<TensorStoreOp>(
|
||||
[](TensorStoreOp op) { return op.tensor().isa<BlockArgument>(); });
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createResolveTensorLoadStoreOpsPass() {
|
||||
return std::make_unique<ResolveTensorLoadStoreOps>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// createE2ELoweringPipeline
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -121,28 +187,71 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
|||
|
||||
// Convert to TCP.
|
||||
pm.addPass(createConvertTCFToTCPPass());
|
||||
// Convert tcp ops to Linalg where possible.
|
||||
pm.addPass(createConvertTCPToLinalgPass());
|
||||
|
||||
// TODO: Do tcp.island coarsening here.
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Tensor to buffer (memref) conversion.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Convert tcp ops to Linalg where possible, as we want generic linalg
|
||||
// tensor->memref to do most of the mechanical work of rewriting ops in
|
||||
// terms of tensors to ops in terms of memrefs (since it is easy on that
|
||||
// representation).
|
||||
pm.addPass(createConvertTCPToLinalgPass());
|
||||
|
||||
// Lower to hybrid tensor/memref
|
||||
//
|
||||
// The hybrid tensor/memref representation guarantees:
|
||||
// - every use of a tensor is a tensor_store op writing it into a memref
|
||||
// - every def of a tensor is a tensor_load op loading out of some memref.
|
||||
// - every memref is allocated by a `tcp.alloc_memref(%shape)` op.
|
||||
// - every memref is only ever writen once, and never mutated
|
||||
//
|
||||
// Exceptions: "boundaries" such as function arguments and island
|
||||
// live-outs.
|
||||
//
|
||||
// Or, another way to say this: the hybrid tensor/memref representation
|
||||
// doesn't attempt to eliminate the original tensors from the program,
|
||||
// but rather locally expands operations on tensors to be small subgraphs
|
||||
// with tensor_load/tensor_store at the boundaries, leaving enough
|
||||
// invariants that we can clean it up later.
|
||||
//
|
||||
// The core invariants that are needed for this step are that the
|
||||
// tensor-level ops we receive as input have a way of calculating the
|
||||
// sizes for their outputs. This is equivalent to saying that
|
||||
// `shape.shape_of` on the result of an op must be calculatable in terms
|
||||
// of the shapes of the inputs to the op.
|
||||
createLowerToHybridTensorMemRefPipeline(pm);
|
||||
|
||||
// At this point, every tensor in the program is the result of a
|
||||
// `tensor_load` of an `alloc_memref` op (or is an argument). Therefore,
|
||||
// every shape_of can be resolved by looking at the corresponding
|
||||
// alloc_memref of the tensor.
|
||||
// At this point, the invariants of the hybrid tensor/memref
|
||||
// representation allow us to resolve `shape.shape_of` ops to shape
|
||||
// computations earlier in the program. Specifically, every
|
||||
// `shape.shape_of` can be resolved to the shape argument to the
|
||||
// corresponding `tcp.alloc_memref` op of the tensor_load that produced
|
||||
// that tensor.
|
||||
pm.addPass(createResolveShapeOfOpsPass());
|
||||
|
||||
// Now, we use the hybrid tensor/memref invariants to replace the
|
||||
// tensor_store ops with memref copy operations and erase the
|
||||
// tensor_load/tensor_store ops.
|
||||
pm.addPass(createResolveTensorLoadStoreOpsPass());
|
||||
|
||||
// At this point, the IR is in a form where the interiors of islands
|
||||
// don't have tensor ops (except tensor_store's of arguments and
|
||||
// tensor_load's of returns).
|
||||
//
|
||||
// This is a reasonable representation for doing buffer assignment.
|
||||
//
|
||||
// TODO: Might want a different kind of island to better represent this.
|
||||
// This island op would explicitly capture all tensors as inputs, and it
|
||||
// would establish a more formalized ABI with the interior of the body
|
||||
// region (much like IREE does with dispatch regions). For now, we are
|
||||
// planning on just inlining the islands, so there is little value in
|
||||
// doing this, but we should look at the layering aspects here later.
|
||||
|
||||
// TODO:
|
||||
// forward tensor_load/tensor_store (which leaves all tensors with no
|
||||
// uses)
|
||||
// lower linalg to loops: mlir::createConvertLinalgToLoopsPass()
|
||||
// lower shape stuff to rshape?
|
||||
// lower rshape to SSA values?
|
||||
// lower shape stuff to ssa values.
|
||||
// Convert all of it to LLVM?
|
||||
}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
// RUN: npcomp-opt -resolve-tensor-load-store-ops <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @basic
|
||||
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
||||
%shape = "shape.shape_of"(%arg0) : (tensor<?xf32>) -> !shape.shape
|
||||
|
||||
// CHECK: %[[SRCMEMREF:.+]] = "tcp.alloc_memref"
|
||||
%src_memref = "tcp.alloc_memref"(%shape) : (!shape.shape) -> memref<?xf32>
|
||||
// tensor_store of argument remains.
|
||||
// CHECK: tensor_store %arg0, %[[SRCMEMREF]]
|
||||
tensor_store %arg0, %src_memref : memref<?xf32>
|
||||
%src = tensor_load %src_memref : memref<?xf32>
|
||||
|
||||
// CHECK: %[[DSTMEMREF:.+]] = "tcp.alloc_memref"
|
||||
%dst_memref = "tcp.alloc_memref"(%shape) : (!shape.shape) -> memref<?xf32>
|
||||
// tensor_store of internally created tensor is eliminated.
|
||||
// CHECK-NOT: tensor_store
|
||||
// CHECK: linalg.copy(%[[SRCMEMREF]], %[[DSTMEMREF]])
|
||||
tensor_store %src, %dst_memref : memref<?xf32>
|
||||
%ret = tensor_load %dst_memref : memref<?xf32>
|
||||
|
||||
// The tensor_load feeding into the return remains.
|
||||
// %[[RET:.+]] = tensor_load %[[DSTMEMREF]]
|
||||
// return %[[RET]]
|
||||
return %ret : tensor<?xf32>
|
||||
}
|
Loading…
Reference in New Issue