Rename torch_mlir.compile APIs and introduce FX based analogs (#2842)

Link to related RFC:
https://discourse.llvm.org/t/rfc-rename-torch-mlir-compile-apis-and-introduce-fx-based-analogs/76646
This commit updates the documentation, tests, CMake files, and API for
the proposed changes in the RFC. There is a new torch_mlir/fx.py for
user level APIs related to importing modules and a corresponding test
for this path can be found at test/python/fx_importer/basic_test.py.

---------

Co-authored-by: MaheshRavishankar <mravisha@amd.com>
int_view_hack
saienduri 2024-02-06 19:07:59 -08:00 committed by GitHub
parent cc06391630
commit bfcf93ea21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 146 additions and 106 deletions

View File

@ -184,7 +184,7 @@ semantics. And often users want to erase the shapes in the trace to allow
dynamic shapes for the trace. Additionally, the Python-level data structures and
APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we
consider both of those as the same from the perspective of the responsibilities
of the compiler. Both are accessed via the `torch_mlir.compile` Python API.
of the compiler. Both are accessed via the `torch_mlir.torchscript.compile` Python API.
### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript

View File

@ -120,37 +120,50 @@ cmake --build build
### Linux and macOS
```shell
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
```
### Windows PowerShell
```shell
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/projects/pt1/examples"
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer"
```
## Testing MLIR output in various dialects
To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`.
Make sure you have activated the virtualenv and set the `PYTHONPATH` above
(if running on Windows, modify the environment variable as shown above):
```shell
source mlir_venv/bin/activate
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
python test/python/fx_importer/basic_test.py
```
This will display the basic example in TORCH dialect.
To test the compiler's output to the different MLIR dialects, you can also use the deprecated path
using torchscript with the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
This path doesn't give access to the current generation work that is being driven via the fx_importer
and may lead to errors.
Same as above, but with different python path and example:
```shell
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
python projects/pt1/examples/torchscript_resnet18_all_output_types.py
```
This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA.
The main functionality is on `torch_mlir.compile()`'s `output_type`.
The main functionality is on `torch_mlir.torchscript.compile()`'s `output_type`.
Ex:
```python
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
module = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
```
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
`output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
## Jupyter

View File

@ -51,6 +51,22 @@ the ecosystem are:
Most of this document describes long-term ecosystem changes that will address
these, drastically improving Torch-MLIR's ability to meet its goals.
## Current API Paths
Currently, there are two main API paths for the torch-mlir project:
- The first path is part of the legacy project pt1 code
(torch_mlir.torchscript.compile). This allows users to test the compiler's
output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`,
`RAW` and `STABLEHLO`). This path is deprecated and doesnt give access to
the current generation work that is being driven via the fx_importer. It is
tied to the old Torchscript path.
- The second path (torch_mlir.fx.export_and_import) allows users to import a
consolidated torch.export.ExportedProgram instance of an arbitrary Python
callable (an nn.Module, a function or a method) and output to torch dialect
mlir module. This path is aligned with PyTorch's roadmap, but the path is
not fully functional yet.
## Roadmap
### Refactoring the frontend

View File

@ -14,7 +14,7 @@ import torch._dynamo as dynamo
import torchvision.models as models
from torchvision import transforms
import torch_mlir
from torch_mlir import torchscript
from torch_mlir.dynamo import make_simple_dynamo_backend
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
@ -71,7 +71,7 @@ labels = load_labels()
@make_simple_dynamo_backend
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
mlir_module = torch_mlir.compile(
mlir_module = torchscript.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors")
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module)

View File

@ -12,7 +12,7 @@ import torch
import torchvision.models as models
from torchvision import transforms
import torch_mlir
from torch_mlir import torchscript
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
@ -67,7 +67,7 @@ labels = load_labels()
resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)

View File

@ -6,15 +6,15 @@
import torch
import torchvision
import torch_mlir
from torch_mlir import torchscript
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.eval()
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
# TODO: Debug why this is so slow.
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))

View File

@ -184,7 +184,7 @@
"\n",
"# Compile the model with an example input.\n",
"# We lower to the linalg-on-tensors form that the reference backend supports.\n",
"compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
"compiled = torch_mlir.torchscript.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
"# Load it on the reference backend.\n",
"jit_module = compile_and_load_on_refbackend(compiled)\n",
"# Run it!\n",
@ -326,7 +326,7 @@
"source": [
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
"resnet18.eval()\n",
"compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
"compiled = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
"jit_module = compile_and_load_on_refbackend(compiled)"
]
},

View File

@ -1,13 +1,13 @@
import torch
import torchvision.models as models
import torch_mlir
from torch_mlir import torchscript
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))

View File

@ -1,5 +1,5 @@
import torch
import torch_mlir
from torch_mlir import torchscript
from transformers import BertForMaskedLM
@ -17,7 +17,7 @@ model.eval()
data = torch.randint(30522, (2, 128))
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))

View File

@ -18,7 +18,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
SOURCES
__init__.py
torchscript.py
_dynamo_fx_importer.py
compiler_utils.py
dynamo.py

View File

@ -6,7 +6,7 @@
# RUN: %PYTHON %s | FileCheck %s
import torch
import torch_mlir
from torch_mlir import torchscript
class BasicModule(torch.nn.Module):
@ -15,17 +15,17 @@ class BasicModule(torch.nn.Module):
return torch.ops.aten.sin(x)
example_args = torch_mlir.ExampleArgs()
example_args = torchscript.ExampleArgs()
example_args.add_method("sin", torch.ones(2, 3))
scripted = torch.jit.script(BasicModule())
print(torch_mlir.compile(scripted, example_args))
print(torchscript.compile(scripted, example_args))
# CHECK: module
# CHECK-DAG: func.func @sin
scripted = torch.jit.script(BasicModule())
try:
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
torch_mlir.compile(scripted, torch_mlir.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
except Exception as e:
print(e)

View File

@ -6,23 +6,23 @@
# RUN: %PYTHON %s | FileCheck %s
import torch
import torch_mlir
from torch_mlir import torchscript
class BasicModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.sin(x)
example_arg = torch.ones(2, 3)
example_args = torch_mlir.ExampleArgs.get(example_arg)
example_args = torchscript.ExampleArgs.get(example_arg)
traced = torch.jit.trace(BasicModule(), example_arg)
print(torch_mlir.compile(traced, example_args))
print(torchscript.compile(traced, example_args))
# CHECK: module
# CHECK-DAG: func.func @forward
traced = torch.jit.trace(BasicModule(), example_arg)
try:
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg))
torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg))
except Exception as e:
print(e)

View File

@ -7,7 +7,7 @@
import torch
import torch_mlir
from torch_mlir import torchscript
class AddmmModule(torch.nn.Module):
def __init__(self):
@ -15,9 +15,9 @@ class AddmmModule(torch.nn.Module):
def forward(self, x, y, z):
return torch.ops.aten.addmm(x, y, z)
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)]
print(torch_mlir.compile(AddmmModule(), example_args,
print(torchscript.compile(AddmmModule(), example_args,
output_type="torch", backend_legal_ops=["aten.addmm"]))
# CHECK-LABEL: @forward
# CHECK: torch.aten.addmm

View File

@ -7,7 +7,7 @@
import torch
import torch_mlir
from torch_mlir import torchscript
class TanhModule(torch.nn.Module):
def __init__(self):
@ -18,24 +18,24 @@ class TanhModule(torch.nn.Module):
tanh_example_input = torch.ones(2, 3)
# Simplest case: One example argument.
print(torch_mlir.compile(TanhModule(), tanh_example_input))
print(torchscript.compile(TanhModule(), tanh_example_input))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# Use a TensorPlaceholder to represent dynamic axes.
placeholder = torch_mlir.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
print(torch_mlir.compile(TanhModule(), placeholder))
placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
print(torchscript.compile(TanhModule(), placeholder))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
# Explicitly construct a TensorPlaceholder.
placeholder = torch_mlir.TensorPlaceholder([-1, 2], torch.float32)
print(torch_mlir.compile(TanhModule(), placeholder))
placeholder = torchscript.TensorPlaceholder([-1, 2], torch.float32)
print(torchscript.compile(TanhModule(), placeholder))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
# Basic smoke test for the raw output type.
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW))
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW))
# CHECK: torch.nn_module {
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">
@ -47,12 +47,12 @@ class MmModule(torch.nn.Module):
# N > 1 inputs.
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
print(torch_mlir.compile(MmModule(), mm_example_inputs))
print(torchscript.compile(MmModule(), mm_example_inputs))
# CHECK-LABEL: @forward
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
# Mixes Tensor's and TensorPlaceholder's.
mm_dynamic_inputs = [mm_example_inputs[0], torch_mlir.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
print(torch_mlir.compile(MmModule(), mm_dynamic_inputs))
mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
print(torchscript.compile(MmModule(), mm_dynamic_inputs))
# CHECK-LABEL: @forward
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>

View File

@ -8,7 +8,7 @@
import functorch
import torch
import torch_mlir
from torch_mlir import torchscript
def simple(x):
return x * x
@ -17,6 +17,6 @@ example_input = torch.randn(1,)
graph = functorch.make_fx(simple)(torch.randn(1,))
# Simplest case: One example argument.
print(torch_mlir.compile(graph, example_input))
print(torchscript.compile(graph, example_input))
# CHECK-LABEL: @forward
# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>

View File

@ -6,7 +6,7 @@
# RUN: %PYTHON %s | FileCheck %s
import torch
import torch_mlir
from torch_mlir import torchscript
class TwoMethodsModule(torch.nn.Module):
@ -17,14 +17,14 @@ class TwoMethodsModule(torch.nn.Module):
return torch.ops.aten.cos(x)
example_args = torch_mlir.ExampleArgs()
example_args = torchscript.ExampleArgs()
example_args.add_method("sin", torch.ones(2, 3))
example_args.add_method("cos", torch.ones(2, 4))
# Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to
# check the `use_tracing` case first.
print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
print(torchscript.compile(TwoMethodsModule(), example_args, use_tracing=True))
# CHECK: module
# CHECK-DAG: func.func @sin
# CHECK-DAG: func.func @cos
@ -34,8 +34,8 @@ print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
# Otherwise the user would have to do this manually, which is tedious. This
# technically mutates the user input model which is not great but probably okay
# for this kind of API sugar. Users can always take full control of the process
# by scripting the model themselves before passing it to `torch_mlir.compile`.
print(torch_mlir.compile(TwoMethodsModule(), example_args))
# by scripting the model themselves before passing it to `torchscript.compile`.
print(torchscript.compile(TwoMethodsModule(), example_args))
# CHECK: module
# CHECK-DAG: func.func @sin
# CHECK-DAG: func.func @cos

View File

@ -7,7 +7,7 @@
import torch
import torch_mlir
from torch_mlir import torchscript
class TanhModule(torch.nn.Module):
def __init__(self):
@ -17,9 +17,9 @@ class TanhModule(torch.nn.Module):
tanh_example_input = torch.ones(2, 3)
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.TORCH))
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type="torch"))
print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch"))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>

View File

@ -7,7 +7,7 @@
import torch
import torch_mlir
from torch_mlir import torchscript
class TanhModule(torch.nn.Module):
@ -17,38 +17,38 @@ class TanhModule(torch.nn.Module):
tanh_example_input = torch.ones(2, 3)
# Simplest case: One example argument.
print(torch_mlir.compile(TanhModule(), tanh_example_input, use_tracing=True))
print(torchscript.compile(TanhModule(), tanh_example_input, use_tracing=True))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# Simplest case: Passed as a tuple.
print(torch_mlir.compile(TanhModule(), (tanh_example_input,), use_tracing=True))
print(torchscript.compile(TanhModule(), (tanh_example_input,), use_tracing=True))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# Simplest case: Passed as a list.
print(torch_mlir.compile(TanhModule(), [tanh_example_input], use_tracing=True))
print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# TensorPlaceholder support.
placeholder = torch_mlir.TensorPlaceholder.like(
placeholder = torchscript.TensorPlaceholder.like(
tanh_example_input, dynamic_axes=[1])
print(torch_mlir.compile(TanhModule(), [placeholder],
print(torchscript.compile(TanhModule(), [placeholder],
use_tracing=True, ignore_traced_shapes=True))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
try:
# CHECK: `ignore_traced_shapes` requires `use_tracing`
torch_mlir.compile(TanhModule(), [placeholder], ignore_traced_shapes=True)
torchscript.compile(TanhModule(), [placeholder], ignore_traced_shapes=True)
except Exception as e:
print(e)
try:
# CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`
torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True)
torchscript.compile(TanhModule(), [placeholder], use_tracing=True)
except Exception as e:
print(e)
@ -60,13 +60,13 @@ class DictModule(torch.nn.Module):
try:
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
except Exception as e:
print(e)
try:
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
except Exception as e:
print(e)

