mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add folder for aten.any.bool (#2388)
* update * update * update * update * update * update * updatepull/2426/head
parent
1682b540bf
commit
17d02811d5
|
@ -10919,6 +10919,7 @@ def Torch_AtenAnyBoolOp : Torch_Op<"aten.any.bool", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [
|
||||
|
|
|
@ -1434,6 +1434,24 @@ OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenAnyBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
|
||||
auto inputConstruct = getSelf().getDefiningOp<Torch::PrimListConstructOp>();
|
||||
if (!inputConstruct || isListPotentiallyMutated(inputConstruct))
|
||||
return nullptr;
|
||||
// If any operand is a constant true, return true.
|
||||
for (auto operand : inputConstruct.getOperands()) {
|
||||
bool b;
|
||||
if (matchPattern(operand, m_TorchConstantBool(&b)) && b) {
|
||||
return getI1IntegerAttr(getContext(), true);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenFloatScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -646,7 +646,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])", has_canonicalizer=True)
|
||||
emit("aten::insert.t : (t[], int, t) -> ()")
|
||||
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
|
||||
emit("aten::any.bool : (bool[]) -> (bool)")
|
||||
emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True)
|
||||
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
|
||||
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
|
||||
|
|
|
@ -259,27 +259,6 @@ func.func @torch.aten.sqrt.int(%arg0: !torch.int) -> !torch.float {
|
|||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.any.bool() -> !torch.bool {
|
||||
// CHECK: %[[CST_FALSE:.*]] = arith.constant false
|
||||
// CHECK: %[[FALSE:.*]] = torch_c.from_i1 %[[CST_FALSE]]
|
||||
// CHECK: %[[CST_TRUE:.*]] = arith.constant true
|
||||
// CHECK: %[[TRUE:.*]] = torch_c.from_i1 %[[CST_TRUE]]
|
||||
// CHECK: %[[INPUT:.*]] = torch.prim.ListConstruct %[[FALSE]], %[[TRUE]], %[[FALSE]] : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
|
||||
// CHECK: %[[TMP1:.*]] = torch_c.to_i1 %[[FALSE]]
|
||||
// CHECK: %[[TMP2:.*]] = torch_c.to_i1 %[[TRUE]]
|
||||
// CHECK: %[[TMP3:.*]] = torch_c.to_i1 %[[FALSE]]
|
||||
// CHECK: %[[CMP:.*]] = arith.ori %[[TMP1]], %[[TMP2]] : i1
|
||||
// CHECK: %[[CMP_RESULT:.*]] = arith.ori %[[CMP]], %[[TMP3]] : i1
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_i1 %[[CMP_RESULT]]
|
||||
// CHECK: return %[[RESULT]] : !torch.bool
|
||||
func.func @torch.aten.any.bool() -> !torch.bool {
|
||||
%false = torch.constant.bool false
|
||||
%true = torch.constant.bool true
|
||||
%input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
|
||||
%0 = torch.aten.any.bool %input : !torch.list<bool> -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.Bool.float(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.bool {
|
||||
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
|
||||
|
|
|
@ -2087,3 +2087,14 @@ func.func @torch.aten.add$fold() -> !torch.float {
|
|||
%0 = torch.aten.add %float1, %float2 : !torch.float, !torch.float -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.any.bool$fold() -> !torch.bool {
|
||||
// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[CST_TRUE]] : !torch.bool
|
||||
func.func @torch.aten.any.bool$fold() -> !torch.bool {
|
||||
%false = torch.constant.bool false
|
||||
%true = torch.constant.bool true
|
||||
%input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
|
||||
%0 = torch.aten.any.bool %input : !torch.list<bool> -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
Loading…
Reference in New Issue