mirror of https://github.com/llvm/torch-mlir
Added support for integer to complex conversion (#3604)
parent
cb6a499460
commit
da877a781e
|
@ -379,7 +379,6 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
realVal = b.create<arith::ExtFOp>(loc, complexElementType, scalar);
|
||||
} else if (complexElementType.getWidth() < dtypeFloat.getWidth()) {
|
||||
realVal = b.create<arith::TruncFOp>(loc, complexElementType, scalar);
|
||||
;
|
||||
} else {
|
||||
realVal = scalar;
|
||||
}
|
||||
|
@ -387,6 +386,19 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
||||
}
|
||||
|
||||
// Int to complex type.
|
||||
if (auto dtypeInt = dyn_cast<mlir::IntegerType>(scalarType)) {
|
||||
auto complexElementType =
|
||||
cast<mlir::FloatType>(dtypeComplex.getElementType());
|
||||
|
||||
Value realVal =
|
||||
b.create<arith::SIToFPOp>(loc, complexElementType, scalar);
|
||||
Value imgVal =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(complexElementType));
|
||||
|
||||
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
||||
}
|
||||
|
||||
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype
|
||||
<< "(dtype)";
|
||||
|
|
|
@ -1188,6 +1188,7 @@ STABLEHLO_PASS_SET = {
|
|||
"MoveDimIntModule_basic",
|
||||
"MoveDimIntNegativeIndexModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
"MulIntModule_basic",
|
||||
"Mv_basic",
|
||||
"NarrowHorizontalTest2_basic",
|
||||
|
@ -1362,6 +1363,7 @@ STABLEHLO_PASS_SET = {
|
|||
"TensorsConcatModule_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
"TensorsConcatComplex64FloatModule_basic",
|
||||
"TensorsConcatComplex128IntModule_basic",
|
||||
"TensorsConcatNegativeDimModule_basic",
|
||||
"TensorsConcatNegativeDimStaticModule_basic",
|
||||
"TensorsConcatPromoteDTypeModule_basic",
|
||||
|
@ -2683,6 +2685,7 @@ ONNX_XFAIL_SET = {
|
|||
"TanhBackward_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
"TensorsConcatComplex64FloatModule_basic",
|
||||
"TensorsConcatComplex128IntModule_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToBool_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
|
|
|
@ -1073,6 +1073,35 @@ def TensorsConcatComplex128FloatModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorsConcatComplex128IntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.complex128, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b, c):
|
||||
return torch.cat([a, b, c], 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorsConcatComplex128IntModule())
|
||||
def TensorsConcatComplex128IntModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 1, 4, low=1, high=10).to(torch.complex128),
|
||||
tu.rand(2, 3, 4, low=1, high=10).to(torch.int64),
|
||||
tu.rand(2, 3, 4, low=1, high=10).to(torch.int32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorsConcatNegativeDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue