[TORCH][MLIR] Add run-time assert support in Torch-dialect

- This commit adds `aten.assert` op in the Torch dialect.
- The `aten.assert` op is lowered to `mlir::Assert` op.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/571/head
Gaurav Shukla 2022-02-08 19:48:08 +05:30 committed by Yi Zhang
parent e09e2cbe70
commit bd177bdfc7
3 changed files with 59 additions and 0 deletions

View File

@ -931,4 +931,33 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
}
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
// But TS compiler introduces control flow for `torch._assert` operation. The
// `torch._assert` would introduce control flow like:
//
// %cond = "torch.aten.Bool.Tensor"(%0) : (!torch.tensor) -> !torch.bool
// "torch.prim.If"(%cond) ({
// "torch.prim.If.yield"() : () -> ()
// }, {
// "torch.prim.RaiseException"(%msg) : (!torch.str) -> ()
// "torch.prim.If.yield"() : () -> ()
// }) : (!torch.bool) -> ()
//
// This new operation `torch.runtime.assert` is added to simplify the IR control
// flow by avoiding unnecessary branches. It also makes insertion of the runtime
// assert in the source code easier.
def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [
AllowsTypeRefinement,
HasValueSemantics,
]> {
let summary = "Runtime Assertion";
let arguments = (ins
Torch_BoolType:$condition,
StrAttr:$message
);
let results = (outs
);
let assemblyFormat = "$condition `,` $message attr-dict";
}
#endif // TORCH_OPS

View File

@ -45,6 +45,20 @@ public:
};
} // namespace
namespace {
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AssertOp>(op, adaptor.condition(),
adaptor.message());
return success();
}
};
} // namespace
namespace {
template <typename AtenOp, typename BinOp>
class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
@ -173,6 +187,8 @@ public:
RewritePatternSet patterns(context);
target.addIllegalOp<AtenDimOp>();
patterns.add<ConvertAtenDimOp>(typeConverter, context);
target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
target.addIllegalOp<AtenNeIntOp>();
patterns.add<ConvertAtenNeIntOp>(typeConverter, context);
target.addIllegalOp<AtenGtIntOp>();

View File

@ -13,6 +13,20 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
return %0 : !torch.int
}
// CHECK-LABEL: func @torch.runtime.assert(
// CHECK-SAME: %[[X:.*]]: !torch.int,
// CHECK-SAME: %[[Y:.*]]: !torch.int) {
// CHECK: %[[X_I64:.*]] = torch_c.to_i64 %[[X]]
// CHECK: %[[Y_I64:.*]] = torch_c.to_i64 %[[Y]]
// CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[X_I64]], %[[Y_I64]] : i64
// CHECK: assert %[[CMP]], "x must not be equal to y"
// CHECK: return
func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) {
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
torch.runtime.assert %0, "x must not be equal to y"
return
}
// CHECK-LABEL: func @torch.aten.ne.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {