implemented complex tensor aten mul (#2444)

pull/2446/head snapshot-20230908.955
Bruce Kim 2023-09-07 13:29:15 -07:00 committed by GitHub
parent a8fd275a00
commit 27b55b1d5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 0 deletions

View File

@ -567,6 +567,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) { if (dtype.isa<mlir::FloatType>()) {
return b.create<arith::MulFOp>(loc, lhs, rhs); return b.create<arith::MulFOp>(loc, lhs, rhs);
} else if(dtype.isa<mlir::ComplexType>()) {
return b.create<complex::MulOp>(loc, lhs, rhs);
} else { } else {
return b.create<arith::MulIOp>(loc, lhs, rhs); return b.create<arith::MulIOp>(loc, lhs, rhs);
} }

View File

@ -1042,6 +1042,28 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseMulTensorComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.complex64, True),
([-1], torch.complex64, True),
])
def forward(self, a, b):
return torch.mul(a, b)
@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexModule())
def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.complex64), tu.randint(4, high=10).type(torch.complex64))
# ==============================================================================
class ElementwiseMishModule(torch.nn.Module): class ElementwiseMishModule(torch.nn.Module):