diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7300d5310..aad26c846 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -982,6 +982,53 @@ def Torch_AtenDiv_TensorOp : Torch_Op<"aten.div_.Tensor", [ }]; } +def Torch_AtenLogicalOrOp : Torch_Op<"aten.logical_or", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::logical_or : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogicalOrOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLogicalOrOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::logical_or_ : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogicalOr_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLogicalOr_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 01bdbc41b..2111145a7 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -190,6 +190,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto logicalOr = dyn_cast(op)) { + MLIRContext *context = op->getContext(); + Type floatDtype = mlir::FloatType::getF64(context); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype); + Value zero = + b.create(loc, b.getFloatAttr(floatDtype, 0)); + Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero); + Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero); + return b.create(loc, lhsTest, rhsTest); + } if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { @@ -844,7 +855,8 @@ public: AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp>(op)) + AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, + AtenLogicalOrOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1581,7 +1593,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, - AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp>(); + AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, + AtenLogicalOrOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 87e6c8a5e..3caa86655 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -673,7 +673,7 @@ ChangeResult TypeAnalyzer::visitOperation( // Dtype is always i1. if (isa(op)) { + AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = IntegerType::get(op->getContext(), 1); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 2b644628a..5a4281e35 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -2155,6 +2155,10 @@ module { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.logical_or"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.threshold"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 0e0cb7382..41ed2d177 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -710,6 +710,9 @@ def aten〇maximum(self: List[int], other: List[int]) -> List[int]: def aten〇bitwise_and〇Tensor(self: List[int], other: List[int]) -> List[int]: return upstream_shape_helpers.broadcast(self, other) +def aten〇logical_or(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_helpers.broadcast(self, other) + def aten〇threshold(self: List[int], threshold: float, value: float) -> List[int]: return upstream_shape_helpers.unary(self) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 6ccd25f18..3b8302b6f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -249,6 +249,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::logical_or : (Tensor, Tensor) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index b959ce05c..08b7b9793 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1302,3 +1302,156 @@ class ElementwiseNegModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseNegModule()) def ElementwiseNegModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ElementwiseAtenLogicalOrOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.bool, True), + ([-1], torch.bool, True), + ]) + def forward(self, x, y): + return torch.ops.aten.logical_or(x, y) + +@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpModule()) +def ElementwiseAtenLogicalOrOpModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([False, True]), torch.tensor([False, False])) + +class ElementwiseAtenLogicalOrOpDiffArgs1Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.logical_or(x, y) + +@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs1Module()) +def ElementwiseAtenLogicalOrOpDiffArgs1Module_basic(module, tu: TestUtils): + module.forward(torch.tensor([0.2, 0.1]), torch.tensor([0, 1])) + +# ============================================================================== + +class ElementwiseAtenLogicalOrOpDiffArgs2Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.bool, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.logical_or(x, y) + +@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs2Module()) +def ElementwiseAtenLogicalOrOpDiffArgs2Module_basic(module, tu: TestUtils): + module.forward(torch.tensor([True, False]), torch.tensor([0, 1])) + +# ============================================================================== + +class ElementwiseAtenLogicalOrOpDiffArgs3Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1], torch.bool, True), + ]) + def forward(self, x, y): + return torch.ops.aten.logical_or(x, y) + +@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs3Module()) +def ElementwiseAtenLogicalOrOpDiffArgs3Module_basic(module, tu: TestUtils): + module.forward(torch.tensor([1, 2]), torch.tensor([False, True])) + +# ============================================================================== + +class ElementwiseAtenLogicalOrOpRandomModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.int64, True), + ([-1, -1, -1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.logical_or(x, y) + +@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule()) +def ElementwiseAtenLogicalOrOpRandomModule_basic(module, tu: TestUtils): + module.forward(torch.randint(3, 10, (2, 3, 4, 5)), torch.randint(10, 100, (2, 3, 4, 5))) + +# ============================================================================== + +class ElementwiseAtenLogicalOrOpRandomFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.ops.aten.logical_or(x, y) + +@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule()) +def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils): + module.forward(torch.rand(2, 3, 3, 5), torch.rand(2, 3, 3, 5)) + +# ============================================================================== + +class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.int64, True), + ([-1, -1, -1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.logical_or(x, y) + +@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule()) +def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils): + module.forward(torch.neg(torch.randint(3, 10, (2, 3, 4, 5))), torch.neg(torch.randint(10, 100, (2, 3, 4, 5)))) + +# ============================================================================== + +class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.logical_or(x, y) + +@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule()) +def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils): + module.forward(torch.randint(3, (3,)), torch.randint(3, (4, 3))) + + +