mirror of https://github.com/llvm/torch-mlir
Fix torchdynamo fail test.
parent
8eb0c7e656
commit
c47d3aab01
|
@ -202,6 +202,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'TorchPrimLoopForLikeModule_basic',
|
||||
'TorchPrimLoopWhileLikeModule_basic',
|
||||
|
||||
# Forming aten.view_as_real and aten.view_as_imag instead of aten.real and aten.imag op.
|
||||
# Complex ops
|
||||
"AtenComplexImagModule_basic",
|
||||
"AtenComplexRealModule_basic",
|
||||
|
@ -638,6 +639,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ConvolutionBackwardModule2DStrided_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
"AtenComplex64Module_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
|
@ -918,6 +920,7 @@ TOSA_PASS_SET = {
|
|||
"DetachModule_basic",
|
||||
"TensorsConcatStaticModule_basic",
|
||||
"TensorsConcatNegativeDimStaticModule_basic",
|
||||
"AtenComplex64Module_basic",
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
@ -1096,4 +1099,5 @@ LTC_XFAIL_SET = {
|
|||
"VarMeanDimBiasedModule_basic",
|
||||
"AtenComplexImagModule_basic",
|
||||
"AtenComplexRealModule_basic",
|
||||
"AtenComplexViewModule_basic"
|
||||
}
|
||||
|
|
|
@ -1350,8 +1350,7 @@ public:
|
|||
auto input = adaptor.getSelf();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
auto elementType = resultType.getElementType();
|
||||
SmallVector<Value> resultShape;
|
||||
|
@ -1382,7 +1381,7 @@ public:
|
|||
auto complexVar =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, resultType, ValueRange{}, outTensor, indexingMaps,
|
||||
loc, outTensor.getType(), ValueRange{}, outTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
SmallVector<Value> indicesZero;
|
||||
|
|
|
@ -3821,7 +3821,7 @@ class AtenComplexImagModule(torch.nn.Module):
|
|||
([-1], torch.complex64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.imag
|
||||
return torch.ops.aten.imag(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenComplexImagModule())
|
||||
|
@ -3840,7 +3840,7 @@ class AtenComplexRealModule(torch.nn.Module):
|
|||
([-1], torch.complex64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.real
|
||||
return torch.ops.aten.real(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenComplexRealModule())
|
||||
|
|
Loading…
Reference in New Issue