Add make_fx_tosa variant to end2end tests (#2240)

* Add make_fx_tosa variant to end2end tests

* e2e_testing/xfail_sets.py: Add make_fx_tosa xfail for stable
pull/2303/head
Matthias Gehre 2023-07-13 15:07:54 +02:00 committed by GitHub
parent 91c6454618
commit f8e75f659d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 7 deletions

View File

@ -300,7 +300,10 @@ function test_in_tree() {
exit 1
;;
esac
echo ":::: Run make_fx + TOSA e2e integration tests"
python -m e2e_testing.main --config=make_fx_tosa -v
echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v

View File

@ -29,6 +29,7 @@ from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsT
from .xfail_sets import (
LINALG_XFAIL_SET,
MAKE_FX_TOSA_PASS_SET,
STABLEHLO_PASS_SET,
TOSA_PASS_SET,
LTC_XFAIL_SET,
@ -42,7 +43,7 @@ from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()
def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config",
choices=config_choices,
@ -94,6 +95,10 @@ def main():
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
crashing_set = set()
elif args.config == "make_fx_tosa":
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True)
xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET
crashing_set = set()
elif args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET

View File

@ -11,6 +11,7 @@
# might be used to keep more elaborate sets of testing configurations).
from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
from torch_mlir._version import torch_version_for_comparison, version
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
@ -1113,6 +1114,41 @@ TOSA_PASS_SET = {
"ChunkListUnpackUneven_Module_basic",
}
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
### Tests additionally passing in make_fx_tosa
"NativeGroupNormBackwardModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
}) - {
### Test failing in make_fx_tosa but not in tosa
# failed to lower torch.aten.empty.memory_format
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool2dStaticCeilModeTrueModule_basic",
"MaxPool2dStaticModule_basic",
"ResNet18StaticModule_basic",
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
}
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
MAKE_FX_TOSA_PASS_SET -= {
# 'tensor.expand_shape' op expected rank expansion, but found source rank 1 >= result rank 1
"ReshapeCollapseModule_basic",
}
LTC_CRASHING_SET = {
# https://github.com/llvm/torch-mlir/issues/2186
"Add_Module_basic"

View File

@ -13,6 +13,8 @@ import tempfile
from torch._functorch.compile_utils import strip_overloads
import torch
import torch.fx
from torch_mlir.dynamo import _get_decomposition_table
from torch.fx.experimental.proxy_tensor import make_fx
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
@ -225,8 +227,11 @@ class ExampleArgs:
# they know what they are doing and that their trace is
# correct for any specific concrete size.
shape = [s if s != -1 else 7 for s in arg.shape]
example_args_for_trace.append(
torch.ones(*shape, dtype=arg.dtype))
if len(shape) == 0:
example_args_for_trace.append(torch.tensor(1))
else:
example_args_for_trace.append(
torch.ones(*shape, dtype=arg.dtype))
else:
assert isinstance(arg, torch.Tensor)
example_args_for_trace.append(arg)
@ -313,7 +318,8 @@ def compile(model: torch.nn.Module,
ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library: Iterable[Callable] = [],
verbose: bool = False):
verbose: bool = False,
use_make_fx: bool = False):
"""Convert a PyTorch model to MLIR.
Args:
@ -367,6 +373,13 @@ def compile(model: torch.nn.Module,
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
if use_make_fx:
args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)["forward"]
model = make_fx(
model,
decomposition_table=_get_decomposition_table())(*args)
# For FX-based models, automatically strip overloads.
if isinstance(model, torch.fx.GraphModule):
strip_overloads(model)

View File

@ -23,14 +23,15 @@ class TosaBackendTestConfig(TestConfig):
This class handles all the common lowering that torch-mlir does before
reaching the linalg-on-tensors abstraction level.
"""
def __init__(self, backend: TosaBackend):
def __init__(self, backend: TosaBackend, use_make_fx: bool = False):
super().__init__()
self.backend = backend
self.use_make_fx = use_make_fx
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(
program, example_args, output_type="tosa")
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
return self.backend.compile(module)