mirror of https://github.com/llvm/torch-mlir
Use upstream shape.from_extents
Replace our local `tcp.shape_from_extents` op with the upstream `shape.from_extents` op.pull/1/head
parent
1fed1cb016
commit
3a09455540
|
@ -109,17 +109,6 @@ This op has undefined behavior if the shape is an error.
|
|||
];
|
||||
}
|
||||
|
||||
// TODO: This op belongs in the shape dialect as `shape.from_extents`.
|
||||
def TCP_ShapeFromExtentsOp : TCP_Op<"shape_from_extents",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Constructs a shape from extents";
|
||||
let description = [{
|
||||
Constructs a shape from the extents passed as arguments.
|
||||
}];
|
||||
let arguments = (ins Variadic<Index>:$extents);
|
||||
let results = (outs Shape_ShapeType:$shape);
|
||||
}
|
||||
|
||||
def TCP_AbortIfOp : TCP_Op<"abort_if"> {
|
||||
let summary = "Aborts the program if the argument is true.";
|
||||
let description = [{
|
||||
|
|
|
@ -37,18 +37,6 @@ LogicalResult GetExtentOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeFromExtentsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ShapeFromExtentsOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(shape::ShapeType::get(context));
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace tcp {
|
||||
|
|
|
@ -424,9 +424,8 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
|||
// pass that checks no !shape.shape types left.
|
||||
pm.addPass(createLowerRankedShapesPass());
|
||||
|
||||
|
||||
// Run a final canonicalization pass to delete dead
|
||||
// `tcp.shape_from_extents` ops.
|
||||
// `shape.from_extents` ops.
|
||||
// This is needed for correctness, since we can't currently lower that op
|
||||
// to LLVM, since we don't have a runtime representation of `!shape.shape`.
|
||||
// TODO: Change LowerRankedShapes to delete these ops itself.
|
||||
|
|
|
@ -27,8 +27,8 @@ public:
|
|||
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
shape::BroadcastOp::OperandAdaptor adaptor(operands);
|
||||
auto lhs = adaptor.lhs().getDefiningOp<tcp::ShapeFromExtentsOp>();
|
||||
auto rhs = adaptor.rhs().getDefiningOp<tcp::ShapeFromExtentsOp>();
|
||||
auto lhs = adaptor.lhs().getDefiningOp<shape::FromExtentsOp>();
|
||||
auto rhs = adaptor.rhs().getDefiningOp<shape::FromExtentsOp>();
|
||||
if (!lhs || !rhs)
|
||||
return rewriter.notifyMatchFailure(op, "operands not converted");
|
||||
// Establish invariant that rank(lhs) >= rank(rhs).
|
||||
|
@ -66,7 +66,7 @@ public:
|
|||
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
|
||||
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tcp::ShapeFromExtentsOp>(op, resultExtents);
|
||||
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(op, resultExtents);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -87,7 +87,7 @@ public:
|
|||
matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
tcp::GetExtentOp::OperandAdaptor adaptor(operands);
|
||||
auto fromExtents = adaptor.shape().getDefiningOp<tcp::ShapeFromExtentsOp>();
|
||||
auto fromExtents = adaptor.shape().getDefiningOp<shape::FromExtentsOp>();
|
||||
if (!fromExtents)
|
||||
return rewriter.notifyMatchFailure(op, "not a from_extents op");
|
||||
int64_t dim = op.dim().getLimitedValue();
|
||||
|
@ -116,14 +116,14 @@ public:
|
|||
|
||||
// Basic invariant of this pass:
|
||||
// Every def of a !shape.shape type is replaced with a
|
||||
// `tcp.shape_from_extents` op.
|
||||
// When converting an op, look for the `tcp.shape_from_extents` op that
|
||||
// `shape.from_extents` op.
|
||||
// When converting an op, look for the `shape.from_extents` op that
|
||||
// defined all operands, then do a computation on the extents (i.e.
|
||||
// operands to the `tcp.shape_from_extents` op) and produce a
|
||||
// `tcp.shape_from_extents` op.
|
||||
// operands to the `shape.from_extents` op) and produce a
|
||||
// `shape.from_extents` op.
|
||||
//
|
||||
// We expect that previous passes have inserted a "root" set of
|
||||
// tcp::ShapeFromExtentsOp's that allow this process to get started.
|
||||
// shape::FromExtentsOp's that allow this process to get started.
|
||||
//
|
||||
// We then use this to resolve get_extent ops by using a rewrite
|
||||
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
|
||||
|
@ -138,7 +138,7 @@ public:
|
|||
//
|
||||
// TODO: This approach doesn't naively work with control flow.
|
||||
// In the presence of non-cyclic control flow, we can just generalize the
|
||||
// `getDefiningOp<tcp::ShapeFromExtentsOp>()` calls into something that will
|
||||
// `getDefiningOp<shape::FromExtentsOp>()` calls into something that will
|
||||
// look through block arguments and rewrite "phi of shapes -> phi of extents".
|
||||
// In the presence of cyclic control flow, we need to somehow resolve the
|
||||
// ranks of use-def cycles ahead of time or optimistically assume that
|
||||
|
@ -158,7 +158,7 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|||
target.addIllegalOp<shape::ShapeOfOp>();
|
||||
target.addIllegalOp<shape::BroadcastOp>();
|
||||
target.addIllegalOp<tcp::GetExtentOp>();
|
||||
target.addLegalOp<tcp::ShapeFromExtentsOp>();
|
||||
target.addLegalOp<shape::FromExtentsOp>();
|
||||
target.addLegalOp<tcp::AbortIfOp>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
|
||||
|
|
|
@ -101,7 +101,7 @@ public:
|
|||
SmallVector<Value, 6> extents;
|
||||
for (int i = 0, e = tensorType.getRank(); i < e; i++)
|
||||
extents.push_back(rewriter.create<DimOp>(op.getLoc(), rankedMemRef, i));
|
||||
rewriter.replaceOpWithNewOp<tcp::ShapeFromExtentsOp>(op, extents);
|
||||
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(op, extents);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -148,7 +148,7 @@ class LowerToMemRefABI : public LowerToMemRefABIBase<LowerToMemRefABI> {
|
|||
|
||||
patterns.insert<LowerShapeOfOp>(context);
|
||||
target.addIllegalOp<shape::ShapeOfOp>();
|
||||
target.addLegalOp<tcp::ShapeFromExtentsOp>();
|
||||
target.addLegalOp<shape::FromExtentsOp>();
|
||||
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
|
||||
// CHECK-NOT: shape.broadcast
|
||||
// CHECK-NOT: tcp.get_extent
|
||||
%0 = "tcp.shape_from_extents"(%arg0, %arg1) : (index, index) -> !shape.shape
|
||||
%1 = "tcp.shape_from_extents"(%arg2) : (index) -> !shape.shape
|
||||
%0 = shape.from_extents %arg0, %arg1
|
||||
%1 = shape.from_extents %arg2
|
||||
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%e0 = tcp.get_extent %2, 0
|
||||
%e1 = tcp.get_extent %2, 1
|
||||
|
@ -16,7 +16,7 @@ func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index,
|
|||
// CHECK-LABEL: func @erase_shape_observe_error
|
||||
func @erase_shape_observe_error(%arg0: index) {
|
||||
// CHECK-NOT tcp.shape_observe_error
|
||||
%0 = "tcp.shape_from_extents"(%arg0) : (index) -> !shape.shape
|
||||
%0 = shape.from_extents %arg0
|
||||
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
|
||||
return
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||
|
||||
// CHECK: %[[VAL_2:.*]] = memref_cast %[[VAL_1]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = dim %[[VAL_2]], 0 : memref<?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tcp.shape_from_extents"(%[[VAL_3]]) : (index) -> !shape.shape
|
||||
// CHECK: %[[VAL_4:.*]] = shape.from_extents %[[VAL_3]]
|
||||
%shape = shape.shape_of %arg0 : tensor<?xf32>
|
||||
|
||||
// CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32>
|
||||
|
|
Loading…
Reference in New Issue