mirror of https://github.com/llvm/torch-mlir
Use upstream shape.from_extents
Replace our local `tcp.shape_from_extents` op with the upstream `shape.from_extents` op.pull/1/head
parent
1fed1cb016
commit
3a09455540
|
@ -109,17 +109,6 @@ This op has undefined behavior if the shape is an error.
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: This op belongs in the shape dialect as `shape.from_extents`.
|
|
||||||
def TCP_ShapeFromExtentsOp : TCP_Op<"shape_from_extents",
|
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
|
||||||
let summary = "Constructs a shape from extents";
|
|
||||||
let description = [{
|
|
||||||
Constructs a shape from the extents passed as arguments.
|
|
||||||
}];
|
|
||||||
let arguments = (ins Variadic<Index>:$extents);
|
|
||||||
let results = (outs Shape_ShapeType:$shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
def TCP_AbortIfOp : TCP_Op<"abort_if"> {
|
def TCP_AbortIfOp : TCP_Op<"abort_if"> {
|
||||||
let summary = "Aborts the program if the argument is true.";
|
let summary = "Aborts the program if the argument is true.";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -37,18 +37,6 @@ LogicalResult GetExtentOp::inferReturnTypes(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ShapeFromExtentsOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
LogicalResult ShapeFromExtentsOp::inferReturnTypes(
|
|
||||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
|
||||||
DictionaryAttr attributes, RegionRange regions,
|
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
||||||
inferredReturnTypes.push_back(shape::ShapeType::get(context));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace tcp {
|
namespace tcp {
|
||||||
|
|
|
@ -424,9 +424,8 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
|
||||||
// pass that checks no !shape.shape types left.
|
// pass that checks no !shape.shape types left.
|
||||||
pm.addPass(createLowerRankedShapesPass());
|
pm.addPass(createLowerRankedShapesPass());
|
||||||
|
|
||||||
|
|
||||||
// Run a final canonicalization pass to delete dead
|
// Run a final canonicalization pass to delete dead
|
||||||
// `tcp.shape_from_extents` ops.
|
// `shape.from_extents` ops.
|
||||||
// This is needed for correctness, since we can't currently lower that op
|
// This is needed for correctness, since we can't currently lower that op
|
||||||
// to LLVM, since we don't have a runtime representation of `!shape.shape`.
|
// to LLVM, since we don't have a runtime representation of `!shape.shape`.
|
||||||
// TODO: Change LowerRankedShapes to delete these ops itself.
|
// TODO: Change LowerRankedShapes to delete these ops itself.
|
||||||
|
|
|
@ -27,8 +27,8 @@ public:
|
||||||
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
|
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
shape::BroadcastOp::OperandAdaptor adaptor(operands);
|
shape::BroadcastOp::OperandAdaptor adaptor(operands);
|
||||||
auto lhs = adaptor.lhs().getDefiningOp<tcp::ShapeFromExtentsOp>();
|
auto lhs = adaptor.lhs().getDefiningOp<shape::FromExtentsOp>();
|
||||||
auto rhs = adaptor.rhs().getDefiningOp<tcp::ShapeFromExtentsOp>();
|
auto rhs = adaptor.rhs().getDefiningOp<shape::FromExtentsOp>();
|
||||||
if (!lhs || !rhs)
|
if (!lhs || !rhs)
|
||||||
return rewriter.notifyMatchFailure(op, "operands not converted");
|
return rewriter.notifyMatchFailure(op, "operands not converted");
|
||||||
// Establish invariant that rank(lhs) >= rank(rhs).
|
// Establish invariant that rank(lhs) >= rank(rhs).
|
||||||
|
@ -66,7 +66,7 @@ public:
|
||||||
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
|
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
|
||||||
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
|
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<tcp::ShapeFromExtentsOp>(op, resultExtents);
|
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(op, resultExtents);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -87,7 +87,7 @@ public:
|
||||||
matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands,
|
matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
tcp::GetExtentOp::OperandAdaptor adaptor(operands);
|
tcp::GetExtentOp::OperandAdaptor adaptor(operands);
|
||||||
auto fromExtents = adaptor.shape().getDefiningOp<tcp::ShapeFromExtentsOp>();
|
auto fromExtents = adaptor.shape().getDefiningOp<shape::FromExtentsOp>();
|
||||||
if (!fromExtents)
|
if (!fromExtents)
|
||||||
return rewriter.notifyMatchFailure(op, "not a from_extents op");
|
return rewriter.notifyMatchFailure(op, "not a from_extents op");
|
||||||
int64_t dim = op.dim().getLimitedValue();
|
int64_t dim = op.dim().getLimitedValue();
|
||||||
|
@ -116,14 +116,14 @@ public:
|
||||||
|
|
||||||
// Basic invariant of this pass:
|
// Basic invariant of this pass:
|
||||||
// Every def of a !shape.shape type is replaced with a
|
// Every def of a !shape.shape type is replaced with a
|
||||||
// `tcp.shape_from_extents` op.
|
// `shape.from_extents` op.
|
||||||
// When converting an op, look for the `tcp.shape_from_extents` op that
|
// When converting an op, look for the `shape.from_extents` op that
|
||||||
// defined all operands, then do a computation on the extents (i.e.
|
// defined all operands, then do a computation on the extents (i.e.
|
||||||
// operands to the `tcp.shape_from_extents` op) and produce a
|
// operands to the `shape.from_extents` op) and produce a
|
||||||
// `tcp.shape_from_extents` op.
|
// `shape.from_extents` op.
|
||||||
//
|
//
|
||||||
// We expect that previous passes have inserted a "root" set of
|
// We expect that previous passes have inserted a "root" set of
|
||||||
// tcp::ShapeFromExtentsOp's that allow this process to get started.
|
// shape::FromExtentsOp's that allow this process to get started.
|
||||||
//
|
//
|
||||||
// We then use this to resolve get_extent ops by using a rewrite
|
// We then use this to resolve get_extent ops by using a rewrite
|
||||||
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
|
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
|
||||||
|
@ -138,7 +138,7 @@ public:
|
||||||
//
|
//
|
||||||
// TODO: This approach doesn't naively work with control flow.
|
// TODO: This approach doesn't naively work with control flow.
|
||||||
// In the presence of non-cyclic control flow, we can just generalize the
|
// In the presence of non-cyclic control flow, we can just generalize the
|
||||||
// `getDefiningOp<tcp::ShapeFromExtentsOp>()` calls into something that will
|
// `getDefiningOp<shape::FromExtentsOp>()` calls into something that will
|
||||||
// look through block arguments and rewrite "phi of shapes -> phi of extents".
|
// look through block arguments and rewrite "phi of shapes -> phi of extents".
|
||||||
// In the presence of cyclic control flow, we need to somehow resolve the
|
// In the presence of cyclic control flow, we need to somehow resolve the
|
||||||
// ranks of use-def cycles ahead of time or optimistically assume that
|
// ranks of use-def cycles ahead of time or optimistically assume that
|
||||||
|
@ -158,7 +158,7 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
||||||
target.addIllegalOp<shape::ShapeOfOp>();
|
target.addIllegalOp<shape::ShapeOfOp>();
|
||||||
target.addIllegalOp<shape::BroadcastOp>();
|
target.addIllegalOp<shape::BroadcastOp>();
|
||||||
target.addIllegalOp<tcp::GetExtentOp>();
|
target.addIllegalOp<tcp::GetExtentOp>();
|
||||||
target.addLegalOp<tcp::ShapeFromExtentsOp>();
|
target.addLegalOp<shape::FromExtentsOp>();
|
||||||
target.addLegalOp<tcp::AbortIfOp>();
|
target.addLegalOp<tcp::AbortIfOp>();
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
|
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
|
||||||
|
|
|
@ -101,7 +101,7 @@ public:
|
||||||
SmallVector<Value, 6> extents;
|
SmallVector<Value, 6> extents;
|
||||||
for (int i = 0, e = tensorType.getRank(); i < e; i++)
|
for (int i = 0, e = tensorType.getRank(); i < e; i++)
|
||||||
extents.push_back(rewriter.create<DimOp>(op.getLoc(), rankedMemRef, i));
|
extents.push_back(rewriter.create<DimOp>(op.getLoc(), rankedMemRef, i));
|
||||||
rewriter.replaceOpWithNewOp<tcp::ShapeFromExtentsOp>(op, extents);
|
rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(op, extents);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -148,7 +148,7 @@ class LowerToMemRefABI : public LowerToMemRefABIBase<LowerToMemRefABI> {
|
||||||
|
|
||||||
patterns.insert<LowerShapeOfOp>(context);
|
patterns.insert<LowerShapeOfOp>(context);
|
||||||
target.addIllegalOp<shape::ShapeOfOp>();
|
target.addIllegalOp<shape::ShapeOfOp>();
|
||||||
target.addLegalOp<tcp::ShapeFromExtentsOp>();
|
target.addLegalOp<shape::FromExtentsOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -5,8 +5,8 @@
|
||||||
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
|
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
|
||||||
// CHECK-NOT: shape.broadcast
|
// CHECK-NOT: shape.broadcast
|
||||||
// CHECK-NOT: tcp.get_extent
|
// CHECK-NOT: tcp.get_extent
|
||||||
%0 = "tcp.shape_from_extents"(%arg0, %arg1) : (index, index) -> !shape.shape
|
%0 = shape.from_extents %arg0, %arg1
|
||||||
%1 = "tcp.shape_from_extents"(%arg2) : (index) -> !shape.shape
|
%1 = shape.from_extents %arg2
|
||||||
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||||
%e0 = tcp.get_extent %2, 0
|
%e0 = tcp.get_extent %2, 0
|
||||||
%e1 = tcp.get_extent %2, 1
|
%e1 = tcp.get_extent %2, 1
|
||||||
|
@ -16,7 +16,7 @@ func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index,
|
||||||
// CHECK-LABEL: func @erase_shape_observe_error
|
// CHECK-LABEL: func @erase_shape_observe_error
|
||||||
func @erase_shape_observe_error(%arg0: index) {
|
func @erase_shape_observe_error(%arg0: index) {
|
||||||
// CHECK-NOT tcp.shape_observe_error
|
// CHECK-NOT tcp.shape_observe_error
|
||||||
%0 = "tcp.shape_from_extents"(%arg0) : (index) -> !shape.shape
|
%0 = shape.from_extents %arg0
|
||||||
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
|
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
|
||||||
// CHECK: %[[VAL_2:.*]] = memref_cast %[[VAL_1]] : memref<*xf32> to memref<?xf32>
|
// CHECK: %[[VAL_2:.*]] = memref_cast %[[VAL_1]] : memref<*xf32> to memref<?xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = dim %[[VAL_2]], 0 : memref<?xf32>
|
// CHECK: %[[VAL_3:.*]] = dim %[[VAL_2]], 0 : memref<?xf32>
|
||||||
// CHECK: %[[VAL_4:.*]] = "tcp.shape_from_extents"(%[[VAL_3]]) : (index) -> !shape.shape
|
// CHECK: %[[VAL_4:.*]] = shape.from_extents %[[VAL_3]]
|
||||||
%shape = shape.shape_of %arg0 : tensor<?xf32>
|
%shape = shape.shape_of %arg0 : tensor<?xf32>
|
||||||
|
|
||||||
// CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32>
|
// CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32>
|
||||||
|
|
Loading…
Reference in New Issue