Fix torchdynamo fail test.

pull/2120/head
Prashant Kumar 2023-05-11 12:05:01 +00:00
parent 8eb0c7e656
commit c47d3aab01
3 changed files with 8 additions and 5 deletions

View File

@ -202,6 +202,7 @@ TORCHDYNAMO_XFAIL_SET = {
'TorchPrimLoopForLikeModule_basic',
'TorchPrimLoopWhileLikeModule_basic',
# Forming aten.view_as_real and aten.view_as_imag instead of aten.real and aten.imag op.
# Complex ops
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
@ -638,6 +639,7 @@ STABLEHLO_PASS_SET = {
"ConvolutionBackwardModule2DStrided_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"AtenComplex64Module_basic",
}
# Write the TOSA set as a "passing" set as it is very early in development
@ -918,6 +920,7 @@ TOSA_PASS_SET = {
"DetachModule_basic",
"TensorsConcatStaticModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
"AtenComplex64Module_basic",
}
LTC_XFAIL_SET = {
@ -1096,4 +1099,5 @@ LTC_XFAIL_SET = {
"VarMeanDimBiasedModule_basic",
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
"AtenComplexViewModule_basic"
}

View File

@ -1350,8 +1350,7 @@ public:
auto input = adaptor.getSelf();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
auto elementType = resultType.getElementType();
SmallVector<Value> resultShape;
@ -1382,7 +1381,7 @@ public:
auto complexVar =
rewriter
.create<linalg::GenericOp>(
loc, resultType, ValueRange{}, outTensor, indexingMaps,
loc, outTensor.getType(), ValueRange{}, outTensor, indexingMaps,
iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indicesZero;

View File

@ -3821,7 +3821,7 @@ class AtenComplexImagModule(torch.nn.Module):
([-1], torch.complex64, True),
])
def forward(self, x):
return x.imag
return torch.ops.aten.imag(x)
@register_test_case(module_factory=lambda: AtenComplexImagModule())
@ -3840,7 +3840,7 @@ class AtenComplexRealModule(torch.nn.Module):
([-1], torch.complex64, True),
])
def forward(self, x):
return x.real
return torch.ops.aten.real(x)
@register_test_case(module_factory=lambda: AtenComplexRealModule())