diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 0c9ae4cbe..e21f2d1ec 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -326,7 +326,7 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseGtScalarModule(torch.nn.Module): +class ElementwiseGtFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @@ -339,10 +339,47 @@ class ElementwiseGtScalarModule(torch.nn.Module): return torch.gt(x, 0.6) -@register_test_case(module_factory=lambda: ElementwiseGtScalarModule()) -def ElementwiseGtScalarModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule()) +def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + +class ElementwiseGtIntScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.gt(x, 10) + + +@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule()) +def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 15, (3,4))) + + +class ElementwiseGtMixed2ScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.gt(x, 7) + + +@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule()) +def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32)) + + class ElementwiseGtFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @@ -361,6 +398,7 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module): def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(5)) + class ElementwiseGtIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index f4483bef5..7b832dcac 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1735,14 +1735,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto gtScalar = dyn_cast(op)) { - Type dtype = gtScalar.self().getType().cast().getDtype(); - if (!dtype.isa()) { - gtScalar.emitError("unimplemented: non-floating point operand dtype"); - return nullptr; + Type dtype = gtScalar.self().getType().cast().getDtype(); + + // TODO: `gtTensor` and `gtScalar` share similar code and can be called from + // one static function. + Value otherPromoted = + convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); + + if (dtype.isa()) + return b.create(loc, arith::CmpFPredicate::UGT, + payloadArgs[0], otherPromoted); + if (IntegerType intType = dtype.dyn_cast()) { + if (!operands[1].getType().isa()) { + // TODO: Promote tensor args from integer to float. + gtScalar.emitError( + "unimplemented: type promotion from tensor to scalar."); + return nullptr; + } + + if (intType.isUnsigned()) + return b.create(loc, arith::CmpIPredicate::ugt, + payloadArgs[0], otherPromoted); + if (intType.isSigned()) + return b.create(loc, arith::CmpIPredicate::sgt, + payloadArgs[0], otherPromoted); } - Value otherPromoted = convertScalarToDtype(b, loc, operands[1], dtype); - return b.create(loc, arith::CmpFPredicate::UGT, - payloadArgs[0], otherPromoted); + gtScalar.emitError("unimplemented: dtype isn't supported."); + return nullptr; } if (auto whereSelf = dyn_cast(op)) {