mirror of https://github.com/llvm/torch-mlir
[linalg] Fix bug for conversion of complex dtype (#3269)
The conversion of complex type wasn't supported or checked; the support and required tests were added. Fixes: https://github.com/iree-org/iree/issues/17226#issuecomment-2087779158pull/3273/head
parent
0a2d21b108
commit
8c48135a42
|
@ -10,6 +10,7 @@
|
|||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
||||
}
|
||||
|
||||
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
|
||||
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
|
||||
auto dtypeElemType = dtypeComplex.getElementType();
|
||||
|
||||
// Extract the real and imaginary parts of the scalar.
|
||||
// Cast them to the target element type, and create a new complex
|
||||
// value with the target complex type.
|
||||
Value realVal = b.create<complex::ReOp>(loc, scalar);
|
||||
Value imgVal = b.create<complex::ImOp>(loc, scalar);
|
||||
|
||||
realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType);
|
||||
imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType);
|
||||
|
||||
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
||||
}
|
||||
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype
|
||||
<< "(dtype)";
|
||||
}
|
||||
|
||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
}
|
||||
|
||||
|
|
|
@ -575,6 +575,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseLogitModule_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
|
@ -2314,6 +2315,7 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwiseMulTensorComplexDiffModule_basic",
|
||||
"ElementwiseOrTensorModule_basic",
|
||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
|
|
|
@ -1839,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
# torch.complex32 is not supported by the refbackend.
|
||||
class ElementwiseMulTensorComplexDiffModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1], torch.complex64, True),
|
||||
([-1], torch.complex128, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return torch.mul(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule())
|
||||
def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(4, high=10).type(torch.complex64),
|
||||
tu.randint(4, high=10).type(torch.complex128),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMishModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue