mirror of https://github.com/llvm/torch-mlir
parent
a8fd275a00
commit
27b55b1d5f
|
@ -567,6 +567,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
||||
} else if(dtype.isa<mlir::ComplexType>()) {
|
||||
return b.create<complex::MulOp>(loc, lhs, rhs);
|
||||
} else {
|
||||
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||
}
|
||||
|
|
|
@ -1042,6 +1042,28 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMulTensorComplexModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.complex64, True),
|
||||
([-1], torch.complex64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.mul(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexModule())
|
||||
def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(4, high=10).type(torch.complex64), tu.randint(4, high=10).type(torch.complex64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMishModule(torch.nn.Module):
|
||||
|
||||
|
|
Loading…
Reference in New Issue