From 1f102cc400a72fa0926eb79f657e347af7c2d142 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 25 Apr 2022 19:07:43 +0530 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.ge.float_int op This commit adds lowering of `aten.ge.float_int` op. Signed-Off By: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++++++++++++++++++ lib/Conversion/TorchToStd/TorchToStd.cpp | 13 ++++++---- .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/scalar_comparison.py | 20 ++++++++++++++++ test/Conversion/TorchToStd/basic.mlir | 14 +++++++++++ 5 files changed, 68 insertions(+), 4 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5d61bef5d..008240ff9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6592,6 +6592,30 @@ def Torch_AtenLtFloatIntOp : Torch_Op<"aten.lt.float_int", [ }]; } +def Torch_AtenGeFloatIntOp : Torch_Op<"aten.ge.float_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ge.float_int : (float, int) -> (bool)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGeFloatIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGeFloatIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index b4e89410f..e207c99c2 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" @@ -94,7 +95,7 @@ public: } // namespace namespace { -// Lowers aten float comparison ops. +// Lowers aten float and float_int comparison ops. template class ConvertAtenFloatComparisonOp : public OpConversionPattern { public: @@ -103,8 +104,9 @@ public: matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, Pred, adaptor.a(), - adaptor.b()); + Value lhs = adaptor.a(), rhs = adaptor.b(); + rhs = convertScalarToDtype(rewriter, op.getLoc(), rhs, lhs.getType()); + rewriter.replaceOpWithNewOp(op, Pred, lhs, rhs); return success(); } }; @@ -209,10 +211,13 @@ public: patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add< ConvertAtenFloatComparisonOp>( typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 7972b7435..d67e4bc19 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -495,6 +495,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::ge.float : (float, float) -> (bool)", has_folder=True) emit("aten::lt.float : (float, float) -> (bool)", has_folder=True) emit("aten::lt.float_int : (float, int) -> (bool)") + emit("aten::ge.float_int : (float, int) -> (bool)") emit("aten::__and__.bool : (bool, bool) -> (bool)") emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) diff --git a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index f23c375b9..ff60e3821 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -88,3 +88,23 @@ class GeFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: GeFloatModule()) def GeFloatModule_basic(module, tu: TestUtils): module.forward(torch.randn(()), torch.randn(())) + +# ============================================================================== + +class GeFloatIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs) >= int(rhs) + + +@register_test_case(module_factory=lambda: GeFloatIntModule()) +def GeFloatIntModule_basic(module, tu: TestUtils): + module.forward(torch.randn(()), torch.randint(-100, 100, ())) diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index 1f9dc2dac..3c9d16db2 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -179,3 +179,17 @@ func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.bo %0 = torch.aten.ge.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.bool return %0 : !torch.bool } + +// CHECK-LABEL: func @torch.aten.ge.float_int( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 +// CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.ge.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.bool + return %0 : !torch.bool +}