diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 8a6aeec22..a07b9e4ef 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -931,4 +931,33 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [ let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)"; } +// To handle runtime assertions, torchscript provides us `torch._assert` operation. +// But TS compiler introduces control flow for `torch._assert` operation. The +// `torch._assert` would introduce control flow like: +// +// %cond = "torch.aten.Bool.Tensor"(%0) : (!torch.tensor) -> !torch.bool +// "torch.prim.If"(%cond) ({ +// "torch.prim.If.yield"() : () -> () +// }, { +// "torch.prim.RaiseException"(%msg) : (!torch.str) -> () +// "torch.prim.If.yield"() : () -> () +// }) : (!torch.bool) -> () +// +// This new operation `torch.runtime.assert` is added to simplify the IR control +// flow by avoiding unnecessary branches. It also makes insertion of the runtime +// assert in the source code easier. +def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [ + AllowsTypeRefinement, + HasValueSemantics, + ]> { + let summary = "Runtime Assertion"; + let arguments = (ins + Torch_BoolType:$condition, + StrAttr:$message + ); + let results = (outs + ); + let assemblyFormat = "$condition `,` $message attr-dict"; +} + #endif // TORCH_OPS diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 069db650d..861b368ba 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -45,6 +45,20 @@ public: }; } // namespace +namespace { +class ConvertRuntimeAssertOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.condition(), + adaptor.message()); + return success(); + } +}; +} // namespace + namespace { template class ConvertAtenBinaryOp : public OpConversionPattern { @@ -173,6 +187,8 @@ public: RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index 5cc4af52b..53cba8907 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -13,6 +13,20 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int { return %0 : !torch.int } +// CHECK-LABEL: func @torch.runtime.assert( +// CHECK-SAME: %[[X:.*]]: !torch.int, +// CHECK-SAME: %[[Y:.*]]: !torch.int) { +// CHECK: %[[X_I64:.*]] = torch_c.to_i64 %[[X]] +// CHECK: %[[Y_I64:.*]] = torch_c.to_i64 %[[Y]] +// CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[X_I64]], %[[Y_I64]] : i64 +// CHECK: assert %[[CMP]], "x must not be equal to y" +// CHECK: return +func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) { + %0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool + torch.runtime.assert %0, "x must not be equal to y" + return +} + // CHECK-LABEL: func @torch.aten.ne.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {