diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index e21f2d1ec..d5de9a6ed 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -419,6 +419,194 @@ def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseLtFloatScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.lt(x, 0.6) + + +@register_test_case(module_factory=lambda: ElementwiseLtFloatScalarModule()) +def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + + +class ElementwiseLtIntScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.lt(x, 0) + + +@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule()) +def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 15, (3,4))) + + +class ElementwiseLtDiffWidthScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.lt(x, 2) + + +@register_test_case(module_factory=lambda: ElementwiseLtDiffWidthScalarModule()) +def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32)) + + +class ElementwiseLtFloatTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.lt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule()) +def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5), tu.rand(5)) + + +class ElementwiseLtIntTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.lt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule()) +def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5,))) + +# ============================================================================== + +class ElementwiseEqFloatScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.eq(x, 6.0) + + +@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule()) +def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]) + .to(torch.float32)) + + +class ElementwiseEqIntScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.eq(x, 2) + + +@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule()) +def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(2, 4, (5,8))) + + +class ElementwiseEqDiffWidthScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.eq(x, 2) + + +@register_test_case(module_factory=lambda: ElementwiseEqDiffWidthScalarModule()) +def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(2, 4, (5,8)).to(torch.int32)) + + +class ElementwiseEqFloatTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.eq(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule()) +def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]) + .to(torch.float32), + torch.tensor([1.0, 2.4, 6.0]).to(torch.float32)) + + +class ElementwiseEqIntTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.eq(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule()) +def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils): + module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5,))) + +# ============================================================================== class ElementwiseClampModule(torch.nn.Module): def __init__(self): diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 4248fcb91..fce289490 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -570,6 +570,36 @@ def Torch_AtenGt_TensorOp : Torch_Op<"aten.gt_.Tensor", [ let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; } +def Torch_AtenLtTensorOp : Torch_Op<"aten.lt.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenLt_TensorOp : Torch_Op<"aten.lt_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::lt_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + def Torch_AtenNeTensorOp : Torch_Op<"aten.ne.Tensor", [ AllowsTypeRefinement, HasValueSemantics @@ -844,6 +874,36 @@ def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [ let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; } +def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + +def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::lt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 92ccd5ce3..7e558a84f 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1691,8 +1691,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. - if (lhsDtype != rhsDtype) + if (lhsDtype != rhsDtype) { gtTensor.emitError("unimplemented: different lhs and rhs dtype"); + return nullptr; + } Type elementalType = gtTensor.self().getType().cast().getDtype(); @@ -1709,6 +1711,61 @@ static Value createLinalgPayloadCalculationForElementwiseOp( payloadArgs[0], payloadArgs[1]); } gtTensor.emitError("unimplemented: dtype isn't supported."); + return nullptr; + } + if (auto eqTensor = dyn_cast(op)) { + AtenEqTensorOp::Adaptor adaptor(operands); + Type lhsDtype = payloadArgs[0].getType(); + Type rhsDtype = payloadArgs[1].getType(); + + // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs + // to be handled. + if (lhsDtype != rhsDtype) { + eqTensor.emitError("unimplemented: lhs and rhs dtype must be same"); + return nullptr; + } + + Type elementalType = + eqTensor.self().getType().cast().getDtype(); + + if (elementalType.isa()) + return b.create(loc, arith::CmpFPredicate::UEQ, + payloadArgs[0], payloadArgs[1]); + if (elementalType.isa()) { + return b.create(loc, arith::CmpIPredicate::eq, + payloadArgs[0], payloadArgs[1]); + } + eqTensor.emitError("unimplemented: dtype isn't supported."); + return nullptr; + } + if (auto ltTensor = dyn_cast(op)) { + AtenLtTensorOp::Adaptor adaptor(operands); + Type lhsDtype = payloadArgs[0].getType(); + Type rhsDtype = payloadArgs[1].getType(); + + // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs + // to be handled. + if (lhsDtype != rhsDtype) { + ltTensor.emitError("unimplemented: lhs and rhs dtype must be same"); + return nullptr; + } + + Type elementalType = + ltTensor.self().getType().cast().getDtype(); + + if (elementalType.isa()) + return b.create(loc, arith::CmpFPredicate::ULT, + payloadArgs[0], payloadArgs[1]); + if (IntegerType intType = elementalType.dyn_cast()) { + if (intType.isUnsigned()) + return b.create(loc, arith::CmpIPredicate::ult, + payloadArgs[0], payloadArgs[1]); + if (intType.isSigned()) + return b.create(loc, arith::CmpIPredicate::slt, + payloadArgs[0], payloadArgs[1]); + } + ltTensor.emitError("unimplemented: dtype isn't supported."); + return nullptr; } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); @@ -1764,6 +1821,56 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } + if (auto eqScalar = dyn_cast(op)) { + Type dtype = eqScalar.self().getType().cast().getDtype(); + Value otherPromoted = + convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); + + if (dtype.isa()) + return b.create(loc, arith::CmpFPredicate::UEQ, + payloadArgs[0], otherPromoted); + if (dtype.isa()) { + if (!operands[1].getType().isa()) { + // TODO: Promote tensor operand from integer to float. + eqScalar.emitError( + "unimplemented: type promotion from tensor to scalar"); + return nullptr; + } + return b.create(loc, arith::CmpIPredicate::eq, + payloadArgs[0], otherPromoted); + } + eqScalar.emitError("unimplemented: dtype isn't supported"); + return nullptr; + } + + if (auto ltScalar = dyn_cast(op)) { + Type dtype = ltScalar.self().getType().cast().getDtype(); + Value otherPromoted = + convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); + + // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share a + // lot of code that can be refactored. + if (dtype.isa()) + return b.create(loc, arith::CmpFPredicate::ULT, + payloadArgs[0], otherPromoted); + if (IntegerType intType = dtype.dyn_cast()) { + if (!operands[1].getType().isa()) { + // TODO: Promote tensor operand from integer to float. + ltScalar.emitError( + "unimplemented: type promotion from tensor to scalar"); + return nullptr; + } + if (intType.isUnsigned()) + return b.create(loc, arith::CmpIPredicate::ult, + payloadArgs[0], otherPromoted); + if (intType.isSigned()) + return b.create(loc, arith::CmpIPredicate::slt, + payloadArgs[0], otherPromoted); + } + ltScalar.emitError("unimplemented: dtype isn't supported."); + return nullptr; + } + if (auto whereSelf = dyn_cast(op)) { Type dtype = converter->convertType(whereSelf.getType()) .cast() @@ -2130,8 +2237,9 @@ struct ConvertElementwiseOp : ConversionPattern { AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenWhereSelfOp, - AtenCeilOp, AtenGtTensorOp>(op)) + AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp, + AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, + AtenEqTensorOp, AtenLtTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -3788,7 +3896,8 @@ public: AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, - AtenWhereSelfOp, AtenGtTensorOp>(); + AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp, + AtenEqTensorOp, AtenLtTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 80bf0b395..d895dd027 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -247,8 +247,8 @@ public: } // These comparison ops return a tensor with 1-bit integer dtype. - if (isa( - op)) { + if (isa(op)) { auto operand = operands[0]->getValue(); auto knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); @@ -317,10 +317,10 @@ public: op)) { return visitBinaryTensorScalarOp(op, operands); } else if (isa(op)) { + AtenDivTensorOp, Aten__And__TensorOp, AtenMinimumOp, + AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) { return visitBinaryBroadcastingOp(op, operands); - } else if (isa(op)) { + } else if (isa(op)) { return visitBinaryBroadcastingComparisonOp(op, operands); } else if (auto whereSelf = llvm::dyn_cast(op)) { return visitAtenWhereSelfOp(whereSelf, operands); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 917b0cfd1..f44f8e98e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -461,6 +461,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", @@ -470,6 +471,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", + "aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",