[Torch Dialect] add folder for aten.any.bool (#2388)

* update

* update

* update

* update

* update

* update

* update
pull/2426/head
JianzheXiao 2023-08-30 17:29:03 +08:00 committed by GitHub
parent 1682b540bf
commit 17d02811d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 22 deletions

View File

@ -10919,6 +10919,7 @@ def Torch_AtenAnyBoolOp : Torch_Op<"aten.any.bool", [
printDefaultTorchOp(printer, *this, 1, 1); printDefaultTorchOp(printer, *this, 1, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [ def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [

View File

@ -1434,6 +1434,24 @@ OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenAnyBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
auto inputConstruct = getSelf().getDefiningOp<Torch::PrimListConstructOp>();
if (!inputConstruct || isListPotentiallyMutated(inputConstruct))
return nullptr;
// If any operand is a constant true, return true.
for (auto operand : inputConstruct.getOperands()) {
bool b;
if (matchPattern(operand, m_TorchConstantBool(&b)) && b) {
return getI1IntegerAttr(getContext(), true);
}
}
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenFloatScalarOp // AtenFloatScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -646,7 +646,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])", has_canonicalizer=True) emit("aten::slice.t : (t[], int?, int?, int) -> (t[])", has_canonicalizer=True)
emit("aten::insert.t : (t[], int, t) -> ()") emit("aten::insert.t : (t[], int, t) -> ()")
emit("aten::ne.int_list : (int[], int[]) -> (bool)") emit("aten::ne.int_list : (int[], int[]) -> (bool)")
emit("aten::any.bool : (bool[]) -> (bool)") emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True)
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")

View File

@ -259,27 +259,6 @@ func.func @torch.aten.sqrt.int(%arg0: !torch.int) -> !torch.float {
return %0 : !torch.float return %0 : !torch.float
} }
// CHECK-LABEL: func.func @torch.aten.any.bool() -> !torch.bool {
// CHECK: %[[CST_FALSE:.*]] = arith.constant false
// CHECK: %[[FALSE:.*]] = torch_c.from_i1 %[[CST_FALSE]]
// CHECK: %[[CST_TRUE:.*]] = arith.constant true
// CHECK: %[[TRUE:.*]] = torch_c.from_i1 %[[CST_TRUE]]
// CHECK: %[[INPUT:.*]] = torch.prim.ListConstruct %[[FALSE]], %[[TRUE]], %[[FALSE]] : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
// CHECK: %[[TMP1:.*]] = torch_c.to_i1 %[[FALSE]]
// CHECK: %[[TMP2:.*]] = torch_c.to_i1 %[[TRUE]]
// CHECK: %[[TMP3:.*]] = torch_c.to_i1 %[[FALSE]]
// CHECK: %[[CMP:.*]] = arith.ori %[[TMP1]], %[[TMP2]] : i1
// CHECK: %[[CMP_RESULT:.*]] = arith.ori %[[CMP]], %[[TMP3]] : i1
// CHECK: %[[RESULT:.*]] = torch_c.from_i1 %[[CMP_RESULT]]
// CHECK: return %[[RESULT]] : !torch.bool
func.func @torch.aten.any.bool() -> !torch.bool {
%false = torch.constant.bool false
%true = torch.constant.bool true
%input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
%0 = torch.aten.any.bool %input : !torch.list<bool> -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.Bool.float( // CHECK-LABEL: func.func @torch.aten.Bool.float(
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.bool { // CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.bool {
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] // CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]

View File

@ -2087,3 +2087,14 @@ func.func @torch.aten.add$fold() -> !torch.float {
%0 = torch.aten.add %float1, %float2 : !torch.float, !torch.float -> !torch.float %0 = torch.aten.add %float1, %float2 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float return %0 : !torch.float
} }
// CHECK-LABEL: func.func @torch.aten.any.bool$fold() -> !torch.bool {
// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[CST_TRUE]] : !torch.bool
func.func @torch.aten.any.bool$fold() -> !torch.bool {
%false = torch.constant.bool false
%true = torch.constant.bool true
%input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
%0 = torch.aten.any.bool %input : !torch.list<bool> -> !torch.bool
return %0 : !torch.bool
}