diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a43be0f83..1de407a8a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10919,6 +10919,7 @@ def Torch_AtenAnyBoolOp : Torch_Op<"aten.any.bool", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index d346d9db4..a7279d347 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1434,6 +1434,24 @@ OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenAnyBoolOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { + auto inputConstruct = getSelf().getDefiningOp(); + if (!inputConstruct || isListPotentiallyMutated(inputConstruct)) + return nullptr; + // If any operand is a constant true, return true. + for (auto operand : inputConstruct.getOperands()) { + bool b; + if (matchPattern(operand, m_TorchConstantBool(&b)) && b) { + return getI1IntegerAttr(getContext(), true); + } + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenFloatScalarOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 813e69ac3..0d2ae9af8 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -646,7 +646,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::slice.t : (t[], int?, int?, int) -> (t[])", has_canonicalizer=True) emit("aten::insert.t : (t[], int, t) -> ()") emit("aten::ne.int_list : (int[], int[]) -> (bool)") - emit("aten::any.bool : (bool[]) -> (bool)") + emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True) emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 52936c53b..933031e16 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -259,27 +259,6 @@ func.func @torch.aten.sqrt.int(%arg0: !torch.int) -> !torch.float { return %0 : !torch.float } -// CHECK-LABEL: func.func @torch.aten.any.bool() -> !torch.bool { -// CHECK: %[[CST_FALSE:.*]] = arith.constant false -// CHECK: %[[FALSE:.*]] = torch_c.from_i1 %[[CST_FALSE]] -// CHECK: %[[CST_TRUE:.*]] = arith.constant true -// CHECK: %[[TRUE:.*]] = torch_c.from_i1 %[[CST_TRUE]] -// CHECK: %[[INPUT:.*]] = torch.prim.ListConstruct %[[FALSE]], %[[TRUE]], %[[FALSE]] : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list -// CHECK: %[[TMP1:.*]] = torch_c.to_i1 %[[FALSE]] -// CHECK: %[[TMP2:.*]] = torch_c.to_i1 %[[TRUE]] -// CHECK: %[[TMP3:.*]] = torch_c.to_i1 %[[FALSE]] -// CHECK: %[[CMP:.*]] = arith.ori %[[TMP1]], %[[TMP2]] : i1 -// CHECK: %[[CMP_RESULT:.*]] = arith.ori %[[CMP]], %[[TMP3]] : i1 -// CHECK: %[[RESULT:.*]] = torch_c.from_i1 %[[CMP_RESULT]] -// CHECK: return %[[RESULT]] : !torch.bool -func.func @torch.aten.any.bool() -> !torch.bool { - %false = torch.constant.bool false - %true = torch.constant.bool true - %input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %0 = torch.aten.any.bool %input : !torch.list -> !torch.bool - return %0 : !torch.bool -} - // CHECK-LABEL: func.func @torch.aten.Bool.float( // CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.bool { // CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a056aa2f0..88b73ed5c 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2087,3 +2087,14 @@ func.func @torch.aten.add$fold() -> !torch.float { %0 = torch.aten.add %float1, %float2 : !torch.float, !torch.float -> !torch.float return %0 : !torch.float } + +// CHECK-LABEL: func.func @torch.aten.any.bool$fold() -> !torch.bool { +// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[CST_TRUE]] : !torch.bool +func.func @torch.aten.any.bool$fold() -> !torch.bool { + %false = torch.constant.bool false + %true = torch.constant.bool true + %input = torch.prim.ListConstruct %false, %true, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %0 = torch.aten.any.bool %input : !torch.list -> !torch.bool + return %0 : !torch.bool +} \ No newline at end of file