From 6dabf185f5c819fddcfa167c82833e83fa9b90a1 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Mon, 13 Dec 2021 21:02:47 +0530 Subject: [PATCH] Add support for int types in gtScalar op. Support for integer types in gtScalar op has been added. The code share same logic with gtTensor op and can be merged which is added as a TODO. --- e2e_testing/torchscript/elementwise.py | 44 +++++++++++++++++-- .../TorchToLinalg/TorchToLinalg.cpp | 33 +++++++++++--- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 0c9ae4cbe..e21f2d1ec 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -326,7 +326,7 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseGtScalarModule(torch.nn.Module): +class ElementwiseGtFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @@ -339,10 +339,47 @@ class ElementwiseGtScalarModule(torch.nn.Module): return torch.gt(x, 0.6) -@register_test_case(module_factory=lambda: ElementwiseGtScalarModule()) -def ElementwiseGtScalarModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule()) +def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + +class ElementwiseGtIntScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return torch.gt(x, 10) + + +@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule()) +def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 15, (3,4))) + + +class ElementwiseGtMixed2ScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + return torch.gt(x, 7) + + +@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule()) +def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32)) + + class ElementwiseGtFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @@ -361,6 +398,7 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module): def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(5)) + class ElementwiseGtIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index f4483bef5..7b832dcac 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1735,14 +1735,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto gtScalar = dyn_cast(op)) { - Type dtype = gtScalar.self().getType().cast().getDtype(); - if (!dtype.isa()) { - gtScalar.emitError("unimplemented: non-floating point operand dtype"); - return nullptr; + Type dtype = gtScalar.self().getType().cast().getDtype(); + + // TODO: `gtTensor` and `gtScalar` share similar code and can be called from + // one static function. + Value otherPromoted = + convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); + + if (dtype.isa()) + return b.create(loc, arith::CmpFPredicate::UGT, + payloadArgs[0], otherPromoted); + if (IntegerType intType = dtype.dyn_cast()) { + if (!operands[1].getType().isa()) { + // TODO: Promote tensor args from integer to float. + gtScalar.emitError( + "unimplemented: type promotion from tensor to scalar."); + return nullptr; + } + + if (intType.isUnsigned()) + return b.create(loc, arith::CmpIPredicate::ugt, + payloadArgs[0], otherPromoted); + if (intType.isSigned()) + return b.create(loc, arith::CmpIPredicate::sgt, + payloadArgs[0], otherPromoted); } - Value otherPromoted = convertScalarToDtype(b, loc, operands[1], dtype); - return b.create(loc, arith::CmpFPredicate::UGT, - payloadArgs[0], otherPromoted); + gtScalar.emitError("unimplemented: dtype isn't supported."); + return nullptr; } if (auto whereSelf = dyn_cast(op)) {