mirror of https://github.com/llvm/torch-mlir
Add make_fx_tosa variant to end2end tests (#2240)
* Add make_fx_tosa variant to end2end tests * e2e_testing/xfail_sets.py: Add make_fx_tosa xfail for stablepull/2303/head
parent
91c6454618
commit
f8e75f659d
|
@ -301,6 +301,9 @@ function test_in_tree() {
|
|||
;;
|
||||
esac
|
||||
|
||||
echo ":::: Run make_fx + TOSA e2e integration tests"
|
||||
python -m e2e_testing.main --config=make_fx_tosa -v
|
||||
|
||||
echo ":::: Run TorchDynamo e2e integration tests"
|
||||
python -m e2e_testing.main --config=torchdynamo -v
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsT
|
|||
|
||||
from .xfail_sets import (
|
||||
LINALG_XFAIL_SET,
|
||||
MAKE_FX_TOSA_PASS_SET,
|
||||
STABLEHLO_PASS_SET,
|
||||
TOSA_PASS_SET,
|
||||
LTC_XFAIL_SET,
|
||||
|
@ -42,7 +43,7 @@ from torch_mlir_e2e_test.test_suite import register_all_tests
|
|||
register_all_tests()
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
|
||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||
parser.add_argument("-c", "--config",
|
||||
choices=config_choices,
|
||||
|
@ -94,6 +95,10 @@ def main():
|
|||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||
crashing_set = set()
|
||||
elif args.config == "make_fx_tosa":
|
||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True)
|
||||
xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET
|
||||
crashing_set = set()
|
||||
elif args.config == "stablehlo":
|
||||
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
||||
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# might be used to keep more elaborate sets of testing configurations).
|
||||
|
||||
from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
from torch_mlir._version import torch_version_for_comparison, version
|
||||
|
||||
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
|
@ -1113,6 +1114,41 @@ TOSA_PASS_SET = {
|
|||
"ChunkListUnpackUneven_Module_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"TensorFloatModule_basic",
|
||||
"TensorIntModule_basic",
|
||||
}) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
|
||||
# failed to lower torch.aten.empty.memory_format
|
||||
"BatchNorm1DModule_basic",
|
||||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
"BatchNorm3DModule_basic",
|
||||
"BatchNorm1DStaticShapeModule_basic",
|
||||
|
||||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
|
||||
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
|
||||
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||
MAKE_FX_TOSA_PASS_SET -= {
|
||||
# 'tensor.expand_shape' op expected rank expansion, but found source rank 1 >= result rank 1
|
||||
"ReshapeCollapseModule_basic",
|
||||
}
|
||||
|
||||
LTC_CRASHING_SET = {
|
||||
# https://github.com/llvm/torch-mlir/issues/2186
|
||||
"Add_Module_basic"
|
||||
|
|
|
@ -13,6 +13,8 @@ import tempfile
|
|||
from torch._functorch.compile_utils import strip_overloads
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch_mlir.dynamo import _get_decomposition_table
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from .compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
|
@ -225,8 +227,11 @@ class ExampleArgs:
|
|||
# they know what they are doing and that their trace is
|
||||
# correct for any specific concrete size.
|
||||
shape = [s if s != -1 else 7 for s in arg.shape]
|
||||
example_args_for_trace.append(
|
||||
torch.ones(*shape, dtype=arg.dtype))
|
||||
if len(shape) == 0:
|
||||
example_args_for_trace.append(torch.tensor(1))
|
||||
else:
|
||||
example_args_for_trace.append(
|
||||
torch.ones(*shape, dtype=arg.dtype))
|
||||
else:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
example_args_for_trace.append(arg)
|
||||
|
@ -313,7 +318,8 @@ def compile(model: torch.nn.Module,
|
|||
ignore_traced_shapes=False,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
extra_library: Iterable[Callable] = [],
|
||||
verbose: bool = False):
|
||||
verbose: bool = False,
|
||||
use_make_fx: bool = False):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
Args:
|
||||
|
@ -367,6 +373,13 @@ def compile(model: torch.nn.Module,
|
|||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
if use_make_fx:
|
||||
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)["forward"]
|
||||
model = make_fx(
|
||||
model,
|
||||
decomposition_table=_get_decomposition_table())(*args)
|
||||
|
||||
|
||||
# For FX-based models, automatically strip overloads.
|
||||
if isinstance(model, torch.fx.GraphModule):
|
||||
strip_overloads(model)
|
||||
|
|
|
@ -23,14 +23,15 @@ class TosaBackendTestConfig(TestConfig):
|
|||
This class handles all the common lowering that torch-mlir does before
|
||||
reaching the linalg-on-tensors abstraction level.
|
||||
"""
|
||||
def __init__(self, backend: TosaBackend):
|
||||
def __init__(self, backend: TosaBackend, use_make_fx: bool = False):
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
self.use_make_fx = use_make_fx
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
example_args = convert_annotations_to_placeholders(program.forward)
|
||||
module = torch_mlir.compile(
|
||||
program, example_args, output_type="tosa")
|
||||
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
|
||||
|
||||
return self.backend.compile(module)
|
||||
|
||||
|
|
Loading…
Reference in New Issue