From a82ba1c42282bff79d7b5f0a1c25601d265bc7f2 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:40:20 -0500 Subject: [PATCH] [TorchToArith] add lowerings for some scalar bool binary ops (#3823) Added lit tests since these scalar operations don't trace well through the `fx_importer` route. `XOR` and `NE` are equivalent binary operators, so `aten.ne.bool` is lowered to `arith.xori`. --- lib/Conversion/TorchToArith/TorchToArith.cpp | 7 ++++ test/Conversion/TorchToArith/basic.mlir | 42 ++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index a1af190e4..143b46694 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -496,6 +496,13 @@ public: patterns.add>( typeConverter, context); target.addIllegalOp(); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); patterns .add>( typeConverter, context); diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 3d9e9f22a..86ad4e972 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -40,6 +40,48 @@ func.func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo return %0 : !torch.bool } + +// CHECK-LABEL: func.func @torch.aten.ne.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[XOR:.*]] = arith.xori %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[XOR]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.ne.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.ne.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + + +// CHECK-LABEL: func.func @torch.aten.__and__.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[AND:.*]] = arith.andi %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[AND]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.__and__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.__and__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + + +// CHECK-LABEL: func.func @torch.aten.__or__.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[OR:.*]] = arith.ori %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[OR]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.__or__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.__or__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.eq.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {