mirror of https://github.com/llvm/torch-mlir
[FxImporter] Add FxImporter config in e2e-test (#3151)
parent
859f5d280f
commit
45eaeaaf36
|
@ -22,6 +22,7 @@ from torch_mlir_e2e_test.configs import (
|
|||
TorchScriptTestConfig,
|
||||
TosaBackendTestConfig,
|
||||
TorchDynamoTestConfig,
|
||||
FxImporterTestConfig,
|
||||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
|
@ -41,6 +42,8 @@ from .xfail_sets import (
|
|||
TORCHDYNAMO_CRASHING_SET,
|
||||
ONNX_CRASHING_SET,
|
||||
ONNX_XFAIL_SET,
|
||||
FX_IMPORT_XFAIL_SET,
|
||||
FX_IMPORTER_CRASHING_SET,
|
||||
)
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
|
@ -48,7 +51,7 @@ from torch_mlir_e2e_test.test_suite import register_all_tests
|
|||
register_all_tests()
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "onnx"]
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "onnx", "fx_importer"]
|
||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||
parser.add_argument("-c", "--config",
|
||||
choices=config_choices,
|
||||
|
@ -121,6 +124,10 @@ def main():
|
|||
config = LazyTensorCoreTestConfig()
|
||||
xfail_set = LTC_XFAIL_SET
|
||||
crashing_set = LTC_CRASHING_SET
|
||||
elif args.config == "fx_importer":
|
||||
config = FxImporterTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
xfail_set = FX_IMPORT_XFAIL_SET
|
||||
crashing_set = FX_IMPORTER_CRASHING_SET
|
||||
elif args.config == "torchdynamo":
|
||||
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
xfail_set = TORCHDYNAMO_XFAIL_SET
|
||||
|
|
|
@ -369,6 +369,319 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
}
|
||||
|
||||
FX_IMPORT_XFAIL_SET = {
|
||||
"AddIntModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
"AnyBoolFalseModule_basic",
|
||||
"AnyBoolTrueModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
"AtenFloatScalarModule_basic",
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
"AtenItemIntOpModule_basic",
|
||||
"AtenMmQuint8_basic",
|
||||
"AtenSubFloatModule_basic",
|
||||
"BincountMinlengthModule_basic",
|
||||
"BincountModule_basic",
|
||||
"BincountStaticSizeModule_basic",
|
||||
"BoolFloatConstantModule_basic",
|
||||
"BoolFloatFalseModule_basic",
|
||||
"BoolFloatTrueModule_basic",
|
||||
"BoolIntConstantModule_basic",
|
||||
"BoolIntFalseModule_basic",
|
||||
"BoolIntTrueModule_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"ConvTbcModule_basic",
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
"ConvolutionBackwardModule2DStrided_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarInt8Module_basic",
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"EqIntModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
"FloatImplicitModule_basic",
|
||||
"GeFloatIntModule_basic",
|
||||
"GeFloatModule_basic",
|
||||
"GeIntModule_basic",
|
||||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"IntFloatModule_basic",
|
||||
"IntImplicitModule_basic",
|
||||
"IsFloatingPointFloat_True",
|
||||
"IsFloatingPointInt_False",
|
||||
"LenStrModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
"MulIntModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"NllLossModuleBackward1DMeanWeight_basic",
|
||||
"NllLossModuleBackward1DMean_basic",
|
||||
"NllLossModuleBackward1DSumWeight_basic",
|
||||
"NllLossModuleBackward1DSum_basic",
|
||||
"NllLossModuleBackward1DWeight_basic",
|
||||
"NllLossModuleBackward1D_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
"PrimMinIntDynamicModule_basic",
|
||||
"PrimMinIntModule_basic",
|
||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||
"PrimsSqueezeModule_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"QuantizedNoLayer_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
"SplitDimDynamicModule_basic",
|
||||
"SplitDimStaticModule_basic",
|
||||
"SqrtIntConstantModule_basic",
|
||||
"SqrtIntModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TModuleRank0_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToBool_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
"TensorToFloat_basic",
|
||||
"TensorToIntZeroRank_basic",
|
||||
"TensorToInt_basic",
|
||||
"ThresholdBackward2dMixedModule_basic",
|
||||
"TorchPrimLoopForLikeModule_basic",
|
||||
"TorchPrimLoopWhileLikeModule_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_CRASHING_SET = {
|
||||
'TransposeIntModule_basic',
|
||||
'PermuteModule_basic',
|
||||
'PermuteNegativeIndexModule_basic',
|
||||
'TransposeIntNegDimsModule_basic',
|
||||
'Add_MixPModule_basic',
|
||||
'EmbeddingModuleI64_basic',
|
||||
'EmbeddingModuleI32_basic',
|
||||
'EmbeddingModuleF16_basic',
|
||||
'EmbeddingModuleI32Static_basic',
|
||||
'EmbeddingModule1DIndices_basic',
|
||||
'BroadcastToSameRankStaticModule_basic',
|
||||
'BroadcastZeroRankInputStaticModule_basic',
|
||||
'BroadcastListConstructWithMinusOneModule_basic',
|
||||
'ExpandModule_basic',
|
||||
'ReturnThreeTensorFloat32_basic',
|
||||
'AddCDivModule_basic',
|
||||
'TensorIntModule_basic',
|
||||
'TensorFloatModule_basic',
|
||||
'BoolTensorReturnFalseModule_basic',
|
||||
'BoolTensorReturnTrueModule_basic',
|
||||
'BoolTensorReturnMixedModule_basic',
|
||||
'TModuleRank2_basic',
|
||||
'ReturnTwoTensorF32I64_basic',
|
||||
'IndexTensorMultiInputNonContiguousDynamic_basic',
|
||||
'ExpandAsFloatModule_basic',
|
||||
'ExpandAsIntModule_basic',
|
||||
'CopyModule_basic',
|
||||
'CopyWithDifferentSizesModule_basic',
|
||||
'CopyWithDifferentDTypesModule_basic',
|
||||
'CopyWithDifferentDTypesAndSizesModule_basic',
|
||||
'ToCopyModule_basic',
|
||||
'NumpyTRankNStaticModule_basic',
|
||||
'NumpyTRankNDynamicModule_basic',
|
||||
'NumpyTRank2Module_basic',
|
||||
'Aten_EmbeddingBagExample_basic',
|
||||
'CumsumModule_basic',
|
||||
'MoveDimIntModule_basic',
|
||||
'MoveDimIntNegativeIndexModule_basic',
|
||||
'ScaledDotProductAttentionDifferentModule_basic',
|
||||
'AtenComplex64Module_basic',
|
||||
'Add_Module_basic',
|
||||
'ResNet18Module_basic',
|
||||
'ResNet18StaticModule_basic',
|
||||
'MobilenetV3Module_basic',
|
||||
'Mlp1LayerModule_basic',
|
||||
'Mlp2LayerModule_basic',
|
||||
'Mlp2LayerModuleNoBias_basic',
|
||||
'BatchMlpLayerModule_basic',
|
||||
'Conv2dNoPaddingModule_basic',
|
||||
'Conv2dBiasNoPaddingModule_basic',
|
||||
'Conv2dWithPaddingModule_basic',
|
||||
'Conv2dWithPaddingDilationStrideModule_basic',
|
||||
'Conv2dWithPaddingDilationStrideStaticModule_basic',
|
||||
'Conv2dWithPaddingDilationStrideStaticModule_depthwise',
|
||||
'Conv2dWithPaddingDilationStrideStaticModule_grouped',
|
||||
'Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier',
|
||||
'BatchNorm1DModule_basic',
|
||||
'BatchNorm1DWith2DInputModule_basic',
|
||||
'BatchNorm2DModule_basic',
|
||||
'BatchNorm3DModule_basic',
|
||||
'BatchNorm1DStaticShapeModule_basic',
|
||||
'NativeBatchNorm1DModule_basic',
|
||||
'NativeBatchNorm2DModule_basic',
|
||||
'NativeBatchNorm3DModule_basic',
|
||||
'NativeBatchNormNoneWeightModule_basic',
|
||||
'NativeGroupNormBackwardModule_basic',
|
||||
'LayerNormModule_basic',
|
||||
'LayerNormLastDimModule_basic',
|
||||
'LayerNormNormalizeOverAllDimsModule_basic',
|
||||
'AtenInstanceNormModule_basic',
|
||||
'ElementwiseUnsqueezeBroadcastModule_basic',
|
||||
'ElementwiseFlattenBroadcastModule_basic',
|
||||
'ElementwiseDivScalarModule_basic',
|
||||
'ElementwiseAtenDivIntScalarModule_basic',
|
||||
'ElementwiseDivRoundingModeTruncModule_basic',
|
||||
'ElementwiseDivRoundingModeFloorModule_basic',
|
||||
'ElementwiseDivRoundingModeTruncStaticModule_basic',
|
||||
'ElementwiseDivRoundingModeFloorStaticModule_basic',
|
||||
'ElementwiseDivRoundingModeTruncIntStaticModule_basic',
|
||||
'ElementwiseDivRoundingModeFloorIntStaticModule_basic',
|
||||
'ElementwiseSubScalarFloatModule_basic',
|
||||
'ElementwiseAddScalar_TensorLiteralInt32_Module_basic',
|
||||
'ElementwiseCloneModule_basic',
|
||||
'ElementwiseCloneContiguousModule_basic',
|
||||
'ElementwiseCloneChannelsLastMemoryFormatModule_basic',
|
||||
'Fill_TensorFloat64WithFloat32_basic',
|
||||
'Fill_TensorFloat64WithFloat32Static_basic',
|
||||
'Fill_TensorFloat64WithFloat64_basic',
|
||||
'Fill_TensorFloat64WithInt64_basic',
|
||||
'Fill_TensorFloat64WithInt64Static_basic',
|
||||
'Fill_TensorFloat32WithFloat32_basic',
|
||||
'Fill_TensorFloat32WithFloat64_basic',
|
||||
'Fill_TensorFloat32WithInt64_basic',
|
||||
'TupleModule_basic',
|
||||
'Matmul_dot',
|
||||
'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic',
|
||||
'SqueezeModule_broadcast',
|
||||
'SelectIntNegativeDimAndIndexStaticModule_basic',
|
||||
'NarrowVerticalTest_basic',
|
||||
'NarrowVerticalTest2_basic',
|
||||
'SliceCopy_Module_basic',
|
||||
'SliceCopyNegative_Module_basic',
|
||||
'SliceCopyStartGreaterThanDimSize_Module_basic',
|
||||
'SliceCopyEndGreaterThanDimSize_Module_basic',
|
||||
'SliceCopyNonZeroDim_Module_basic',
|
||||
'SplitTensorGetItem_Module_basic',
|
||||
'SplitTensorListUnpackModule_basic',
|
||||
'SplitTensorLastSmallerModule_basic',
|
||||
'SplitTensorNegativeDimModule_basic',
|
||||
'SplitWithSizesListUnpackModule_basic',
|
||||
'ChunkListUnpack_Module_basic',
|
||||
'ChunkListUnpackUneven_Module_basic',
|
||||
'ChunkListUnpackDynamic_Module_basic',
|
||||
'ChunkListUnpackUnevenDynamic_Module_basic',
|
||||
'SplitWithSizes_Module_basic',
|
||||
'ArangeStartOutModule_basic',
|
||||
'ArangeStartOutDtypeModule_basic',
|
||||
'FullModuleInt3D_basic',
|
||||
'ZeroFloat32Module_basic',
|
||||
'ZeroInt32Module_basic',
|
||||
'ZeroInt64Module_basic',
|
||||
'ThresholdBackward1dIntModule_basic',
|
||||
'ThresholdBackward2dIntModule_basic',
|
||||
'ThresholdBackward3dIntModule_basic',
|
||||
'HBC_basic',
|
||||
'UniformModule_basic',
|
||||
'UniformStaticShapeModule_basic',
|
||||
'UniformNoCorrelationModule_basic',
|
||||
'ExponentialModule_basic',
|
||||
'BernoulliFloatModule_basic',
|
||||
'BernoulliTensorModule_basic',
|
||||
'RandnGeneratorModule_basic',
|
||||
'RandnGeneratorF64Module_basic',
|
||||
'IndexPutImpl1DFloatNonAccumulateModule_basic',
|
||||
'IndexPutImpl2DFloatNonAccumulateModule_basic',
|
||||
'IndexPutImpl2DImplicitModule_basic',
|
||||
'IndexPutImpl2DNoneIndexStaticModule_basic',
|
||||
'IndexPutImpl3DFloatNonAccumulateModule_basic',
|
||||
'IndexPutImpl1DIntNonAccumulateModule_basic',
|
||||
'IndexPutImpl1DFloatAccumulateModule_basic',
|
||||
'IndexPutImpl2DFloatAccumulateModule_basic',
|
||||
'IndexPutImpl3DFloatAccumulateModule_basic',
|
||||
'IndexPutImpl1DIntAccumulateModule_basic',
|
||||
'IndexPut1DFloatNonAccumulateModule_basic',
|
||||
'IndexPut2DFloatNonAccumulateModule_basic',
|
||||
'IndexPut3DFloatNonAccumulateModule_basic',
|
||||
'IndexPut1DIntNonAccumulateModule_basic',
|
||||
'IndexPut2DIntNonAccumulateModule_basic',
|
||||
'IndexPut3DIntNonAccumulateModule_basic',
|
||||
'IndexPut1DFloatAccumulateModule_basic',
|
||||
'IndexPut2DFloatAccumulateModule_basic',
|
||||
'IndexPut3DFloatAccumulateModule_basic',
|
||||
'IndexPut1DIntAccumulateModule_basic',
|
||||
'IndexPut2DIntAccumulateModule_basic',
|
||||
'IndexPut3DIntAccumulateModule_basic',
|
||||
'IndexPutHackedTwin1DFloatNonAccumulateModule_basic',
|
||||
'IndexPutHackedTwin2DFloatNonAccumulateModule_basic',
|
||||
'IndexPutHackedTwin3DFloatNonAccumulateModule_basic',
|
||||
'IndexPutHackedTwin1DIntNonAccumulateModule_basic',
|
||||
'IndexPutHackedTwin2DIntNonAccumulateModule_basic',
|
||||
'IndexPutHackedTwin3DIntNonAccumulateModule_basic',
|
||||
'IndexPutHackedTwin1DFloatAccumulateModule_basic',
|
||||
'IndexPutHackedTwin2DFloatAccumulateModule_basic',
|
||||
'IndexPutHackedTwin3DFloatAccumulateModule_basic',
|
||||
'IndexPutHackedTwin1DIntAccumulateModule_basic',
|
||||
'IndexPutHackedTwin2DIntAccumulateModule_basic',
|
||||
'IndexPutHackedTwin3DIntAccumulateModule_basic',
|
||||
'ScatterValueFloatModule_basic',
|
||||
'ScatterValueIntModule_basic',
|
||||
'IndexPutImpl2DIndexModule_basic',
|
||||
'IndexPutImplIndexWithNoneModule_basic',
|
||||
'AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic',
|
||||
'AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic',
|
||||
'AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic',
|
||||
'AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic',
|
||||
'MaxPool3dModule_basic',
|
||||
'MaxPool3dModuleRandomSimple_basic',
|
||||
'MaxPool3dLargeDatadModule_basic',
|
||||
'MaxPool3dEmptyStrideStaticModule_basic',
|
||||
'MaxPool3dStaticModule_basic',
|
||||
'MaxPool3dStaticCeilModeTrueModule_basic',
|
||||
'MaxPool3dCeilModeTrueModule_basic',
|
||||
'AdaptiveAvgPool1dStaticEvenMultiple_basic',
|
||||
'AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic',
|
||||
'AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic',
|
||||
'TestMultipleTensorReturn_basic',
|
||||
'TestMultipleTensorAndPrimitiveTypesReturn_basic',
|
||||
'TestF16Return_basic',
|
||||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
|
|
|
@ -11,3 +11,4 @@ from .torchscript import TorchScriptTestConfig
|
|||
from .stablehlo_backend import StablehloBackendTestConfig
|
||||
from .tosa_backend import TosaBackendTestConfig
|
||||
from .torchdynamo import TorchDynamoTestConfig
|
||||
from .fx_importer_backend import FxImporterTestConfig
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Union, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from torch_mlir import fx
|
||||
from torch_mlir.torchscript import (
|
||||
_example_args,
|
||||
OutputType,
|
||||
BACKEND_LEGAL_OPS,
|
||||
run_pipeline_with_repro_report,
|
||||
_lower_mlir_module,
|
||||
_canon_extra_library,
|
||||
)
|
||||
from torch_mlir_e2e_test.configs.utils import (
|
||||
recursively_convert_to_numpy,
|
||||
recursively_convert_from_numpy,
|
||||
)
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
def refine_result_type(_result):
|
||||
if isinstance(_result, tuple):
|
||||
return tuple(refine_result_type(x) for x in _result)
|
||||
elif isinstance(_result, np.ndarray):
|
||||
return torch.from_numpy(_result)
|
||||
elif isinstance(_result, (bool, int, float)):
|
||||
return _result
|
||||
else:
|
||||
raise ValueError(f"Unhandled return type {type(_result)}")
|
||||
|
||||
|
||||
def jit(
|
||||
model: torch.nn.Module,
|
||||
example_args: _example_args,
|
||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
extra_library=None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
if extra_library is None:
|
||||
extra_library = []
|
||||
mlir_module = None
|
||||
|
||||
extra_library_file_name = _canon_extra_library(extra_library)
|
||||
output_type = OutputType.get(output_type)
|
||||
if backend_legal_ops is not None:
|
||||
if output_type != OutputType.TORCH:
|
||||
raise Exception("`backend_legal_ops` is only valid with the "
|
||||
"`torch` output type")
|
||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
|
||||
" extra-library=" + extra_library_file_name + "}")
|
||||
|
||||
mlir_module = fx.export_and_import(model, *example_args, func_name=model.__class__.__name__)
|
||||
assert mlir_module is not None
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
f"builtin.module(torch-simplification-pipeline)",
|
||||
"Simplification pipeline for torch dialect",
|
||||
)
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
|
||||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
return _lower_mlir_module(verbose, output_type, mlir_module)
|
||||
|
||||
|
||||
class FxImporterTestConfig(TestConfig):
|
||||
"""TestConfig that runs the torch.nn.Module with Fx Importer"""
|
||||
|
||||
def __init__(self, backend):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||
return program
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
module = jit(artifact,
|
||||
item.inputs,
|
||||
output_type="linalg-on-tensors")
|
||||
module = self.backend.compile(module)
|
||||
backend_module = self.backend.load(module)
|
||||
params = {
|
||||
# **dict(artifact.named_parameters(remove_duplicate=False)),
|
||||
**dict(artifact.named_buffers(remove_duplicate=False)),
|
||||
}
|
||||
params_flat, params_spec = pytree.tree_flatten(params)
|
||||
params_flat = list(params_flat)
|
||||
with torch.no_grad():
|
||||
numpy_inputs = recursively_convert_to_numpy(params_flat +
|
||||
item.inputs)
|
||||
outputs = getattr(backend_module,
|
||||
artifact.__class__.__name__)(*numpy_inputs)
|
||||
output = refine_result_type(outputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
return result
|
Loading…
Reference in New Issue