mirror of https://github.com/llvm/torch-mlir
parent
a8fd275a00
commit
27b55b1d5f
|
@ -567,6 +567,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
||||||
|
} else if(dtype.isa<mlir::ComplexType>()) {
|
||||||
|
return b.create<complex::MulOp>(loc, lhs, rhs);
|
||||||
} else {
|
} else {
|
||||||
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1042,6 +1042,28 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseMulTensorComplexModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.complex64, True),
|
||||||
|
([-1], torch.complex64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.mul(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexModule())
|
||||||
|
def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(4, high=10).type(torch.complex64), tu.randint(4, high=10).type(torch.complex64))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ElementwiseMishModule(torch.nn.Module):
|
class ElementwiseMishModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue