Automatically strip overloads for FX-based models.

pull/1663/head
Daniel Ellis 2022-11-29 22:19:09 -05:00 committed by GitHub
parent a8cbfff95b
commit e2de20575f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 1 deletions

View File

@ -0,0 +1,22 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
import functorch
import torch
import torch_mlir
def simple(x):
return x * x
example_input = torch.randn(1,)
graph = functorch.make_fx(simple)(torch.randn(1,))
# Simplest case: One example argument.
print(torch_mlir.compile(graph, example_input))
# CHECK-LABEL: @forward
# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>

View File

@ -9,6 +9,7 @@ from enum import Enum
import sys
from io import StringIO
from functorch._src.compile_utils import strip_overloads
import torch
from torch_mlir.passmanager import PassManager
@ -300,6 +301,10 @@ def compile(model: torch.nn.Module,
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
# For FX-based models, automatically strip overloads.
if isinstance(model, torch.fx.GraphModule):
strip_overloads(model)
# Get the model as JIT IR (TorchScript) for import.
# TODO: Longer-term, we probably need to split `torch_mlir.compile`.
# There should be an "acquisition" step that does

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import functorch
import torch
from torch_mlir_e2e_test.framework import TestUtils
@ -3097,4 +3098,4 @@ class SortIntListReverse(torch.nn.Module):
@register_test_case(module_factory=lambda: SortIntListReverse())
def SortIntListReverse_basic(module, tu: TestUtils):
module.forward()
module.forward()