[FxImporter] Fix fx importer test config and clean xfail set (#3176)

pull/3180/head
penguin_wwy 2024-04-17 13:36:07 +08:00 committed by GitHub
parent 398aeeec87
commit e4b11a0ab4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 164 additions and 318 deletions

View File

@ -387,320 +387,148 @@ TORCHDYNAMO_CRASHING_SET = {
} }
FX_IMPORT_XFAIL_SET = { FX_IMPORT_XFAIL_SET = {
"AddIntModule_basic", 'AddIntModule_basic',
"AllBoolFalseModule_basic", 'AllBoolFalseModule_basic',
"AllBoolTrueModule_basic", 'AllBoolTrueModule_basic',
"AnyBoolFalseModule_basic", 'AnyBoolFalseModule_basic',
"AnyBoolTrueModule_basic", 'AnyBoolTrueModule_basic',
"ArangeStartOutViewModule_basic", 'ArangeStartOutViewModule_basic',
"AtenEmbeddingBagStaticModule_basic", 'AtenEmbeddingBagStaticModule_basic',
"AtenEmbeddingBagSumExample_basic", 'AtenEmbeddingBagSumExample_basic',
"AtenFloatScalarModule_basic", 'AtenFloatScalarModule_basic',
"AtenIntBoolOpConstFalseModule_basic", 'AtenIntBoolOpConstFalseModule_basic',
"AtenIntBoolOpConstTrueModule_basic", 'AtenIntBoolOpConstTrueModule_basic',
"AtenIntBoolOpModule_basic", 'AtenIntBoolOpModule_basic',
"AtenIntTensorByteDtypeModule_basic", 'AtenIntTensorByteDtypeModule_basic',
"AtenIntTensorCharDtypeModule_basic", 'AtenIntTensorCharDtypeModule_basic',
"AtenItemFpOpModule_basic", 'AtenItemFpOpModule_basic',
"AtenItemIntOpModule_basic", 'AtenItemIntOpModule_basic',
"AtenMmQint8_basic", 'AtenMatmulQMixedSigni8Transpose_basic',
"AtenMatmulQint8_basic", 'AtenMatmulQMixedSigni8_basic',
"AtenMatmulQint8MV_basic", 'AtenMatmulQint8MV_basic',
"AtenMmQMixedSigni8_basic", 'AtenMatmulQint8_basic',
"AtenMatmulQMixedSigni8_basic", 'AtenMatmulQint8VM_basic',
"AtenMatmulQMixedSigni8Transpose_basic", 'AtenMatmulQint8VV_basic',
"AtenMmQuint8_basic", 'AtenMmQMixedSigni8_basic',
"AtenSubFloatModule_basic", 'AtenMmQint8_basic',
"BincountMinlengthModule_basic", 'AtenMmQuint8_basic',
"BincountModule_basic", 'AtenSubFloatModule_basic',
"BincountStaticSizeModule_basic", 'BincountMinlengthModule_basic',
"BoolFloatConstantModule_basic", 'BincountModule_basic',
"BoolFloatFalseModule_basic", 'BincountStaticSizeModule_basic',
"BoolFloatTrueModule_basic", 'BoolFloatConstantModule_basic',
"BoolIntConstantModule_basic", 'BoolFloatFalseModule_basic',
"BoolIntFalseModule_basic", 'BoolFloatTrueModule_basic',
"BoolIntTrueModule_basic", 'BoolIntConstantModule_basic',
"BroadcastDynamicDimModule_basic", 'BoolIntFalseModule_basic',
"CeilFloatModule_basic", 'BoolIntTrueModule_basic',
"ConstantBoolParameterModule_basic", 'BroadcastDynamicDimModule_basic',
"ContainsIntList_False", 'CeilFloatModule_basic',
"ContainsIntList_True", 'ConstantBoolParameterModule_basic',
"Conv2dQInt8Module_basic", 'ContainsIntList_False',
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", 'ContainsIntList_True',
"ConvTbcModule_basic", 'Conv2dQInt8Module_basic',
"ConvolutionBackwardModule2DPadded_basic", 'Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier',
"ConvolutionBackwardModule2DStrided_basic", 'ConvTbcModule_basic',
"ConvolutionBackwardModule2D_basic", 'ConvolutionBackwardModule2DPadded_basic',
"DivFloatModule_basic", 'ConvolutionBackwardModule2DStrided_basic',
"DivIntModule_basic", 'ConvolutionBackwardModule2D_basic',
"ElementwiseAddScalar_NumToTensorFloat_Module_basic", 'CumsumModule_basic',
"ElementwiseDequantizePerChannelModule_basic", 'DivFloatModule_basic',
"ElementwiseDequantizePerTensorModule_basic", 'DivIntModule_basic',
"ElementwiseQuantizePerTensorModule_basic", 'ElementwiseAddScalar_NumToTensorFloat_Module_basic',
"ElementwiseQuantizePerTensorUIntModule_basic", 'ElementwiseDequantizePerChannelModule_basic',
"ElementwiseToDtypeI64ToUI8Module_basic", 'ElementwiseDequantizePerTensorModule_basic',
"EqIntModule_basic", 'ElementwiseQuantizePerTensorModule_basic',
"FakeQuantizePerTensorAffineDynamicShapeModule_basic", 'ElementwiseQuantizePerTensorUIntModule_basic',
"FakeQuantizePerTensorAffineModule_basic", 'ElementwiseToDtypeI64ToUI8Module_basic',
"FakeQuantizePerTensorAffineRoundToEvenModule_basic", 'EqIntModule_basic',
"FloatImplicitModule_basic", 'FakeQuantizePerTensorAffineDynamicShapeModule_basic',
"GeFloatIntModule_basic", 'FakeQuantizePerTensorAffineModule_basic',
"GeFloatModule_basic", 'FakeQuantizePerTensorAffineRoundToEvenModule_basic',
"GeIntModule_basic", 'FloatImplicitModule_basic',
"GtFloatIntModule_basic", 'GeFloatIntModule_basic',
"GtIntModule_basic", 'GeFloatModule_basic',
"IntFloatModule_basic", 'GeIntModule_basic',
"IntImplicitModule_basic", 'GtFloatIntModule_basic',
"IsFloatingPointFloat_True", 'GtIntModule_basic',
"IsFloatingPointInt_False", 'IntFloatModule_basic',
"LenStrModule_basic", 'IntImplicitModule_basic',
"MulFloatModule_basic", 'IsFloatingPointFloat_True',
"MulIntModule_basic", 'IsFloatingPointInt_False',
"NeFloatIntModule_basic", 'LenStrModule_basic',
"NeIntModule_basic", 'MaxPool3dCeilModeTrueModule_basic',
"NllLossModuleBackward1DMeanWeight_basic", 'MaxPool3dEmptyStrideStaticModule_basic',
"NllLossModuleBackward1DMean_basic", 'MaxPool3dLargeDatadModule_basic',
"NllLossModuleBackward1DSumWeight_basic", 'MaxPool3dModuleRandomSimple_basic',
"NllLossModuleBackward1DSum_basic", 'MaxPool3dModule_basic',
"NllLossModuleBackward1DWeight_basic", 'MaxPool3dStaticCeilModeTrueModule_basic',
"NllLossModuleBackward1D_basic", 'MaxPool3dStaticModule_basic',
"NumToTensorFloatModule_basic", 'MulFloatModule_basic',
"NumToTensorIntModule_basic", 'MulIntModule_basic',
"NumelModule_basic", 'NativeGroupNormBackwardModule_basic',
"NumelZeroRankModule_basic", 'NeFloatIntModule_basic',
"PowIntFloatModule_basic", 'NeIntModule_basic',
"PrimMaxIntModule_basic", 'NllLossModuleBackward1DMeanWeight_basic',
"PrimMinIntDynamicModule_basic", 'NllLossModuleBackward1DMean_basic',
"PrimMinIntModule_basic", 'NllLossModuleBackward1DSumWeight_basic',
"PrimsSqueezeEmptyDimensionsModule_basic", 'NllLossModuleBackward1DSum_basic',
"PrimsSqueezeModule_basic", 'NllLossModuleBackward1DWeight_basic',
"PrimsViewOfModule_basic", 'NllLossModuleBackward1D_basic',
"PrimsViewOfZeroRankModule_basic", 'NumToTensorFloatModule_basic',
"QuantizedMLP_basic", 'NumToTensorIntModule_basic',
"QuantizedNoLayer_basic", 'NumelModule_basic',
"QuantizedSingleLayer_basic", 'NumelZeroRankModule_basic',
"QuantizedBatchedInputSingleLayer_basic", 'PowIntFloatModule_basic',
"ReduceMaxAlongDimUnsignedInt_basic", 'PrimMaxIntModule_basic',
"ReduceMinAlongDimUnsignedInt_basic", 'PrimMinIntDynamicModule_basic',
"RsubInt0d_NumToTensor_Module_basic", 'PrimMinIntModule_basic',
"ScalarConstantTupleModule_basic", 'PrimsSqueezeEmptyDimensionsModule_basic',
"ScalarImplicitFloatModule_basic", 'PrimsSqueezeModule_basic',
"ScalarImplicitIntModule_basic", 'PrimsViewOfModule_basic',
"SortIntListReverse_basic", 'PrimsViewOfZeroRankModule_basic',
"SortIntList_basic", 'QuantizedBatchedInputSingleLayer_basic',
"SplitDimDynamicModule_basic", 'QuantizedMLP_basic',
"SplitDimStaticModule_basic", 'QuantizedNoLayer_basic',
"SqrtIntConstantModule_basic", 'QuantizedSingleLayer_basic',
"SqrtIntModule_basic", 'ReduceMaxAlongDimUnsignedInt_basic',
"SubFloatModule_basic", 'ReduceMinAlongDimUnsignedInt_basic',
"SubIntModule_basic", 'RsubInt0d_NumToTensor_Module_basic',
"TModuleRank0_basic", 'ScalarConstantTupleModule_basic',
"TensorToBoolZeroRank_basic", 'ScalarImplicitFloatModule_basic',
"TensorToBool_basic", 'ScalarImplicitIntModule_basic',
"TensorToFloatZeroRank_basic", 'ScatterValueFloatModule_basic',
"TensorToFloat_basic", 'ScatterValueIntModule_basic',
"TensorToIntZeroRank_basic", 'SortIntListReverse_basic',
"TensorToInt_basic", 'SortIntList_basic',
"ThresholdBackward2dMixedModule_basic", 'SplitDimDynamicModule_basic',
"TorchPrimLoopForLikeModule_basic", 'SplitDimStaticModule_basic',
"TorchPrimLoopWhileLikeModule_basic", 'SqrtIntConstantModule_basic',
"UnbindIntGetItem_Module_basic", 'SqrtIntModule_basic',
"UnbindIntListUnpack_Module_basic", 'SubFloatModule_basic',
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", 'SubIntModule_basic',
"UpSampleNearest2dDynamicFactor_basic", 'TModuleRank0_basic',
"ViewCollapseDynamicWithAtenSizeIntModule_basic", 'TensorToBoolZeroRank_basic',
"ViewSizeFromOtherTensor_basic", 'TensorToBool_basic',
'TensorToFloatZeroRank_basic',
'TensorToFloat_basic',
'TensorToIntZeroRank_basic',
'TensorToInt_basic',
'TestMultipleTensorAndPrimitiveTypesReturn_basic',
'ThresholdBackward2dMixedModule_basic',
'TorchPrimLoopForLikeModule_basic',
'TorchPrimLoopWhileLikeModule_basic',
'UnbindIntGetItem_Module_basic',
'UnbindIntListUnpack_Module_basic',
'UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic',
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
'UpSampleNearest2dDynamicFactor_basic',
'ViewCollapseDynamicWithAtenSizeIntModule_basic',
'ViewSizeFromOtherTensor_basic',
} }
FX_IMPORTER_CRASHING_SET = { FX_IMPORTER_CRASHING_SET = {
'TransposeIntModule_basic', "HBC_basic",
'PermuteModule_basic',
'PermuteNegativeIndexModule_basic',
'TransposeIntNegDimsModule_basic',
'Add_MixPModule_basic',
'EmbeddingModuleI64_basic',
'EmbeddingModuleI32_basic',
'EmbeddingModuleF16_basic',
'EmbeddingModuleI32Static_basic',
'EmbeddingModule1DIndices_basic',
'BroadcastToSameRankStaticModule_basic',
'BroadcastZeroRankInputStaticModule_basic',
'BroadcastListConstructWithMinusOneModule_basic',
'ExpandModule_basic',
'ReturnThreeTensorFloat32_basic',
'AddCDivModule_basic',
'TensorIntModule_basic',
'TensorFloatModule_basic',
'BoolTensorReturnFalseModule_basic',
'BoolTensorReturnTrueModule_basic',
'BoolTensorReturnMixedModule_basic',
'TModuleRank2_basic',
'ReturnTwoTensorF32I64_basic',
'IndexTensorMultiInputNonContiguousDynamic_basic',
'ExpandAsFloatModule_basic',
'ExpandAsIntModule_basic',
'CopyModule_basic',
'CopyWithDifferentSizesModule_basic',
'CopyWithDifferentDTypesModule_basic',
'CopyWithDifferentDTypesAndSizesModule_basic',
'ToCopyModule_basic',
'NumpyTRankNStaticModule_basic',
'NumpyTRankNDynamicModule_basic',
'NumpyTRank2Module_basic',
'Aten_EmbeddingBagExample_basic',
'CumsumModule_basic',
'MoveDimIntModule_basic',
'MoveDimIntNegativeIndexModule_basic',
'ScaledDotProductAttentionDifferentModule_basic',
'AtenComplex64Module_basic',
'Add_Module_basic',
'ResNet18Module_basic',
'ResNet18StaticModule_basic',
'MobilenetV3Module_basic',
'Mlp1LayerModule_basic',
'Mlp2LayerModule_basic',
'Mlp2LayerModuleNoBias_basic',
'BatchMlpLayerModule_basic',
'Conv2dNoPaddingModule_basic',
'Conv2dBiasNoPaddingModule_basic',
'Conv2dWithPaddingModule_basic',
'Conv2dWithPaddingDilationStrideModule_basic',
'Conv2dWithPaddingDilationStrideStaticModule_basic',
'Conv2dWithPaddingDilationStrideStaticModule_depthwise',
'Conv2dWithPaddingDilationStrideStaticModule_grouped',
'Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier',
'BatchNorm1DModule_basic',
'BatchNorm1DWith2DInputModule_basic',
'BatchNorm2DModule_basic',
'BatchNorm3DModule_basic',
'BatchNorm1DStaticShapeModule_basic',
'NativeBatchNorm1DModule_basic',
'NativeBatchNorm2DModule_basic',
'NativeBatchNorm3DModule_basic',
'NativeBatchNormNoneWeightModule_basic',
'NativeGroupNormBackwardModule_basic',
'LayerNormModule_basic',
'LayerNormLastDimModule_basic',
'LayerNormNormalizeOverAllDimsModule_basic',
'AtenInstanceNormModule_basic',
'ElementwiseUnsqueezeBroadcastModule_basic',
'ElementwiseFlattenBroadcastModule_basic',
'ElementwiseDivScalarModule_basic',
'ElementwiseAtenDivIntScalarModule_basic',
'ElementwiseDivRoundingModeTruncModule_basic',
'ElementwiseDivRoundingModeFloorModule_basic',
'ElementwiseDivRoundingModeTruncStaticModule_basic',
'ElementwiseDivRoundingModeFloorStaticModule_basic',
'ElementwiseDivRoundingModeTruncIntStaticModule_basic',
'ElementwiseDivRoundingModeFloorIntStaticModule_basic',
'ElementwiseSubScalarFloatModule_basic',
'ElementwiseAddScalar_TensorLiteralInt32_Module_basic',
'ElementwiseCloneModule_basic',
'ElementwiseCloneContiguousModule_basic',
'ElementwiseCloneChannelsLastMemoryFormatModule_basic',
'Fill_TensorFloat64WithFloat32_basic',
'Fill_TensorFloat64WithFloat32Static_basic',
'Fill_TensorFloat64WithFloat64_basic',
'Fill_TensorFloat64WithInt64_basic',
'Fill_TensorFloat64WithInt64Static_basic',
'Fill_TensorFloat32WithFloat32_basic',
'Fill_TensorFloat32WithFloat64_basic',
'Fill_TensorFloat32WithInt64_basic',
'TupleModule_basic',
'Matmul_dot',
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
'SqueezeModule_broadcast',
'SelectIntNegativeDimAndIndexStaticModule_basic',
'NarrowVerticalTest_basic',
'NarrowVerticalTest2_basic',
'SliceCopy_Module_basic',
'SliceCopyNegative_Module_basic',
'SliceCopyStartGreaterThanDimSize_Module_basic',
'SliceCopyEndGreaterThanDimSize_Module_basic',
'SliceCopyNonZeroDim_Module_basic',
'SplitTensorGetItem_Module_basic',
'SplitTensorListUnpackModule_basic',
'SplitTensorLastSmallerModule_basic',
'SplitTensorNegativeDimModule_basic',
'SplitWithSizesListUnpackModule_basic',
'ChunkListUnpack_Module_basic',
'ChunkListUnpackUneven_Module_basic',
'ChunkListUnpackDynamic_Module_basic',
'ChunkListUnpackUnevenDynamic_Module_basic',
'SplitWithSizes_Module_basic',
'ArangeStartOutModule_basic',
'ArangeStartOutDtypeModule_basic',
'FullModuleInt3D_basic',
'ZeroFloat32Module_basic',
'ZeroInt32Module_basic',
'ZeroInt64Module_basic',
'ThresholdBackward1dIntModule_basic',
'ThresholdBackward2dIntModule_basic',
'ThresholdBackward3dIntModule_basic',
'HBC_basic',
'UniformModule_basic',
'UniformStaticShapeModule_basic',
'UniformNoCorrelationModule_basic',
'ExponentialModule_basic',
'BernoulliFloatModule_basic',
'BernoulliTensorModule_basic',
'RandnGeneratorModule_basic',
'RandnGeneratorF64Module_basic',
'IndexPutImpl1DFloatNonAccumulateModule_basic',
'IndexPutImpl2DFloatNonAccumulateModule_basic',
'IndexPutImpl2DImplicitModule_basic',
'IndexPutImpl2DNoneIndexStaticModule_basic',
'IndexPutImpl3DFloatNonAccumulateModule_basic',
'IndexPutImpl1DIntNonAccumulateModule_basic',
'IndexPutImpl1DFloatAccumulateModule_basic',
'IndexPutImpl2DFloatAccumulateModule_basic',
'IndexPutImpl3DFloatAccumulateModule_basic',
'IndexPutImpl1DIntAccumulateModule_basic',
'IndexPut1DFloatNonAccumulateModule_basic',
'IndexPut2DFloatNonAccumulateModule_basic',
'IndexPut3DFloatNonAccumulateModule_basic',
'IndexPut1DIntNonAccumulateModule_basic',
'IndexPut2DIntNonAccumulateModule_basic',
'IndexPut3DIntNonAccumulateModule_basic',
'IndexPut1DFloatAccumulateModule_basic',
'IndexPut2DFloatAccumulateModule_basic',
'IndexPut3DFloatAccumulateModule_basic',
'IndexPut1DIntAccumulateModule_basic',
'IndexPut2DIntAccumulateModule_basic',
'IndexPut3DIntAccumulateModule_basic',
'IndexPutHackedTwin1DFloatNonAccumulateModule_basic',
'IndexPutHackedTwin2DFloatNonAccumulateModule_basic',
'IndexPutHackedTwin3DFloatNonAccumulateModule_basic',
'IndexPutHackedTwin1DIntNonAccumulateModule_basic',
'IndexPutHackedTwin2DIntNonAccumulateModule_basic',
'IndexPutHackedTwin3DIntNonAccumulateModule_basic',
'IndexPutHackedTwin1DFloatAccumulateModule_basic',
'IndexPutHackedTwin2DFloatAccumulateModule_basic',
'IndexPutHackedTwin3DFloatAccumulateModule_basic',
'IndexPutHackedTwin1DIntAccumulateModule_basic',
'IndexPutHackedTwin2DIntAccumulateModule_basic',
'IndexPutHackedTwin3DIntAccumulateModule_basic',
'ScatterValueFloatModule_basic',
'ScatterValueIntModule_basic',
'IndexPutImpl2DIndexModule_basic',
'IndexPutImplIndexWithNoneModule_basic',
'AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic',
'AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic',
'AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic',
'AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic',
'MaxPool3dModule_basic',
'MaxPool3dModuleRandomSimple_basic',
'MaxPool3dLargeDatadModule_basic',
'MaxPool3dEmptyStrideStaticModule_basic',
'MaxPool3dStaticModule_basic',
'MaxPool3dStaticCeilModeTrueModule_basic',
'MaxPool3dCeilModeTrueModule_basic',
'AdaptiveAvgPool1dStaticEvenMultiple_basic',
'AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic',
'AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic',
'TestMultipleTensorReturn_basic',
'TestMultipleTensorAndPrimitiveTypesReturn_basic',
'TestF16Return_basic',
} }
STABLEHLO_PASS_SET = { STABLEHLO_PASS_SET = {

View File

@ -8,6 +8,8 @@ from typing import Union, Optional, Sequence
import numpy as np import numpy as np
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch.export.graph_signature import OutputSpec, OutputKind
from torch.export import ExportedProgram
from torch_mlir import fx from torch_mlir import fx
from torch_mlir.torchscript import ( from torch_mlir.torchscript import (
@ -37,8 +39,8 @@ def refine_result_type(_result):
def jit( def jit(
model: torch.nn.Module, prog: ExportedProgram,
example_args: _example_args, func_name: str,
output_type: Union[str, "OutputType"] = OutputType.TORCH, output_type: Union[str, "OutputType"] = OutputType.TORCH,
backend_legal_ops: Optional[Sequence[str]] = None, backend_legal_ops: Optional[Sequence[str]] = None,
extra_library=None, extra_library=None,
@ -61,7 +63,7 @@ def jit(
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) + option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
" extra-library=" + extra_library_file_name + "}") " extra-library=" + extra_library_file_name + "}")
mlir_module = fx.export_and_import(model, *example_args, func_name=model.__class__.__name__) mlir_module = fx.export_and_import(prog, func_name=func_name)
assert mlir_module is not None assert mlir_module is not None
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
mlir_module, mlir_module,
@ -90,8 +92,9 @@ class FxImporterTestConfig(TestConfig):
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = [] result: Trace = []
for item in trace: for item in trace:
module = jit(artifact, prog = torch.export.export(artifact, tuple(item.inputs))
item.inputs, module = jit(prog,
func_name=artifact.__class__.__name__,
output_type="linalg-on-tensors") output_type="linalg-on-tensors")
module = self.backend.compile(module) module = self.backend.compile(module)
backend_module = self.backend.load(module) backend_module = self.backend.load(module)
@ -107,6 +110,13 @@ class FxImporterTestConfig(TestConfig):
outputs = getattr(backend_module, outputs = getattr(backend_module,
artifact.__class__.__name__)(*numpy_inputs) artifact.__class__.__name__)(*numpy_inputs)
output = refine_result_type(outputs) output = refine_result_type(outputs)
if isinstance(output, (tuple, list)):
user_output = []
out_spec: OutputSpec
for val, out_spec in zip(output, prog.graph_signature.output_specs):
if out_spec.kind == OutputKind.USER_OUTPUT:
user_output.append(val)
output = tuple(user_output)
result.append( result.append(
TraceItem(symbol=item.symbol, TraceItem(symbol=item.symbol,
inputs=item.inputs, inputs=item.inputs,

View File

@ -74,6 +74,10 @@ class ValueReport:
def _evaluate_outcome(self): def _evaluate_outcome(self):
value, golden = self.value, self.golden_value value, golden = self.value, self.golden_value
if isinstance(value, tuple) and len(value) == 1:
value = value[0]
if isinstance(golden, tuple) and len(golden) == 1:
golden = golden[0]
if isinstance(golden, float): if isinstance(golden, float):
if not isinstance(value, float): if not isinstance(value, float):
return self._record_mismatch_type_failure('float', value) return self._record_mismatch_type_failure('float', value)

View File

@ -10,6 +10,7 @@ import warnings
import torch import torch
import torch.export import torch.export
import torch.nn as nn import torch.nn as nn
from torch.export import ExportedProgram
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
from torch_mlir import ir from torch_mlir import ir
@ -18,7 +19,7 @@ from torch_mlir.extras.fx_decomp_util import get_decomposition_table
def export_and_import( def export_and_import(
f, f: Union[nn.Module, ExportedProgram],
*args, *args,
fx_importer: Optional[FxImporter] = None, fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
@ -34,6 +35,9 @@ def export_and_import(
if fx_importer is None: if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks) fx_importer = FxImporter(context=context, hooks=hooks)
if isinstance(f, ExportedProgram):
prog = f
else:
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
if decomposition_table is None: if decomposition_table is None:
decomposition_table = get_decomposition_table() decomposition_table = get_decomposition_table()