View File

@ -125,7 +125,7 @@ def make_simple_dynamo_backend(user_backend):
Args:
user_backend: A function with the signature used by ordinary
TorchDynamo backends. But the torch.fx.GraphModule passed to it
will be normalized for consumption by `torch_mlir.compile`.
will be normalized for consumption by `torchscript.compile`.
Returns:
A function with the signature used by TorchDynamo backends.
"""

View File

@ -22,7 +22,7 @@ from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_li
class OutputType(Enum):
"""The kind of output that `torch_mlir.compile` can produce.
"""The kind of output that `torchscript.compile` can produce.
In MLIR terminology, this describes the mix of dialects that will be
produced by the conversion process.
@ -392,13 +392,13 @@ def compile(model: torch.nn.Module,
strip_overloads(model)
# Get the model as JIT IR (TorchScript) for import.
# TODO: Longer-term, we probably need to split `torch_mlir.compile`.
# TODO: Longer-term, we probably need to split `torchscript.compile`.
# There should be an "acquisition" step that does
# tracing/scripting/importing from FX/using torchdynamo.export/etc.
# + any lowering to the backend contract. Then there should be a
# "backend lowering" step that does the actual lowering to each
# backend. This separation should be visible at the Python API level, and
# we can implement a deliberately simplified API like `torch_mlir.compile`
# we can implement a deliberately simplified API like `torchscript.compile`
# on top of those building blocks.
if isinstance(model, torch.jit.ScriptModule):
# If the user already converted the model to JIT IR themselves, just

View File

@ -6,7 +6,7 @@
from typing import Any
import torch
import torch_mlir
from torch_mlir import torchscript
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
@ -30,7 +30,7 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(
module = torchscript.compile(
program, example_args, output_type="linalg-on-tensors")
return self.backend.compile(module)

View File

@ -6,7 +6,7 @@
from typing import Any
import torch
import torch_mlir
from torch_mlir import torchscript
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
@ -30,7 +30,7 @@ class StablehloBackendTestConfig(TestConfig):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(program, example_args, output_type="stablehlo")
module = torchscript.compile(program, example_args, output_type="stablehlo")
return self.backend.compile(module)

View File

@ -17,7 +17,7 @@ from torch._functorch.aot_autograd import (
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
from torch_mlir.dynamo import _get_decomposition_table
from torch_mlir import (
from torch_mlir.torchscript import (
_example_args,
OutputType,
BACKEND_LEGAL_OPS,

View File

@ -6,7 +6,7 @@
from typing import Any
import torch
import torch_mlir
from torch_mlir import torchscript
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
@ -30,7 +30,7 @@ class TosaBackendTestConfig(TestConfig):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(
module = torchscript.compile(
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
return self.backend.compile(module)

View File

@ -3,13 +3,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from torch_mlir import TensorPlaceholder
from torch_mlir.torchscript import TensorPlaceholder
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
def convert_annotations_to_placeholders(forward_method):
"""Converts the annotations on a forward method into tensor placeholders.
These placeholders are suitable for being passed to `torch_mlir.compile`.
These placeholders are suitable for being passed to `torchscript.compile`.
"""
annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
placeholders = []

View File

@ -5,7 +5,7 @@ from typing import List, Tuple
import torch
import torch.multiprocessing as mp
import torch.utils.cpp_extension
import torch_mlir
from torch_mlir import torchscript
from torch_mlir_e2e_test.annotations import export, annotate_args
@ -56,7 +56,7 @@ def run():
mod = CustomOpExampleModule()
mod.eval()
module = torch_mlir.compile(
module = torchscript.compile(
mod,
torch.ones(3, 4),
output_type="torch",

View File

@ -1,5 +1,5 @@
import torch
import torch_mlir
from torch_mlir import torchscript
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
@ -39,6 +39,6 @@ class Model(torch.nn.Module):
with torch.no_grad():
return data
output_type = torch_mlir.OutputType.RAW
mod = torch_mlir.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type)
output_type = torchscript.OutputType.RAW
mod = torchscript.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type)
print(mod)

View File

@ -39,6 +39,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers
extras/onnx_importer.py
)
declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
fx.py
)
declare_mlir_python_sources(TorchMLIRPythonSources.Tools
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources

View File

@ -0,0 +1,25 @@
from typing import Optional
import torch
import torch.export
import torch.nn as nn
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
constraints: Optional[torch.export.Constraint] = None,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context)
prog = torch.export.export(f, args, kwargs, constraints=constraints)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op

View File

@ -3,7 +3,7 @@
import gc
import sys
import torch
import torch_mlir
from torch_mlir import torchscript
def run_test(f):
@ -26,7 +26,7 @@ class TinyModel(torch.nn.Module):
# CHECK-LABEL: TEST: test_enable_ir_printing
@run_test
def test_enable_ir_printing():
torch_mlir.compile(TinyModel(),
torchscript.compile(TinyModel(),
torch.ones(1, 3, 20, 20),
output_type="linalg-on-tensors",
enable_ir_printing=True)

View File

@ -1,5 +1,3 @@
# Copyright 2023 Advanced Micro Devices, Inc
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@ -13,26 +11,7 @@ import torch
import torch.export
import torch.nn as nn
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
constraints: Optional[torch.export.Constraint] = None,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context)
prog = torch.export.export(f, args, kwargs, constraints=constraints)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op
from torch_mlir import fx
def run(f):
@ -75,5 +54,5 @@ def test_import_frozen_exported_program():
def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p
m = export_and_import(Basic(), torch.randn(3, 4))
m = fx.export_and_import(Basic(), torch.randn(3, 4))
print(m)