mirror of https://github.com/llvm/torch-mlir
Lower to the upstream memref ABI.
Specifically, we use unranked memrefs which get passed as a fixed-size set of arguments/returns. One big caveat about this is that returning results isn't going to work. See TODO in LowerTensorLoadOp. This is far from enough runtime-wise, but it starts to demarcate a plausible layering. Notice for example how this removes the runtime-dependence from LowerRankedShapes. Eventually, we want to have an `npcomp_rt` or `npcomp_hal` dialect with its own set of runtime types that will supercede this. See comments in LowerTensorLoadOp for more direction about where this is going to evolve.pull/1/head
parent
7687a6d8d2
commit
993338a12d
|
@ -15,6 +15,11 @@
|
|||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
||||
// Look in createE2ELoweringPipeline for more information about how these
|
||||
// passes fit together.
|
||||
//
|
||||
// Pass summaries are in Passes.td.
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLowerBroadcastToToLoopsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
|
@ -29,6 +34,8 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerLinalgLoopDimOpsPass();
|
|||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLowerRankedShapesPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLowerToMemRefABIPass();
|
||||
|
||||
void createLowerToHybridTensorMemRefPipeline(OpPassManager &pm);
|
||||
|
||||
// The main pipeline that encapsulates the full E2E lowering.
|
||||
|
|
|
@ -43,4 +43,9 @@ def LowerRankedShapes : Pass<"lower-ranked-shapes", "FuncOp"> {
|
|||
let constructor = "mlir::NPCOMP::createLowerRankedShapesPass()";
|
||||
}
|
||||
|
||||
def LowerToMemRefABI : Pass<"lower-to-memref-abi", "FuncOp"> {
|
||||
let summary = "Lower tensors at ABI boundaries to memref";
|
||||
let constructor = "mlir::NPCOMP::createLowerToMemRefABIPass()";
|
||||
}
|
||||
|
||||
#endif // NPCOMP_E2E_PASSES
|
||||
|
|
|
@ -2,6 +2,7 @@ add_mlir_library(NPCOMPE2E
|
|||
E2E.cpp
|
||||
LowerRankedShapes.cpp
|
||||
LowerToHybridTensorMemRef.cpp
|
||||
LowerToMemRefABI.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SRC_DIR}/include/npcomp/E2E
|
||||
|
|
|
@ -218,7 +218,12 @@ class LowerLinalgLoopDimOps
|
|||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LowerLinalgLoopDimOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<DimOp>();
|
||||
target.addDynamicallyLegalOp<DimOp>([](DimOp op) -> bool {
|
||||
// TODO: We only need this because we use `dim` ops for the memref
|
||||
// ABI. Once we layer that out into our own runtime types, we can
|
||||
// remove this.
|
||||
return op.getOperand().getDefiningOp<tcp::AllocMemRefOp>();
|
||||
});
|
||||
target.addLegalOp<tcp::GetExtentOp>();
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
|
@ -291,12 +296,22 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
|||
// tensor_load/tensor_store ops.
|
||||
pm.addPass(createResolveTensorLoadStoreOpsPass());
|
||||
|
||||
// At this point, the IR is in a form where the interiors of islands
|
||||
// don't have tensor ops (except tensor_store's of arguments and
|
||||
// tensor_load's of returns).
|
||||
// At this point, the IR is in a form where there are no tensor ops
|
||||
// (except tensor_store's of arguments and tensor_load's of returns).
|
||||
//
|
||||
// This is a reasonable representation for doing buffer assignment.
|
||||
// TODO: Do buffer assignment here.
|
||||
|
||||
// We need to finalize the removal of tensors from the program. To do
|
||||
// that, we need to interface with a runtime ABI.
|
||||
// We currently use a canonicalized version of upstream MLIR's memref
|
||||
// ABI, where we canonically use unranked memref's for all
|
||||
// arguments/returns (which makes the C-level ABI very predictable).
|
||||
//
|
||||
// TODO: This pass is very tentative. See comments on LowerTensorLoadOp
|
||||
// for where we need to take it.
|
||||
pm.addPass(createLowerToMemRefABIPass());
|
||||
|
||||
// TODO: Might want a different kind of island to better represent this.
|
||||
// This island op would explicitly capture all tensors as inputs, and it
|
||||
// would establish a more formalized ABI with the interior of the body
|
||||
|
|
|
@ -17,38 +17,6 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
// Lowers ShapeOfOp's (which at this point should only operating on tensors
|
||||
// that need to have a full runtime-reified representation) to low-level
|
||||
// runtime interfaces.
|
||||
//
|
||||
// This is the "root" ranked shape lowering which creates the first
|
||||
// ShapeFromExtentsOp which is needed to start the whole ranked conversion
|
||||
// process.
|
||||
//
|
||||
// TODO: Move this ABI-specific lowering to a separate pass that only does
|
||||
// that and make this pass require an invariant something like "a 'root'
|
||||
// set of tcp::ShapeFromExtentsOp exist".
|
||||
namespace {
|
||||
class LowerRootRankedShape : public OpRewritePattern<shape::ShapeOfOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto tensor = op.getOperand();
|
||||
auto type = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return rewriter.notifyMatchFailure(op, "not a ranked tensor");
|
||||
SmallVector<Value, 6> extents;
|
||||
for (int i = 0, e = type.getRank(); i < e; i++) {
|
||||
extents.push_back(rewriter.create<tcp::RtGetTensorExtentOp>(
|
||||
op.getLoc(), tensor, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tcp::ShapeFromExtentsOp>(op, extents);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// This has to be a "conversion pattern" since the `operands` argument
|
||||
// gives access to the post-conversion operands from earlier ops.
|
||||
namespace {
|
||||
|
@ -137,6 +105,9 @@ public:
|
|||
// operands to the `tcp.shape_from_extents` op) and produce a
|
||||
// `tcp.shape_from_extents` op.
|
||||
//
|
||||
// We expect that previous passes have inserted a "root" set of
|
||||
// tcp::ShapeFromExtentsOp's that allow this process to get started.
|
||||
//
|
||||
// We then use this to resolve get_extent ops by using a rewrite
|
||||
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
|
||||
// maximally many places due to the above invariant.
|
||||
|
@ -163,7 +134,6 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|||
auto *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LowerRootRankedShape>(context);
|
||||
patterns.insert<LowerShapeBroadcastOp>(context);
|
||||
patterns.insert<LowerShapeGetExtentOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/E2E/E2E.h"
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
namespace {
|
||||
class LowerTensorStoreOp : public OpConversionPattern<TensorStoreOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(TensorStoreOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TensorStoreOp::OperandAdaptor adaptor(operands);
|
||||
// The tensor has been converted to an unranked memref. We need to cast
|
||||
// it to the original memref type and copy it to the destination.
|
||||
//
|
||||
// TODO: Can we have a conversion infrastructure that doesn't have
|
||||
// patterns that doesn't couple type conversions and the patterns. That
|
||||
// is, patterns should be "context free" and locally expand to always
|
||||
// valid IR without relying on some side-channel TypeConverter to do
|
||||
// something else to make the IR valid.
|
||||
auto memref = rewriter.create<MemRefCastOp>(
|
||||
op.getLoc(), op.memref().getType(), adaptor.tensor());
|
||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(op, memref, adaptor.memref());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TensorLoadOp::OperandAdaptor adaptor(operands);
|
||||
auto type = UnrankedMemRefType::get(op.getType().getElementType(), 0);
|
||||
// TODO: This won't work. The LLVM unranked memref calling convention
|
||||
// doesn't allow returning an unranked memref becuase it lowers it to
|
||||
// 'int64 rank, void *descriptor' but in this case the descriptor will
|
||||
// likely be on the stack, so when returning the descriptor pointer it
|
||||
// will be use-after-return.
|
||||
//
|
||||
// We could directly emit LLVM IR mallocing the memref struct on the
|
||||
// heap or do a conversion to out params and require a preallocated
|
||||
// memref out descriptor (perhaps preallocated to a fixed upper bound
|
||||
// rank).
|
||||
//
|
||||
// But a more holistic approach seems needed:
|
||||
// 1. Use custom npcomp runtime types at function boundaries. These can
|
||||
// be approximately like IREE's !hal.buffer_view, namely a type-erased,
|
||||
// shape-erased, ref-counted multidimensional array of dense primitive
|
||||
// types. (Something like Py_buffer from the python buffer protocol is
|
||||
// another potential inspiration)
|
||||
// - [IREE HAL buffer view](https://github.com/google/iree/blob/634136f03c144ad3acd2f28cd87785b0b6b572ac/iree/hal/api_detail.h#L26)
|
||||
// - [Python buffer protocol](https://docs.python.org/3/c-api/buffer.html)
|
||||
// 2. Use a custom LLVM conversion that creates the memref types.
|
||||
// For example, have an op
|
||||
// ```
|
||||
// npcomp_rt.to_memref %buf_view : !npcomp_rt.buffer_view -> memref<?xf32>
|
||||
// ```
|
||||
// with a custom LLVM lowering that expands to all the right stuff.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, type, adaptor.memref());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerShapeOfOp : public OpConversionPattern<shape::ShapeOfOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(shape::ShapeOfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::ShapeOfOp::OperandAdaptor adaptor(operands);
|
||||
auto tensorType = op.arg().getType().cast<RankedTensorType>();
|
||||
auto rankedMemRefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
auto rankedMemRef = rewriter.create<MemRefCastOp>(
|
||||
op.getLoc(), rankedMemRefType, adaptor.arg());
|
||||
SmallVector<Value, 6> extents;
|
||||
for (int i = 0, e = tensorType.getRank(); i < e; i++)
|
||||
extents.push_back(rewriter.create<DimOp>(op.getLoc(), rankedMemRef, i));
|
||||
rewriter.replaceOpWithNewOp<tcp::ShapeFromExtentsOp>(op, extents);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// This pass lowers tensor types to a calling convention where all tensors
|
||||
// are passed as UnrankedMemRefType. This allows the current StandardToLLVM
|
||||
// lowering to return them as `size_t rank, void *descriptor` which is easy
|
||||
// to bridge across a fixed C ABI. (otherwise it specializes the signature
|
||||
// to the memref rank, which is very difficult to interoperate with).
|
||||
class LowerToMemRefABI : public LowerToMemRefABIBase<LowerToMemRefABI> {
|
||||
void runOnOperation() {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
TypeConverter converter;
|
||||
converter.addConversion([](TensorType type) {
|
||||
return UnrankedMemRefType::get(type.getElementType(), /*memorySpace=*/0);
|
||||
});
|
||||
// Mark UnrankedMemRefType as "legal". This is the awkward way of doing
|
||||
// that.
|
||||
// TODO: Commenting this out causes a seemingly unrelated crash.
|
||||
// Redesign MLIR's type conversion system to have a clearer mental
|
||||
// model and not be so flaky.
|
||||
converter.addConversion([](UnrankedMemRefType type) { return type; });
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(*context);
|
||||
|
||||
populateFuncOpTypeConversionPattern(patterns, context, converter);
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp op) {
|
||||
return converter.isSignatureLegal(op.getType());
|
||||
});
|
||||
|
||||
patterns.insert<LowerTensorStoreOp>(context);
|
||||
target.addIllegalOp<TensorStoreOp>();
|
||||
target.addLegalOp<DimOp>();
|
||||
target.addLegalOp<MemRefCastOp>();
|
||||
target.addLegalOp<linalg::CopyOp>();
|
||||
|
||||
patterns.insert<LowerTensorLoadOp>(context);
|
||||
target.addIllegalOp<TensorLoadOp>();
|
||||
|
||||
patterns.insert<LowerShapeOfOp>(context);
|
||||
target.addIllegalOp<shape::ShapeOfOp>();
|
||||
target.addLegalOp<tcp::ShapeFromExtentsOp>();
|
||||
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createLowerToMemRefABIPass() {
|
||||
return std::make_unique<LowerToMemRefABI>();
|
||||
}
|
|
@ -2,12 +2,11 @@
|
|||
|
||||
|
||||
// CHECK-LABEL: func @broadcast_rank2_rank1
|
||||
func @broadcast_rank2_rank1(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> (index, index) {
|
||||
// CHECK-NOT: shape.shape_of
|
||||
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
|
||||
// CHECK-NOT: shape.broadcast
|
||||
// CHECK-NOT: tcp.get_extent
|
||||
%0 = shape.shape_of %arg0 : tensor<?x?xf32>
|
||||
%1 = shape.shape_of %arg1 : tensor<?xf32>
|
||||
%0 = "tcp.shape_from_extents"(%arg0, %arg1) : (index, index) -> !shape.shape
|
||||
%1 = "tcp.shape_from_extents"(%arg2) : (index) -> !shape.shape
|
||||
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%e0 = tcp.get_extent %2, 0
|
||||
%e1 = tcp.get_extent %2, 1
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
// RUN: npcomp-opt -lower-to-memref-abi <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @identity
|
||||
func @identity(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: return %arg0 : memref<*xf32>
|
||||
return %arg0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @basic(
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: memref<*xf32>) -> memref<*xf32> {
|
||||
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
||||
// CHECK: %[[VAL_2:.*]] = memref_cast %[[VAL_1]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = dim %[[VAL_2]], 0 : memref<?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tcp.shape_from_extents"(%[[VAL_3]]) : (index) -> !shape.shape
|
||||
%shape = shape.shape_of %arg0 : tensor<?xf32>
|
||||
|
||||
// CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32>
|
||||
%memref = tcp.alloc_memref %shape : memref<?xf32>
|
||||
|
||||
// CHECK: %[[VAL_6:.*]] = memref_cast %[[VAL_1]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: linalg.copy(%[[VAL_6]], %[[VAL_5]]) : memref<?xf32>, memref<?xf32>
|
||||
tensor_store %arg0, %memref : memref<?xf32>
|
||||
|
||||
// CHECK: %[[VAL_7:.*]] = memref_cast %[[VAL_5]] : memref<?xf32> to memref<*xf32>
|
||||
%ret = tensor_load %memref : memref<?xf32>
|
||||
|
||||
// CHECK: return %[[VAL_7]] : memref<*xf32>
|
||||
return %ret: tensor<?xf32>
|
||||
}
|
Loading…
Reference in New Issue