Use upstream shape.from_extents

Replace our local `tcp.shape_from_extents` op with the upstream
`shape.from_extents` op.
pull/1/head
Sean Silva 2020-05-21 14:51:01 -07:00
parent 1fed1cb016
commit 3a09455540
7 changed files with 18 additions and 42 deletions

View File

@ -109,17 +109,6 @@ This op has undefined behavior if the shape is an error.
]; ];
} }
// TODO: This op belongs in the shape dialect as `shape.from_extents`.
def TCP_ShapeFromExtentsOp : TCP_Op<"shape_from_extents",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Constructs a shape from extents";
let description = [{
Constructs a shape from the extents passed as arguments.
}];
let arguments = (ins Variadic<Index>:$extents);
let results = (outs Shape_ShapeType:$shape);
}
def TCP_AbortIfOp : TCP_Op<"abort_if"> { def TCP_AbortIfOp : TCP_Op<"abort_if"> {
let summary = "Aborts the program if the argument is true."; let summary = "Aborts the program if the argument is true.";
let description = [{ let description = [{

View File

@ -37,18 +37,6 @@ LogicalResult GetExtentOp::inferReturnTypes(
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// ShapeFromExtentsOp
//===----------------------------------------------------------------------===//
LogicalResult ShapeFromExtentsOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(shape::ShapeType::get(context));
return success();
}
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
namespace tcp { namespace tcp {

View File

@ -424,9 +424,8 @@ void mlir::NPCOMP::createE2ELoweringPipeline(OpPassManager &pm) {
// pass that checks no !shape.shape types left. // pass that checks no !shape.shape types left.
pm.addPass(createLowerRankedShapesPass()); pm.addPass(createLowerRankedShapesPass());
// Run a final canonicalization pass to delete dead // Run a final canonicalization pass to delete dead
// `tcp.shape_from_extents` ops. // `shape.from_extents` ops.
// This is needed for correctness, since we can't currently lower that op // This is needed for correctness, since we can't currently lower that op
// to LLVM, since we don't have a runtime representation of `!shape.shape`. // to LLVM, since we don't have a runtime representation of `!shape.shape`.
// TODO: Change LowerRankedShapes to delete these ops itself. // TODO: Change LowerRankedShapes to delete these ops itself.

View File

@ -27,8 +27,8 @@ public:
matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands, matchAndRewrite(shape::BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
shape::BroadcastOp::OperandAdaptor adaptor(operands); shape::BroadcastOp::OperandAdaptor adaptor(operands);
auto lhs = adaptor.lhs().getDefiningOp<tcp::ShapeFromExtentsOp>(); auto lhs = adaptor.lhs().getDefiningOp<shape::FromExtentsOp>();
auto rhs = adaptor.rhs().getDefiningOp<tcp::ShapeFromExtentsOp>(); auto rhs = adaptor.rhs().getDefiningOp<shape::FromExtentsOp>();
if (!lhs || !rhs) if (!lhs || !rhs)
return rewriter.notifyMatchFailure(op, "operands not converted"); return rewriter.notifyMatchFailure(op, "operands not converted");
// Establish invariant that rank(lhs) >= rank(rhs). // Establish invariant that rank(lhs) >= rank(rhs).
@ -66,7 +66,7 @@ public:
createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent); createAbortIfIllegalBroadcastExtent(lhsExtent, resultExtent);
createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent); createAbortIfIllegalBroadcastExtent(rhsExtent, resultExtent);
} }
rewriter.replaceOpWithNewOp<tcp::ShapeFromExtentsOp>(op, resultExtents); rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(op, resultExtents);
return success(); return success();
} }
}; };
@ -87,7 +87,7 @@ public:
matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands, matchAndRewrite(tcp::GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
tcp::GetExtentOp::OperandAdaptor adaptor(operands); tcp::GetExtentOp::OperandAdaptor adaptor(operands);
auto fromExtents = adaptor.shape().getDefiningOp<tcp::ShapeFromExtentsOp>(); auto fromExtents = adaptor.shape().getDefiningOp<shape::FromExtentsOp>();
if (!fromExtents) if (!fromExtents)
return rewriter.notifyMatchFailure(op, "not a from_extents op"); return rewriter.notifyMatchFailure(op, "not a from_extents op");
int64_t dim = op.dim().getLimitedValue(); int64_t dim = op.dim().getLimitedValue();
@ -116,14 +116,14 @@ public:
// Basic invariant of this pass: // Basic invariant of this pass:
// Every def of a !shape.shape type is replaced with a // Every def of a !shape.shape type is replaced with a
// `tcp.shape_from_extents` op. // `shape.from_extents` op.
// When converting an op, look for the `tcp.shape_from_extents` op that // When converting an op, look for the `shape.from_extents` op that
// defined all operands, then do a computation on the extents (i.e. // defined all operands, then do a computation on the extents (i.e.
// operands to the `tcp.shape_from_extents` op) and produce a // operands to the `shape.from_extents` op) and produce a
// `tcp.shape_from_extents` op. // `shape.from_extents` op.
// //
// We expect that previous passes have inserted a "root" set of // We expect that previous passes have inserted a "root" set of
// tcp::ShapeFromExtentsOp's that allow this process to get started. // shape::FromExtentsOp's that allow this process to get started.
// //
// We then use this to resolve get_extent ops by using a rewrite // We then use this to resolve get_extent ops by using a rewrite
// `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in // `get_extent(from_extents(x1,x2,x3), N) -> xN`, which should apply in
@ -138,7 +138,7 @@ public:
// //
// TODO: This approach doesn't naively work with control flow. // TODO: This approach doesn't naively work with control flow.
// In the presence of non-cyclic control flow, we can just generalize the // In the presence of non-cyclic control flow, we can just generalize the
// `getDefiningOp<tcp::ShapeFromExtentsOp>()` calls into something that will // `getDefiningOp<shape::FromExtentsOp>()` calls into something that will
// look through block arguments and rewrite "phi of shapes -> phi of extents". // look through block arguments and rewrite "phi of shapes -> phi of extents".
// In the presence of cyclic control flow, we need to somehow resolve the // In the presence of cyclic control flow, we need to somehow resolve the
// ranks of use-def cycles ahead of time or optimistically assume that // ranks of use-def cycles ahead of time or optimistically assume that
@ -158,7 +158,7 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
target.addIllegalOp<shape::ShapeOfOp>(); target.addIllegalOp<shape::ShapeOfOp>();
target.addIllegalOp<shape::BroadcastOp>(); target.addIllegalOp<shape::BroadcastOp>();
target.addIllegalOp<tcp::GetExtentOp>(); target.addIllegalOp<tcp::GetExtentOp>();
target.addLegalOp<tcp::ShapeFromExtentsOp>(); target.addLegalOp<shape::FromExtentsOp>();
target.addLegalOp<tcp::AbortIfOp>(); target.addLegalOp<tcp::AbortIfOp>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();
target.addIllegalOp<tcp::ShapeObserveErrorOp>(); target.addIllegalOp<tcp::ShapeObserveErrorOp>();

View File

@ -101,7 +101,7 @@ public:
SmallVector<Value, 6> extents; SmallVector<Value, 6> extents;
for (int i = 0, e = tensorType.getRank(); i < e; i++) for (int i = 0, e = tensorType.getRank(); i < e; i++)
extents.push_back(rewriter.create<DimOp>(op.getLoc(), rankedMemRef, i)); extents.push_back(rewriter.create<DimOp>(op.getLoc(), rankedMemRef, i));
rewriter.replaceOpWithNewOp<tcp::ShapeFromExtentsOp>(op, extents); rewriter.replaceOpWithNewOp<shape::FromExtentsOp>(op, extents);
return success(); return success();
} }
}; };
@ -148,7 +148,7 @@ class LowerToMemRefABI : public LowerToMemRefABIBase<LowerToMemRefABI> {
patterns.insert<LowerShapeOfOp>(context); patterns.insert<LowerShapeOfOp>(context);
target.addIllegalOp<shape::ShapeOfOp>(); target.addIllegalOp<shape::ShapeOfOp>();
target.addLegalOp<tcp::ShapeFromExtentsOp>(); target.addLegalOp<shape::FromExtentsOp>();
if (failed(applyPartialConversion(func, target, patterns))) { if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure(); return signalPassFailure();

View File

@ -5,8 +5,8 @@
func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) { func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index, index) {
// CHECK-NOT: shape.broadcast // CHECK-NOT: shape.broadcast
// CHECK-NOT: tcp.get_extent // CHECK-NOT: tcp.get_extent
%0 = "tcp.shape_from_extents"(%arg0, %arg1) : (index, index) -> !shape.shape %0 = shape.from_extents %arg0, %arg1
%1 = "tcp.shape_from_extents"(%arg2) : (index) -> !shape.shape %1 = shape.from_extents %arg2
%2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %2 = "shape.broadcast"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%e0 = tcp.get_extent %2, 0 %e0 = tcp.get_extent %2, 0
%e1 = tcp.get_extent %2, 1 %e1 = tcp.get_extent %2, 1
@ -16,7 +16,7 @@ func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index,
// CHECK-LABEL: func @erase_shape_observe_error // CHECK-LABEL: func @erase_shape_observe_error
func @erase_shape_observe_error(%arg0: index) { func @erase_shape_observe_error(%arg0: index) {
// CHECK-NOT tcp.shape_observe_error // CHECK-NOT tcp.shape_observe_error
%0 = "tcp.shape_from_extents"(%arg0) : (index) -> !shape.shape %0 = shape.from_extents %arg0
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none "tcp.shape_observe_error"(%0) : (!shape.shape) -> none
return return
} }

View File

@ -12,7 +12,7 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[VAL_2:.*]] = memref_cast %[[VAL_1]] : memref<*xf32> to memref<?xf32> // CHECK: %[[VAL_2:.*]] = memref_cast %[[VAL_1]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[VAL_3:.*]] = dim %[[VAL_2]], 0 : memref<?xf32> // CHECK: %[[VAL_3:.*]] = dim %[[VAL_2]], 0 : memref<?xf32>
// CHECK: %[[VAL_4:.*]] = "tcp.shape_from_extents"(%[[VAL_3]]) : (index) -> !shape.shape // CHECK: %[[VAL_4:.*]] = shape.from_extents %[[VAL_3]]
%shape = shape.shape_of %arg0 : tensor<?xf32> %shape = shape.shape_of %arg0 : tensor<?xf32>
// CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32> // CHECK: %[[VAL_5:.*]] = tcp.alloc_memref %[[VAL_4]] : memref<?xf32>