diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 4af9709fd..99ea66bea 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -350,6 +350,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, } if (auto dtypeComplex = dyn_cast(dtype)) { + + // Complex to complex. if (auto scalarComplex = dyn_cast(scalarType)) { auto dtypeElemType = dtypeComplex.getElementType(); @@ -364,6 +366,27 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return b.create(loc, dtypeComplex, realVal, imgVal); } + + // Float to complex type. + if (auto dtypeFloat = dyn_cast(scalarType)) { + auto complexElementType = + cast(dtypeComplex.getElementType()); + Value realVal; + Value imgVal = + b.create(loc, b.getZeroAttr(complexElementType)); + + if (complexElementType.getWidth() > dtypeFloat.getWidth()) { + realVal = b.create(loc, complexElementType, scalar); + } else if (complexElementType.getWidth() < dtypeFloat.getWidth()) { + realVal = b.create(loc, complexElementType, scalar); + ; + } else { + realVal = scalar; + } + + return b.create(loc, dtypeComplex, realVal, imgVal); + } + mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " << scalarType << "(scalar type) -> " << dtype << "(dtype)"; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a24840b29..7276d4435 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1320,6 +1320,8 @@ STABLEHLO_PASS_SET = { "TensorToFloatZeroRank_basic", "TensorToIntZeroRank_basic", "TensorsConcatModule_basic", + "TensorsConcatComplex128FloatModule_basic", + "TensorsConcatComplex64FloatModule_basic", "TensorsConcatNegativeDimModule_basic", "TensorsConcatNegativeDimStaticModule_basic", "TensorsConcatPromoteDTypeModule_basic", @@ -2598,6 +2600,8 @@ ONNX_XFAIL_SET = { "SubFloatModule_basic", "SubIntModule_basic", "TanhBackward_basic", + "TensorsConcatComplex128FloatModule_basic", + "TensorsConcatComplex64FloatModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", "TensorToFloatZeroRank_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 082223631..e5b4f3147 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1011,6 +1011,68 @@ def TensorsConcatModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsConcatComplex64FloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex64, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float16, True), + ] + ) + def forward(self, a, b, c, d): + return torch.cat([a, b, c, d], 1) + + +@register_test_case(module_factory=lambda: TensorsConcatComplex64FloatModule()) +def TensorsConcatComplex64FloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 1, 4, low=1, high=10).to(torch.complex64), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float64), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float32), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float16), + ) + + +# ============================================================================== + + +class TensorsConcatComplex128FloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex128, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float16, True), + ] + ) + def forward(self, a, b, c, d): + return torch.cat([a, b, c, d], 1) + + +@register_test_case(module_factory=lambda: TensorsConcatComplex128FloatModule()) +def TensorsConcatComplex128FloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 1, 4, low=1, high=10).to(torch.complex128), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float64), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float32), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float16), + ) + + +# ============================================================================== + + class TensorsConcatNegativeDimModule(torch.nn.Module): def __init__(self): super().__init__()