mirror of https://github.com/llvm/torch-mlir
[TorchToArith] add lowerings for some scalar bool binary ops (#3823)
Added lit tests since these scalar operations don't trace well through the `fx_importer` route. `XOR` and `NE` are equivalent binary operators, so `aten.ne.bool` is lowered to `arith.xori`.pull/3848/head
parent
3dbeda9082
commit
a82ba1c422
|
@ -496,6 +496,13 @@ public:
|
|||
patterns.add<ConvertAtenBinaryOp<PrimMinIntOp, arith::MinSIOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenCeilFloatOp>();
|
||||
target.addIllegalOp<Aten__Or__BoolOp, Aten__And__BoolOp, AtenNeBoolOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<Aten__Or__BoolOp, arith::OrIOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<Aten__And__BoolOp, arith::AndIOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenNeBoolOp, arith::XOrIOp>>(
|
||||
typeConverter, context);
|
||||
patterns
|
||||
.add<ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(
|
||||
typeConverter, context);
|
||||
|
|
|
@ -40,6 +40,48 @@ func.func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo
|
|||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ne.bool(
|
||||
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
|
||||
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
|
||||
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
|
||||
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
|
||||
// CHECK: %[[XOR:.*]] = arith.xori %[[LHS]], %[[RHS]] : i1
|
||||
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[XOR]]
|
||||
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
|
||||
func.func @torch.aten.ne.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
|
||||
%0 = torch.aten.ne.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.__and__.bool(
|
||||
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
|
||||
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
|
||||
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
|
||||
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
|
||||
// CHECK: %[[AND:.*]] = arith.andi %[[LHS]], %[[RHS]] : i1
|
||||
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[AND]]
|
||||
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
|
||||
func.func @torch.aten.__and__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
|
||||
%0 = torch.aten.__and__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.__or__.bool(
|
||||
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
|
||||
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
|
||||
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
|
||||
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
|
||||
// CHECK: %[[OR:.*]] = arith.ori %[[LHS]], %[[RHS]] : i1
|
||||
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[OR]]
|
||||
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
|
||||
func.func @torch.aten.__or__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
|
||||
%0 = torch.aten.__or__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.eq.int(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
|
||||
|
|
Loading…
Reference in New Issue