mirror of https://github.com/llvm/torch-mlir
Rename tcp.abort_if to tcp.shape_observe_error
This more clearly captures its semantics as a structural "observer" of code that we currently mark as NoSideEffect but eventually lowers to eager error handling code. Also, update LowerRankedShapes to erase it, now that the layering here is clear. That pass reifies the eager error handling code, so the need for the dummy op to keep things alive isn't needed. With this change, we are now ready to start lowering to LLVM! This is the current print-ir-after-all from e2e-lowering-pipeline: https://reviews.llvm.org/P8221pull/1/head
parent
9191efdfc1
commit
be1971c4fc
|
@ -63,13 +63,18 @@ Allocates a memref of the given shape.
|
||||||
// This op is also not correctly modeled right now, since it itself doesn't
|
// This op is also not correctly modeled right now, since it itself doesn't
|
||||||
// produce the error in practice. The ops like shape.broadcast itself, when
|
// produce the error in practice. The ops like shape.broadcast itself, when
|
||||||
// lowered, immediately produce errors.
|
// lowered, immediately produce errors.
|
||||||
// Right now, it's more of an "observe_error" which just keeps NoSideEffect
|
// TODO: This should eventually be moved to a shape dialect.
|
||||||
// shape ops alive.
|
def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error",
|
||||||
def TCP_AbortIfErrorOp : TCP_Op<"abort_if_error",
|
|
||||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||||
let summary = "Aborts the program if the argument is an error shape.";
|
let summary = "Observes the fact that a shape might be an error.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Aborts the program if its `shape` argument is an error shape.
|
This op is a structural placeholder that captures a shape such that it
|
||||||
|
is not erased. This will keep around shape computations that are later
|
||||||
|
lowered into eager error handling code.
|
||||||
|
|
||||||
|
The interaction of this op, especially with control flow and side
|
||||||
|
effecting ops, is not very well-defined, and needs to be worked
|
||||||
|
on/redesigned.
|
||||||
}];
|
}];
|
||||||
let arguments = (ins Shape_ShapeType:$shape);
|
let arguments = (ins Shape_ShapeType:$shape);
|
||||||
// TODO: ODS seems to create redeclared class members if we remove this,
|
// TODO: ODS seems to create redeclared class members if we remove this,
|
||||||
|
|
|
@ -33,7 +33,7 @@ public:
|
||||||
Value rhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.rhs());
|
Value rhsShape = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.rhs());
|
||||||
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
|
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
|
||||||
op.getLoc(), lhsShape, rhsShape, /*error=*/nullptr);
|
op.getLoc(), lhsShape, rhsShape, /*error=*/nullptr);
|
||||||
rewriter.create<tcp::AbortIfErrorOp>(op.getLoc(), broadcastedShape);
|
rewriter.create<tcp::ShapeObserveErrorOp>(op.getLoc(), broadcastedShape);
|
||||||
// TODO: It's annoying to do the dynamic broadcast above then
|
// TODO: It's annoying to do the dynamic broadcast above then
|
||||||
// do the static transfer function here. Would be nice if they could
|
// do the static transfer function here. Would be nice if they could
|
||||||
// somehow be unified.
|
// somehow be unified.
|
||||||
|
|
|
@ -14,10 +14,10 @@ using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::tcp;
|
using namespace mlir::NPCOMP::tcp;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AbortIfErrorOp
|
// ShapeObserveErrorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult AbortIfErrorOp::inferReturnTypes(
|
LogicalResult ShapeObserveErrorOp::inferReturnTypes(
|
||||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
DictionaryAttr attributes, RegionRange regions,
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
|
|
|
@ -97,6 +97,23 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Now that we have lowered ranked shapes, which reifies the eager
|
||||||
|
// error-handling code, the tcp::ShapeObserveErrorOp's are no longer
|
||||||
|
// needed.
|
||||||
|
class EraseShapeObserveErrorOp
|
||||||
|
: public OpConversionPattern<tcp::ShapeObserveErrorOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tcp::ShapeObserveErrorOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Basic invariant of this pass:
|
// Basic invariant of this pass:
|
||||||
// Every def of a !shape.shape type is replaced with a
|
// Every def of a !shape.shape type is replaced with a
|
||||||
// `tcp.shape_from_extents` op.
|
// `tcp.shape_from_extents` op.
|
||||||
|
@ -136,14 +153,15 @@ class LowerRankedShapes : public LowerRankedShapesBase<LowerRankedShapes> {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<LowerShapeBroadcastOp>(context);
|
patterns.insert<LowerShapeBroadcastOp>(context);
|
||||||
patterns.insert<LowerShapeGetExtentOp>(context);
|
patterns.insert<LowerShapeGetExtentOp>(context);
|
||||||
|
patterns.insert<EraseShapeObserveErrorOp>(context);
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<shape::ShapeOfOp>();
|
target.addIllegalOp<shape::ShapeOfOp>();
|
||||||
target.addIllegalOp<shape::BroadcastOp>();
|
target.addIllegalOp<shape::BroadcastOp>();
|
||||||
target.addIllegalOp<tcp::GetExtentOp>();
|
target.addIllegalOp<tcp::GetExtentOp>();
|
||||||
target.addLegalOp<tcp::ShapeFromExtentsOp>();
|
target.addLegalOp<tcp::ShapeFromExtentsOp>();
|
||||||
target.addLegalOp<tcp::RtGetTensorExtentOp>();
|
|
||||||
target.addLegalOp<tcp::AbortIfOp>();
|
target.addLegalOp<tcp::AbortIfOp>();
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
target.addIllegalOp<tcp::ShapeObserveErrorOp>();
|
||||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,3 +12,11 @@ func @broadcast_rank2_rank1(%arg0: index, %arg1: index, %arg2: index) -> (index,
|
||||||
%e1 = tcp.get_extent %2, 1
|
%e1 = tcp.get_extent %2, 1
|
||||||
return %e0, %e1 : index, index
|
return %e0, %e1 : index, index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @erase_shape_observe_error
|
||||||
|
func @erase_shape_observe_error(%arg0: index) {
|
||||||
|
// CHECK-NOT tcp.shape_observe_error
|
||||||
|
%0 = "tcp.shape_from_extents"(%arg0) : (index) -> !shape.shape
|
||||||
|
"tcp.shape_observe_error"(%0) : (!shape.shape) -> none
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue