[TorchToArith] add lowerings for some scalar bool binary ops (#3823)

Added lit tests since these scalar operations don't trace well through
the `fx_importer` route.

`XOR` and `NE` are equivalent binary operators, so `aten.ne.bool` is
lowered to `arith.xori`.
pull/3848/head
zjgarvey 2024-11-01 10:40:20 -05:00 committed by GitHub
parent 3dbeda9082
commit a82ba1c422
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 0 deletions

View File

@ -496,6 +496,13 @@ public:
patterns.add<ConvertAtenBinaryOp<PrimMinIntOp, arith::MinSIOp>>(
typeConverter, context);
target.addIllegalOp<AtenCeilFloatOp>();
target.addIllegalOp<Aten__Or__BoolOp, Aten__And__BoolOp, AtenNeBoolOp>();
patterns.add<ConvertAtenBinaryOp<Aten__Or__BoolOp, arith::OrIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<Aten__And__BoolOp, arith::AndIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenNeBoolOp, arith::XOrIOp>>(
typeConverter, context);
patterns
.add<ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(
typeConverter, context);

View File

@ -40,6 +40,48 @@ func.func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.ne.bool(
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
// CHECK: %[[XOR:.*]] = arith.xori %[[LHS]], %[[RHS]] : i1
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[XOR]]
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
func.func @torch.aten.ne.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
%0 = torch.aten.ne.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.__and__.bool(
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
// CHECK: %[[AND:.*]] = arith.andi %[[LHS]], %[[RHS]] : i1
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[AND]]
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
func.func @torch.aten.__and__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
%0 = torch.aten.__and__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.__or__.bool(
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
// CHECK: %[[OR:.*]] = arith.ori %[[LHS]], %[[RHS]] : i1
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[OR]]
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
func.func @torch.aten.__or__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
%0 = torch.aten.__or__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.eq.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {