mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add folder for aten.any.bool (#2388)
* update * update * update * update * update * update * updatepull/2426/head
parent
1682b540bf
commit
17d02811d5
|
@ -10919,6 +10919,7 @@ def Torch_AtenAnyBoolOp : Torch_Op<"aten.any.bool", [
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [
|
def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [
|
||||||
|
|
|
@ -1434,6 +1434,24 @@ OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenAnyBoolOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto inputConstruct = getSelf().getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
|
if (!inputConstruct || isListPotentiallyMutated(inputConstruct))
|
||||||
|
return nullptr;
|
||||||
|
// If any operand is a constant true, return true.
|
||||||
|
for (auto operand : inputConstruct.getOperands()) {
|
||||||
|
bool b;
|
||||||
|
if (matchPattern(operand, m_TorchConstantBool(&b)) && b) {
|
||||||
|
return getI1IntegerAttr(getContext(), true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenFloatScalarOp
|
// AtenFloatScalarOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -646,7 +646,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])", has_canonicalizer=True)
|
emit("aten::slice.t : (t[], int?, int?, int) -> (t[])", has_canonicalizer=True)
|
||||||
emit("aten::insert.t : (t[], int, t) -> ()")
|
emit("aten::insert.t : (t[], int, t) -> ()")
|
||||||
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
|
emit("aten::ne.int_list : (int[], int[]) -> (bool)")
|
||||||
emit("aten::any.bool : (bool[]) -> (bool)")
|
emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True)
|
||||||
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
|
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
|
||||||
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)")
|
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||||
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
|
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
|
||||||
|
|
|
@ -259,27 +259,6 @@ func.func @torch.aten.sqrt.int(%arg0: !torch.int) -> !torch.float {
|
||||||
return %0 : !torch.float
|
return %0 : !torch.float
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.any.bool() -> !torch.bool {
|
|
||||||
// CHECK: %[[CST_FALSE:.*]] = arith.constant false
|
|
||||||
// CHECK: %[[FALSE:.*]] = torch_c.from_i1 %[[CST_FALSE]]
|
|
||||||
// CHECK: %[[CST_TRUE:.*]] = arith.constant true
|
|
||||||
// CHECK: %[[TRUE:.*]] = torch_c.from_i1 %[[CST_TRUE]]
|
|
||||||
// CHECK: %[[INPUT:.*]] = torch.prim.ListConstruct %[[FALSE]], %[[TRUE]], %[[FALSE]] : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
|
|
||||||
// CHECK: %[[TMP1:.*]] = torch_c.to_i1 %[[FALSE]]
|
|
||||||
// CHECK: %[[TMP2:.*]] = torch_c.to_i1 %[[TRUE]]
|
|
||||||
// CHECK: %[[TMP3:.*]] = torch_c.to_i1 %[[FALSE]]
|
|
||||||
// CHECK: %[[CMP:.*]] = arith.ori %[[TMP1]], %[[TMP2]] : i1
|
|
||||||
// CHECK: %[[CMP_RESULT:.*]] = arith.ori %[[CMP]], %[[TMP3]] : i1
|
|
||||||
// CHECK: %[[RESULT:.*]] = torch_c.from_i1 %[[CMP_RESULT]]
|
|
||||||
// CHECK: return %[[RESULT]] : !torch.bool
|
|
||||||
func.func @torch.aten.any.bool() -> !torch.bool {
|
|
||||||
%false = torch.constant.bool false
|
|
||||||
%true = torch.constant.bool true
|
|
||||||
%input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
|
|
||||||
%0 = torch.aten.any.bool %input : !torch.list<bool> -> !torch.bool
|
|
||||||
return %0 : !torch.bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.Bool.float(
|
// CHECK-LABEL: func.func @torch.aten.Bool.float(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.bool {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.bool {
|
||||||
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
|
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
|
||||||
|
|
|
@ -2087,3 +2087,14 @@ func.func @torch.aten.add$fold() -> !torch.float {
|
||||||
%0 = torch.aten.add %float1, %float2 : !torch.float, !torch.float -> !torch.float
|
%0 = torch.aten.add %float1, %float2 : !torch.float, !torch.float -> !torch.float
|
||||||
return %0 : !torch.float
|
return %0 : !torch.float
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.any.bool$fold() -> !torch.bool {
|
||||||
|
// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: return %[[CST_TRUE]] : !torch.bool
|
||||||
|
func.func @torch.aten.any.bool$fold() -> !torch.bool {
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
|
||||||
|
%0 = torch.aten.any.bool %input : !torch.list<bool> -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
Loading…
Reference in New Issue