diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index d40389072..b3eff6827 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -154,4 +154,6 @@ TOSA_PASS_SET = { "ReshapeCollapseModule_basic", "ElementwiseGeluModule_basic", "GeluBackwardModule_basic", + "ElementwiseNeIntScalarModule_basic", + "ElementwiseNeFloatTensorModule_basic", } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 920d9710c..55df899d6 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -173,6 +173,22 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, b, loc, elementalType, lhs, rhs); } +static Value createEqual(OpBuilder &b, Location loc, Type elementalType, + Value lhs, Value rhs) { + return createComparisonTemplate( + b, loc, elementalType, lhs, rhs); +} + +static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType, + Value lhs, Value rhs) { + return createComparisonTemplate( + b, loc, elementalType, lhs, rhs); +} + static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean, Value sigma) { Type elementType = x.getType(); @@ -607,9 +623,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); - if (dtype.isa()) - return b.create(loc, arith::CmpFPredicate::UEQ, - payloadArgs[0], otherPromoted); if (dtype.isa()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. @@ -617,11 +630,24 @@ static Value createLinalgPayloadCalculationForElementwiseOp( "unimplemented: type promotion from tensor to scalar"); return nullptr; } - return b.create(loc, arith::CmpIPredicate::eq, - payloadArgs[0], otherPromoted); } - eqScalar.emitError("unimplemented: dtype isn't supported"); - return nullptr; + return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted); + } + + if (auto neScalar = dyn_cast(op)) { + Type dtype = neScalar.self().getType().cast().getDtype(); + Value otherPromoted = + convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); + + if (dtype.isa()) { + if (!operands[1].getType().isa()) { + // TODO: Promote tensor operand from integer to float. + neScalar.emitError( + "unimplemented: type promotion from tensor to scalar"); + return nullptr; + } + } + return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted); } if (auto ltScalar = dyn_cast(op)) { @@ -629,8 +655,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); - // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share a - // lot of code that can be refactored. + // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share + // a lot of code that can be refactored. if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::ULT, payloadArgs[0], otherPromoted); @@ -657,8 +683,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); - // TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code that - // can be refactored. + // TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code + // that can be refactored. if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::ULE, payloadArgs[0], otherPromoted); @@ -908,7 +934,8 @@ public: AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp>(op)) + AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1626,7 +1653,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, - AtenSinOp, AtenCosOp>(); + AtenSinOp, AtenCosOp, AtenNeScalarOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index dae6023f8..5ef85241b 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -2367,7 +2367,7 @@ module { torch.prim.If.yield %int0 : !torch.int } else { %25 = torch.aten.__getitem__.t %arg0, %2 : !torch.list, !torch.int -> !torch.int - %26 = torch.aten.ge.int %12, %25 : !torch.int, !torch.int -> !torch.bool + %26 = torch.aten.gt.int %12, %25 : !torch.int, !torch.int -> !torch.bool %27 = torch.prim.If %26 -> (!torch.int) { %28 = torch.aten.__getitem__.t %arg0, %2 : !torch.list, !torch.int -> !torch.int torch.prim.If.yield %28 : !torch.int diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 56f8137af..247dd8bc1 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -457,3 +457,41 @@ class ElementwiseEqIntTensorModule(torch.nn.Module): def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils): module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, ))) +# ============================================================================== + +class ElementwiseNeFloatScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ne(x, 2.0) + + +@register_test_case(module_factory=lambda: ElementwiseNeFloatScalarModule()) +def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32)) + +# ============================================================================== + +class ElementwiseNeIntScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.ne(x, 3) + + +@register_test_case(module_factory=lambda: ElementwiseNeIntScalarModule()) +def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(2, 4, (8, 5)))