[LINALG] Add complex tensor support for `create[Zero|One]InitTensor` utility (#3777)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3781/head
Vivek Khandelwal 2024-10-09 16:15:08 +05:30 committed by GitHub
parent d49eabb3fc
commit 94f5410913
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 6 deletions

View File

@ -132,9 +132,12 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
Value c0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
Type fillValElemTy = elemTy;
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
Value c0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(fillValElemTy));
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
}
@ -142,9 +145,12 @@ Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
Value c1 =
b.create<arith::ConstantOp>(loc, b.getOneAttr(type.getElementType()));
Type fillValElemTy = elemTy;
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
Value c1 = b.create<arith::ConstantOp>(loc, b.getOneAttr(fillValElemTy));
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
}

View File

@ -1423,6 +1423,7 @@ STABLEHLO_PASS_SET = {
"SliceSizeTwoStepModule_basic",
"SliceStartEqEndModule_basic",
"SliceStaticModule_basic",
"SliceStaticComplexInputModule_basic",
"SliceWholeTensorModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
@ -2618,6 +2619,7 @@ ONNX_XFAIL_SET = {
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",
"SliceCopy_Module_basic",
"SliceStaticComplexInputModule_basic",
"StdCorrectionLargeInputModule_basic",
"TupleModule_basic",
"VarCorrectionLargeInputModule_basic",
@ -3778,6 +3780,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"SignAndLogarithmOfDeterminantModule_F32",
"SignAndLogarithmOfDeterminantBatchedModule_F32",
"SignAndLogarithmOfDeterminantDynamicModule_F32",
"SliceStaticComplexInputModule_basic",
"SliceCopyEndGreaterThanDimSize_Module_basic",
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",
@ -4714,6 +4717,7 @@ ONNX_TOSA_XFAIL_SET = {
"SliceCopy_Module_basic",
"SliceEndSleStartModule_basic",
"SliceModule_basic",
"SliceStaticComplexInputModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",

View File

@ -58,6 +58,29 @@ def SliceStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class SliceStaticComplexInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([6, 4, 7], torch.complex64, True),
]
)
def forward(self, x):
return x[0:5:1, 1:3:1, 2:4:1]
@register_test_case(module_factory=lambda: SliceStaticComplexInputModule())
def SliceStaticComplexInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 7).to(torch.complex64))
# ==============================================================================
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()