mirror of https://github.com/llvm/torch-mlir
Automatically strip overloads for FX-based models.
parent
a8cbfff95b
commit
e2de20575f
|
@ -0,0 +1,22 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import functorch
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
|
||||
def simple(x):
|
||||
return x * x
|
||||
|
||||
example_input = torch.randn(1,)
|
||||
graph = functorch.make_fx(simple)(torch.randn(1,))
|
||||
|
||||
# Simplest case: One example argument.
|
||||
print(torch_mlir.compile(graph, example_input))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>
|
|
@ -9,6 +9,7 @@ from enum import Enum
|
|||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from functorch._src.compile_utils import strip_overloads
|
||||
import torch
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
|
@ -300,6 +301,10 @@ def compile(model: torch.nn.Module,
|
|||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
# For FX-based models, automatically strip overloads.
|
||||
if isinstance(model, torch.fx.GraphModule):
|
||||
strip_overloads(model)
|
||||
|
||||
# Get the model as JIT IR (TorchScript) for import.
|
||||
# TODO: Longer-term, we probably need to split `torch_mlir.compile`.
|
||||
# There should be an "acquisition" step that does
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import functorch
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
|
@ -3097,4 +3098,4 @@ class SortIntListReverse(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SortIntListReverse())
|
||||
def SortIntListReverse_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
module.forward()
|
Loading…
Reference in New Issue