diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 00b969f79..974c1c0d2 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -323,7 +323,7 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns .add>( typeConverter, context); @@ -333,6 +333,9 @@ public: patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); + patterns.add< + ConvertAtenIntComparisonOp>( + typeConverter, context); target.addIllegalOp(); patterns.add< diff --git a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index 8a626d962..52266cb9d 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -81,6 +81,29 @@ def GtIntModule_basic(module, tu: TestUtils): # ============================================================================== +class GeIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return torch.ops.aten.ge(int(lhs), int(rhs)) + + +@register_test_case(module_factory=lambda: GeIntModule()) +def GeIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + + +# ============================================================================== + + class GeFloatModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index f45388e4e..755c34f69 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -66,6 +66,19 @@ func.func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo return %0 : !torch.bool } +// CHECK-LABEL: func.func @torch.aten.ge.int( +// CHECK-SAME: %[[LHS:.*]]: !torch.int, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[CMP:.*]] = arith.cmpi sge, %[[LHS_I64]], %[[RHS_I64]] : i64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.ge.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],f32>