Rename tcp.abort_if to tcp.shape_observe_error

This more clearly captures its semantics as a structural "observer" of
code that we currently mark as NoSideEffect but eventually lowers to
eager error handling code.

Also, update LowerRankedShapes to erase it, now that the layering here
is clear. That pass reifies the eager error handling code, so the need
for the dummy op to keep things alive isn't needed.

With this change, we are now ready to start lowering to LLVM!
This is the current print-ir-after-all from e2e-lowering-pipeline:
https://reviews.llvm.org/P8221
pull/1/head
Sean Silva 2020-05-18 13:35:25 -07:00
parent 9191efdfc1
commit be1971c4fc
5 changed files with 40 additions and 9 deletions

View File

@ -63,13 +63,18 @@ Allocates a memref of the given shape.
// This op is also not correctly modeled right now, since it itself doesn't
// produce the error in practice. The ops like shape.broadcast itself, when
// lowered, immediately produce errors.
// Right now, it's more of an "observe_error" which just keeps NoSideEffect
// shape ops alive.
def TCP_AbortIfErrorOp : TCP_Op<"abort_if_error",
// TODO: This should eventually be moved to a shape dialect.
def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error",
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Aborts the program if the argument is an error shape.";
let summary = "Observes the fact that a shape might be an error.";
let description = [{
Aborts the program if its `shape` argument is an error shape.
This op is a structural placeholder that captures a shape such that it
is not erased. This will keep around shape computations that are later
lowered into eager error handling code.
The interaction of this op, especially with control flow and side
effecting ops, is not very well-defined, and needs to be worked
on/redesigned.
}];
let arguments = (ins Shape_ShapeType:$shape);
// TODO: ODS seems to create redeclared class members if we remove this,

View File

@ -33,7 +33,7 @@ public:
Value rhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.rhs());
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
op.getLoc(), lhsShape, rhsShape, /*error=*/nullptr);
rewriter.create<tcp::AbortIfErrorOp>(op.getLoc(), broadcastedShape);
rewriter.create<tcp::ShapeObserveErrorOp>(op.getLoc(), broadcastedShape);
// TODO: It's annoying to do the dynamic broadcast above then
// do the static transfer function here. Would be nice if they could
// somehow be unified.

View File

@ -14,10 +14,10 @@ using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::tcp;
//===----------------------------------------------------------------------===//
// AbortIfErrorOp
// ShapeObserveErrorOp
//===----------------------------------------------------------------------===//
LogicalResult AbortIfErrorOp::inferReturnTypes(
LogicalResult ShapeObserveErrorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {

View File

@ -97,6 +97,23 @@ public:
};
} // namespace
namespace {
// Now that we have lowered ranked shapes, which reifies the eager
// error-handling code, the tcp::ShapeObserveErrorOp's are no longer
// needed.
class EraseShapeObserveErrorOp
: public OpConversionPattern<tcp::ShapeObserveErrorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
} // namespace
// Basic invariant of this pass:
// Every def of a !shape.shape type is replaced with a
// `tcp.shape_from_extents` op.
@ -136,14 +153,15 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
OwningRewritePatternList patterns;
patterns.insert<LowerShapeBroadcastOp>(context);
patterns.insert<LowerShapeGetExtentOp>(context);
patterns.insert<EraseShapeObserveErrorOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<shape::ShapeOfOp>();
target.addIllegalOp<shape::BroadcastOp>();
target.addIllegalOp<tcp::GetExtentOp>();
target.addLegalOp<tcp::ShapeFromExtentsOp>();
target.addLegalOp<tcp::RtGetTensorExtentOp>();
target.addLegalOp<tcp::AbortIfOp>();
target.addLegalDialect<StandardOpsDialect>();
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}

View File

@ -12,3 +12,11 @@ func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index,
%e1 = tcp.get_extent %2, 1
return %e0, %e1 : index, index
}
// CHECK-LABEL: func @erase_shape_observe_error
func @erase_shape_observe_error(%arg0: index) {
// CHECK-NOT tcp.shape_observe_error
%0 = "tcp.shape_from_extents"(%arg0) : (index) -> !shape.shape
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
return
}