diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 826d191c6..15ab80bec 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -408,7 +408,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns .add>( typeConverter, context); @@ -418,9 +419,15 @@ public: patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); + patterns.add< + ConvertAtenIntComparisonOp>( + typeConverter, context); patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); + patterns.add< + ConvertAtenIntComparisonOp>( + typeConverter, context); target.addIllegalOp(); patterns.add< diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 933031e16..6b44426f5 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -79,6 +79,33 @@ func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo return %0 : !torch.bool } + +// CHECK-LABEL: func.func @torch.aten.lt.int( +// CHECK-SAME: %[[LHS:.*]]: !torch.int, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[LHS_I64]], %[[RHS_I64]] : i64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func.func @torch.aten.lt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.lt.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.le.int( +// CHECK-SAME: %[[LHS:.*]]: !torch.int, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[LHS_I64]], %[[RHS_I64]] : i64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func.func @torch.aten.le.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.le.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],f32>