mirror of https://github.com/llvm/torch-mlir
Rename tcp.abort_if to tcp.shape_observe_error
This more clearly captures its semantics as a structural "observer" of code that we currently mark as NoSideEffect but eventually lowers to eager error handling code. Also, update LowerRankedShapes to erase it, now that the layering here is clear. That pass reifies the eager error handling code, so the need for the dummy op to keep things alive isn't needed. With this change, we are now ready to start lowering to LLVM! This is the current print-ir-after-all from e2e-lowering-pipeline: https://reviews.llvm.org/P8221pull/1/head
parent
9191efdfc1
commit
be1971c4fc
|
@ -63,13 +63,18 @@ Allocates a memref of the given shape.
|
|||
// This op is also not correctly modeled right now, since it itself doesn't
|
||||
// produce the error in practice. The ops like shape.broadcast itself, when
|
||||
// lowered, immediately produce errors.
|
||||
// Right now, it's more of an "observe_error" which just keeps NoSideEffect
|
||||
// shape ops alive.
|
||||
def TCP_AbortIfErrorOp : TCP_Op<"abort_if_error",
|
||||
// TODO: This should eventually be moved to a shape dialect.
|
||||
def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Aborts the program if the argument is an error shape.";
|
||||
let summary = "Observes the fact that a shape might be an error.";
|
||||
let description = [{
|
||||
Aborts the program if its `shape` argument is an error shape.
|
||||
This op is a structural placeholder that captures a shape such that it
|
||||
is not erased. This will keep around shape computations that are later
|
||||
lowered into eager error handling code.
|
||||
|
||||
The interaction of this op, especially with control flow and side
|
||||
effecting ops, is not very well-defined, and needs to be worked
|
||||
on/redesigned.
|
||||
}];
|
||||
let arguments = (ins Shape_ShapeType:$shape);
|
||||
// TODO: ODS seems to create redeclared class members if we remove this,
|
||||
|
|
|
@ -33,7 +33,7 @@ public:
|
|||
Value rhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.rhs());
|
||||
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
|
||||
op.getLoc(), lhsShape, rhsShape, /*error=*/nullptr);
|
||||
rewriter.create<tcp::AbortIfErrorOp>(op.getLoc(), broadcastedShape);
|
||||
rewriter.create<tcp::ShapeObserveErrorOp>(op.getLoc(), broadcastedShape);
|
||||
// TODO: It's annoying to do the dynamic broadcast above then
|
||||
// do the static transfer function here. Would be nice if they could
|
||||
// somehow be unified.
|
||||
|
|
|
@ -14,10 +14,10 @@ using namespace mlir::NPCOMP;
|
|||
using namespace mlir::NPCOMP::tcp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AbortIfErrorOp
|
||||
// ShapeObserveErrorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AbortIfErrorOp::inferReturnTypes(
|
||||
LogicalResult ShapeObserveErrorOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
|
|
|
@ -97,6 +97,23 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Now that we have lowered ranked shapes, which reifies the eager
|
||||
// error-handling code, the tcp::ShapeObserveErrorOp's are no longer
|
||||
// needed.
|
||||
class EraseShapeObserveErrorOp
|
||||
: public OpConversionPattern<tcp::ShapeObserveErrorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Basic invariant of this pass:
|
||||
// Every def of a !shape.shape type is replaced with a
|
||||
// `tcp.shape_from_extents` op.
|
||||
|
@ -136,14 +153,15 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
|||
OwningRewritePatternList patterns;
|
||||
patterns.insert<LowerShapeBroadcastOp>(context);
|
||||
patterns.insert<LowerShapeGetExtentOp>(context);
|
||||
patterns.insert<EraseShapeObserveErrorOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<shape::ShapeOfOp>();
|
||||
target.addIllegalOp<shape::BroadcastOp>();
|
||||
target.addIllegalOp<tcp::GetExtentOp>();
|
||||
target.addLegalOp<tcp::ShapeFromExtentsOp>();
|
||||
target.addLegalOp<tcp::RtGetTensorExtentOp>();
|
||||
target.addLegalOp<tcp::AbortIfOp>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -12,3 +12,11 @@ func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index,
|
|||
%e1 = tcp.get_extent %2, 1
|
||||
return %e0, %e1 : index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @erase_shape_observe_error
|
||||
func @erase_shape_observe_error(%arg0: index) {
|
||||
// CHECK-NOT tcp.shape_observe_error
|
||||
%0 = "tcp.shape_from_extents"(%arg0) : (index) -> !shape.shape
|
||||
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue