mirror of https://github.com/llvm/torch-mlir
[LINALG] Add complex tensor support for `create[Zero|One]InitTensor` utility (#3777)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3781/head
parent
d49eabb3fc
commit
94f5410913
|
@ -132,9 +132,12 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|||
Type elemTy) {
|
||||
Value initTensor =
|
||||
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
|
||||
Value c0 =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
|
||||
|
||||
Type fillValElemTy = elemTy;
|
||||
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
|
||||
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
|
||||
|
||||
Value c0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(fillValElemTy));
|
||||
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
}
|
||||
|
||||
|
@ -142,9 +145,12 @@ Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|||
Type elemTy) {
|
||||
Value initTensor =
|
||||
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
|
||||
Value c1 =
|
||||
b.create<arith::ConstantOp>(loc, b.getOneAttr(type.getElementType()));
|
||||
|
||||
Type fillValElemTy = elemTy;
|
||||
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
|
||||
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
|
||||
|
||||
Value c1 = b.create<arith::ConstantOp>(loc, b.getOneAttr(fillValElemTy));
|
||||
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
|
||||
}
|
||||
|
||||
|
|
|
@ -1423,6 +1423,7 @@ STABLEHLO_PASS_SET = {
|
|||
"SliceSizeTwoStepModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"SliceStaticModule_basic",
|
||||
"SliceStaticComplexInputModule_basic",
|
||||
"SliceWholeTensorModule_basic",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
|
@ -2618,6 +2619,7 @@ ONNX_XFAIL_SET = {
|
|||
"SliceCopyNegative_Module_basic",
|
||||
"SliceCopyNonZeroDim_Module_basic",
|
||||
"SliceCopy_Module_basic",
|
||||
"SliceStaticComplexInputModule_basic",
|
||||
"StdCorrectionLargeInputModule_basic",
|
||||
"TupleModule_basic",
|
||||
"VarCorrectionLargeInputModule_basic",
|
||||
|
@ -3778,6 +3780,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"SignAndLogarithmOfDeterminantModule_F32",
|
||||
"SignAndLogarithmOfDeterminantBatchedModule_F32",
|
||||
"SignAndLogarithmOfDeterminantDynamicModule_F32",
|
||||
"SliceStaticComplexInputModule_basic",
|
||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||
"SliceCopyNegative_Module_basic",
|
||||
"SliceCopyNonZeroDim_Module_basic",
|
||||
|
@ -4714,6 +4717,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"SliceCopy_Module_basic",
|
||||
"SliceEndSleStartModule_basic",
|
||||
"SliceModule_basic",
|
||||
"SliceStaticComplexInputModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
"SliceOutOfLowerBoundEndIndexModule_basic",
|
||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||
|
|
|
@ -58,6 +58,29 @@ def SliceStaticModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceStaticComplexInputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([6, 4, 7], torch.complex64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x[0:5:1, 1:3:1, 2:4:1]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceStaticComplexInputModule())
|
||||
def SliceStaticComplexInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4, 7).to(torch.complex64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue