From da877a781e5a7f024d9501be35d98859be08f3f4 Mon Sep 17 00:00:00 2001 From: Branko Trifkovic <88882867+BaneTrifa@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:43:00 +0200 Subject: [PATCH] Added support for integer to complex conversion (#3604) --- lib/Conversion/Utils/Utils.cpp | 14 ++++++++- projects/pt1/e2e_testing/xfail_sets.py | 3 ++ .../torch_mlir_e2e_test/test_suite/basic.py | 29 +++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 99ea66bea..5ef0ab169 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -379,7 +379,6 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, realVal = b.create(loc, complexElementType, scalar); } else if (complexElementType.getWidth() < dtypeFloat.getWidth()) { realVal = b.create(loc, complexElementType, scalar); - ; } else { realVal = scalar; } @@ -387,6 +386,19 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return b.create(loc, dtypeComplex, realVal, imgVal); } + // Int to complex type. + if (auto dtypeInt = dyn_cast(scalarType)) { + auto complexElementType = + cast(dtypeComplex.getElementType()); + + Value realVal = + b.create(loc, complexElementType, scalar); + Value imgVal = + b.create(loc, b.getZeroAttr(complexElementType)); + + return b.create(loc, dtypeComplex, realVal, imgVal); + } + mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " << scalarType << "(scalar type) -> " << dtype << "(dtype)"; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index da5f19c63..38b97074e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1188,6 +1188,7 @@ STABLEHLO_PASS_SET = { "MoveDimIntModule_basic", "MoveDimIntNegativeIndexModule_basic", "MulFloatModule_basic", + "MulFloatModule_basic", "MulIntModule_basic", "Mv_basic", "NarrowHorizontalTest2_basic", @@ -1362,6 +1363,7 @@ STABLEHLO_PASS_SET = { "TensorsConcatModule_basic", "TensorsConcatComplex128FloatModule_basic", "TensorsConcatComplex64FloatModule_basic", + "TensorsConcatComplex128IntModule_basic", "TensorsConcatNegativeDimModule_basic", "TensorsConcatNegativeDimStaticModule_basic", "TensorsConcatPromoteDTypeModule_basic", @@ -2683,6 +2685,7 @@ ONNX_XFAIL_SET = { "TanhBackward_basic", "TensorsConcatComplex128FloatModule_basic", "TensorsConcatComplex64FloatModule_basic", + "TensorsConcatComplex128IntModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", "TensorToFloatZeroRank_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index e5b4f3147..2bda11410 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1073,6 +1073,35 @@ def TensorsConcatComplex128FloatModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsConcatComplex128IntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex128, True), + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.int32, True), + ] + ) + def forward(self, a, b, c): + return torch.cat([a, b, c], 1) + + +@register_test_case(module_factory=lambda: TensorsConcatComplex128IntModule()) +def TensorsConcatComplex128IntModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 1, 4, low=1, high=10).to(torch.complex128), + tu.rand(2, 3, 4, low=1, high=10).to(torch.int64), + tu.rand(2, 3, 4, low=1, high=10).to(torch.int32), + ) + + +# ============================================================================== + + class TensorsConcatNegativeDimModule(torch.nn.Module): def __init__(self): super().__init__()