mirror of https://github.com/llvm/torch-mlir
[LINALG] Added support for conversion from float to complex. (#3595)
parent
b48e55c2f7
commit
2d6bfb2dec
|
@ -350,6 +350,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
|
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
|
||||||
|
|
||||||
|
// Complex to complex.
|
||||||
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
|
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
|
||||||
auto dtypeElemType = dtypeComplex.getElementType();
|
auto dtypeElemType = dtypeComplex.getElementType();
|
||||||
|
|
||||||
|
@ -364,6 +366,27 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
|
|
||||||
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Float to complex type.
|
||||||
|
if (auto dtypeFloat = dyn_cast<mlir::FloatType>(scalarType)) {
|
||||||
|
auto complexElementType =
|
||||||
|
cast<mlir::FloatType>(dtypeComplex.getElementType());
|
||||||
|
Value realVal;
|
||||||
|
Value imgVal =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(complexElementType));
|
||||||
|
|
||||||
|
if (complexElementType.getWidth() > dtypeFloat.getWidth()) {
|
||||||
|
realVal = b.create<arith::ExtFOp>(loc, complexElementType, scalar);
|
||||||
|
} else if (complexElementType.getWidth() < dtypeFloat.getWidth()) {
|
||||||
|
realVal = b.create<arith::TruncFOp>(loc, complexElementType, scalar);
|
||||||
|
;
|
||||||
|
} else {
|
||||||
|
realVal = scalar;
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
||||||
|
}
|
||||||
|
|
||||||
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
|
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
|
||||||
<< scalarType << "(scalar type) -> " << dtype
|
<< scalarType << "(scalar type) -> " << dtype
|
||||||
<< "(dtype)";
|
<< "(dtype)";
|
||||||
|
|
|
@ -1320,6 +1320,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"TensorToFloatZeroRank_basic",
|
"TensorToFloatZeroRank_basic",
|
||||||
"TensorToIntZeroRank_basic",
|
"TensorToIntZeroRank_basic",
|
||||||
"TensorsConcatModule_basic",
|
"TensorsConcatModule_basic",
|
||||||
|
"TensorsConcatComplex128FloatModule_basic",
|
||||||
|
"TensorsConcatComplex64FloatModule_basic",
|
||||||
"TensorsConcatNegativeDimModule_basic",
|
"TensorsConcatNegativeDimModule_basic",
|
||||||
"TensorsConcatNegativeDimStaticModule_basic",
|
"TensorsConcatNegativeDimStaticModule_basic",
|
||||||
"TensorsConcatPromoteDTypeModule_basic",
|
"TensorsConcatPromoteDTypeModule_basic",
|
||||||
|
@ -2598,6 +2600,8 @@ ONNX_XFAIL_SET = {
|
||||||
"SubFloatModule_basic",
|
"SubFloatModule_basic",
|
||||||
"SubIntModule_basic",
|
"SubIntModule_basic",
|
||||||
"TanhBackward_basic",
|
"TanhBackward_basic",
|
||||||
|
"TensorsConcatComplex128FloatModule_basic",
|
||||||
|
"TensorsConcatComplex64FloatModule_basic",
|
||||||
"TensorToBoolZeroRank_basic",
|
"TensorToBoolZeroRank_basic",
|
||||||
"TensorToBool_basic",
|
"TensorToBool_basic",
|
||||||
"TensorToFloatZeroRank_basic",
|
"TensorToFloatZeroRank_basic",
|
||||||
|
|
|
@ -1011,6 +1011,68 @@ def TensorsConcatModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TensorsConcatComplex64FloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.complex64, True),
|
||||||
|
([-1, -1, -1], torch.float64, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float16, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b, c, d):
|
||||||
|
return torch.cat([a, b, c, d], 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorsConcatComplex64FloatModule())
|
||||||
|
def TensorsConcatComplex64FloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.rand(2, 1, 4, low=1, high=10).to(torch.complex64),
|
||||||
|
tu.rand(2, 3, 4, low=1, high=10).to(torch.float64),
|
||||||
|
tu.rand(2, 3, 4, low=1, high=10).to(torch.float32),
|
||||||
|
tu.rand(2, 3, 4, low=1, high=10).to(torch.float16),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TensorsConcatComplex128FloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.complex128, True),
|
||||||
|
([-1, -1, -1], torch.float64, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float16, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b, c, d):
|
||||||
|
return torch.cat([a, b, c, d], 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TensorsConcatComplex128FloatModule())
|
||||||
|
def TensorsConcatComplex128FloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.rand(2, 1, 4, low=1, high=10).to(torch.complex128),
|
||||||
|
tu.rand(2, 3, 4, low=1, high=10).to(torch.float64),
|
||||||
|
tu.rand(2, 3, 4, low=1, high=10).to(torch.float32),
|
||||||
|
tu.rand(2, 3, 4, low=1, high=10).to(torch.float16),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TensorsConcatNegativeDimModule(torch.nn.Module):
|
class TensorsConcatNegativeDimModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue