mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add run-time assert support in Torch-dialect
- This commit adds `aten.assert` op in the Torch dialect. - The `aten.assert` op is lowered to `mlir::Assert` op. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/571/head
parent
e09e2cbe70
commit
bd177bdfc7
|
@ -931,4 +931,33 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
|
|||
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
|
||||
}
|
||||
|
||||
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
|
||||
// But TS compiler introduces control flow for `torch._assert` operation. The
|
||||
// `torch._assert` would introduce control flow like:
|
||||
//
|
||||
// %cond = "torch.aten.Bool.Tensor"(%0) : (!torch.tensor) -> !torch.bool
|
||||
// "torch.prim.If"(%cond) ({
|
||||
// "torch.prim.If.yield"() : () -> ()
|
||||
// }, {
|
||||
// "torch.prim.RaiseException"(%msg) : (!torch.str) -> ()
|
||||
// "torch.prim.If.yield"() : () -> ()
|
||||
// }) : (!torch.bool) -> ()
|
||||
//
|
||||
// This new operation `torch.runtime.assert` is added to simplify the IR control
|
||||
// flow by avoiding unnecessary branches. It also makes insertion of the runtime
|
||||
// assert in the source code easier.
|
||||
def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
]> {
|
||||
let summary = "Runtime Assertion";
|
||||
let arguments = (ins
|
||||
Torch_BoolType:$condition,
|
||||
StrAttr:$message
|
||||
);
|
||||
let results = (outs
|
||||
);
|
||||
let assemblyFormat = "$condition `,` $message attr-dict";
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -45,6 +45,20 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<AssertOp>(op, adaptor.condition(),
|
||||
adaptor.message());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename AtenOp, typename BinOp>
|
||||
class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
|
||||
|
@ -173,6 +187,8 @@ public:
|
|||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<AtenDimOp>();
|
||||
patterns.add<ConvertAtenDimOp>(typeConverter, context);
|
||||
target.addIllegalOp<RuntimeAssertOp>();
|
||||
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNeIntOp>();
|
||||
patterns.add<ConvertAtenNeIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenGtIntOp>();
|
||||
|
|
|
@ -13,6 +13,20 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
|
|||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.runtime.assert(
|
||||
// CHECK-SAME: %[[X:.*]]: !torch.int,
|
||||
// CHECK-SAME: %[[Y:.*]]: !torch.int) {
|
||||
// CHECK: %[[X_I64:.*]] = torch_c.to_i64 %[[X]]
|
||||
// CHECK: %[[Y_I64:.*]] = torch_c.to_i64 %[[Y]]
|
||||
// CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[X_I64]], %[[Y_I64]] : i64
|
||||
// CHECK: assert %[[CMP]], "x must not be equal to y"
|
||||
// CHECK: return
|
||||
func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) {
|
||||
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.runtime.assert %0, "x must not be equal to y"
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ne.int(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||
|
|
Loading…
Reference in New Issue