From 27b55b1d5fbf115bf0c38e4a610a9268e654e353 Mon Sep 17 00:00:00 2001 From: Bruce Kim <92174982+brucekimrokcmu@users.noreply.github.com> Date: Thu, 7 Sep 2023 13:29:15 -0700 Subject: [PATCH] implemented complex tensor aten mul (#2444) --- .../TorchToLinalg/Uncategorized.cpp | 2 ++ .../test_suite/elementwise.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8684b68d9..1d25d2272 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -567,6 +567,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (dtype.isa()) { return b.create(loc, lhs, rhs); + } else if(dtype.isa()) { + return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); } diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index f81aae63a..a31e4c41a 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1042,6 +1042,28 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseMulTensorComplexModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.complex64, True), + ([-1], torch.complex64, True), + ]) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexModule()) +def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, high=10).type(torch.complex64), tu.randint(4, high=10).type(torch.complex64)) + + +# ============================================================================== class ElementwiseMishModule(torch.nn.Module):