diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index ea43fc4c4..fb4147601 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -784,6 +784,7 @@ def AddCDivModule_basic(module, tu: TestUtils): # ============================================================================== + class DropoutModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index d1872dd03..af9734299 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -363,6 +363,8 @@ def RsubModule_noalpha_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) # ============================================================================== + + class ElementwiseMulScalarModule(torch.nn.Module): def __init__(self): super().__init__() @@ -378,7 +380,52 @@ class ElementwiseMulScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseMulScalarModule()) def ElementwiseMulScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) - + + + +class ElementwiseMulTensorFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float64, True), + ]) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseMulTensorFloatModule()) +def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(4), + tu.rand(4).type(torch.float64)) + +class ElementwiseMulTensorIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.int64, True), + ]) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseMulTensorIntModule()) +def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): + module.forward( + torch.randint(10, [4]).type(torch.int32), + torch.randint(10, [4])) + + # ============================================================================== class ElementwiseLogModule(torch.nn.Module): def __init__(self): @@ -553,7 +600,32 @@ class ElementwiseDivScalarModule(torch.nn.Module): def ElementwiseDivScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + +class ElementwiseDivTensorFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float64, True), + ]) + def forward(self, a, b): + return torch.div(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseDivTensorFloatModule()) +def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(4), + tu.rand(4).type(torch.float64)) + + # ============================================================================== + + class ElementwiseAndIntegerModule(torch.nn.Module): def __init__(self): super().__init__() @@ -573,3 +645,5 @@ class ElementwiseAndIntegerModule(torch.nn.Module): def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32), torch.randint(-10, 10, (3, 4))) + + diff --git a/e2e_testing/torchscript/type_promotion.py b/e2e_testing/torchscript/type_promotion.py index 6cad4ef03..a7a5491c5 100644 --- a/e2e_testing/torchscript/type_promotion.py +++ b/e2e_testing/torchscript/type_promotion.py @@ -111,3 +111,5 @@ class TypePromotionAlphaWiderModule(torch.nn.Module): @register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule()) def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand()) + + diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 1439654ca..451598901 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1531,24 +1531,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } } if (auto mul = dyn_cast(op)) { - if (!mul.getType() - .cast() - .getDtype() - .isa()) { - mul.emitError("unimplemented: non-floating point dtype"); - return nullptr; + AtenMulTensorOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(mul.getType()) + .cast() + .getElementType(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + if (dtype.isa()) { + return b.create(loc, lhs, rhs); + } else { + return b.create(loc, lhs, rhs); } - return b.create(loc, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { - if (!div.getType() - .cast() - .getDtype() - .isa()) { + AtenDivTensorOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(div.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) div.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } - return b.create(loc, payloadArgs[0], payloadArgs[1]); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); } if (auto pow = dyn_cast(op)) { if (!pow.getType()