From e2de20575f92264f1131039a4d1393833a893f4c Mon Sep 17 00:00:00 2001 From: Daniel Ellis <1346302+dellis23@users.noreply.github.com> Date: Tue, 29 Nov 2022 22:19:09 -0500 Subject: [PATCH] Automatically strip overloads for FX-based models. --- python/test/compile_api/make_fx.py | 22 +++++++++++++++++++ python/torch_mlir/__init__.py | 5 +++++ .../torch_mlir_e2e_test/test_suite/basic.py | 3 ++- 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 python/test/compile_api/make_fx.py diff --git a/python/test/compile_api/make_fx.py b/python/test/compile_api/make_fx.py new file mode 100644 index 000000000..62add20a5 --- /dev/null +++ b/python/test/compile_api/make_fx.py @@ -0,0 +1,22 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import functorch +import torch + +import torch_mlir + +def simple(x): + return x * x + +example_input = torch.randn(1,) +graph = functorch.make_fx(simple)(torch.randn(1,)) + +# Simplest case: One example argument. +print(torch_mlir.compile(graph, example_input)) +# CHECK-LABEL: @forward +# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> \ No newline at end of file diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 9b135ad24..ff80d7d9b 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -9,6 +9,7 @@ from enum import Enum import sys from io import StringIO +from functorch._src.compile_utils import strip_overloads import torch from torch_mlir.passmanager import PassManager @@ -300,6 +301,10 @@ def compile(model: torch.nn.Module, else: backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) + # For FX-based models, automatically strip overloads. + if isinstance(model, torch.fx.GraphModule): + strip_overloads(model) + # Get the model as JIT IR (TorchScript) for import. # TODO: Longer-term, we probably need to split `torch_mlir.compile`. # There should be an "acquisition" step that does diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 3ced7b46b..9e04ab2d5 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +import functorch import torch from torch_mlir_e2e_test.framework import TestUtils @@ -3097,4 +3098,4 @@ class SortIntListReverse(torch.nn.Module): @register_test_case(module_factory=lambda: SortIntListReverse()) def SortIntListReverse_basic(module, tu: TestUtils): - module.forward() + module.forward() \ No newline at end of file