diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b7475a8a8..aa2566711 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15548,6 +15548,31 @@ def Torch_Aten__Not__Op : Torch_Op<"aten.__not__", [ let hasFolder = 1; } +def Torch_Aten__Or__BoolOp : Torch_Op<"aten.__or__.bool", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__or__.bool : (bool, bool) -> (bool)`"; + let arguments = (ins + Torch_BoolType:$a, + Torch_BoolType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Or__BoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Or__BoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a2208a797..ca46ca62f 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -732,6 +732,21 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { return IntegerAttr::get(IntegerType::get(getContext(), 1), !value); } +//===----------------------------------------------------------------------===// +// Aten__Or__Op +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) { + auto valueA = dyn_cast_or_null(adaptor.getA()); + auto valueB = dyn_cast_or_null(adaptor.getB()); + if (!valueA || !valueB) { + return nullptr; + } + + return IntegerAttr::get(IntegerType::get(getContext(), 1), + valueA.getValue() | valueB.getValue()); +} + //===----------------------------------------------------------------------===// // AtenNeBoolOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fb215a303..c54a9023b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -855,6 +855,7 @@ STABLEHLO_PASS_SET = { "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", "AddIntModule_basic", "AliasModule_basic", + "TrueFalseOrBoolOpModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", @@ -1576,6 +1577,7 @@ TOSA_PASS_SET = { "AtenInstanceNormModule_basic", "AtenToDeviceModule_basic", "Aten_CastFloatModule_basic", + "TrueFalseOrBoolOpModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", "BaddbmmDynamicModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 47c4b721c..30758f457 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1077,6 +1077,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__not__ : (bool) -> (bool)", has_folder=True) + emit("aten::__or__.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 5576e850a..3dacb9872 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -528,3 +528,21 @@ class AtenItemFpOpModule(torch.nn.Module): @register_test_case(module_factory=lambda: AtenItemFpOpModule()) def AtenItemFpOpModule_basic(module, tu: TestUtils): module.forward(tu.rand(1)) + + +# ============================================================================== + + +class TrueFalseOrBoolOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([], torch.bool, True), ([], torch.bool, True)]) + def forward(self, a, b): + return a | b + + +@register_test_case(module_factory=lambda: TrueFalseOrBoolOpModule()) +def TrueFalseOrBoolOpModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=0, high=1).bool(), tu.randint(low=1, high=2).bool())