Rename torch_mlir.compile APIs and introduce FX based analogs (#2842)

Link to related RFC:
https://discourse.llvm.org/t/rfc-rename-torch-mlir-compile-apis-and-introduce-fx-based-analogs/76646
This commit updates the documentation, tests, CMake files, and API for
the proposed changes in the RFC. There is a new torch_mlir/fx.py for
user level APIs related to importing modules and a corresponding test
for this path can be found at test/python/fx_importer/basic_test.py.

---------

Co-authored-by: MaheshRavishankar <mravisha@amd.com>
int_view_hack
saienduri 2024-02-06 19:07:59 -08:00 committed by GitHub
parent cc06391630
commit bfcf93ea21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 146 additions and 106 deletions

View File

@ -184,7 +184,7 @@ semantics. And often users want to erase the shapes in the trace to allow
dynamic shapes for the trace. Additionally, the Python-level data structures and dynamic shapes for the trace. Additionally, the Python-level data structures and
APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we
consider both of those as the same from the perspective of the responsibilities consider both of those as the same from the perspective of the responsibilities
of the compiler. Both are accessed via the `torch_mlir.compile` Python API. of the compiler. Both are accessed via the `torch_mlir.torchscript.compile` Python API.
### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript ### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript

View File

@ -120,37 +120,50 @@ cmake --build build
### Linux and macOS ### Linux and macOS
```shell ```shell
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
``` ```
### Windows PowerShell ### Windows PowerShell
```shell ```shell
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/projects/pt1/examples" $env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer"
``` ```
## Testing MLIR output in various dialects ## Testing MLIR output in various dialects
To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`. To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`.
Make sure you have activated the virtualenv and set the `PYTHONPATH` above Make sure you have activated the virtualenv and set the `PYTHONPATH` above
(if running on Windows, modify the environment variable as shown above): (if running on Windows, modify the environment variable as shown above):
```shell ```shell
source mlir_venv/bin/activate source mlir_venv/bin/activate
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
python test/python/fx_importer/basic_test.py
```
This will display the basic example in TORCH dialect.
To test the compiler's output to the different MLIR dialects, you can also use the deprecated path
using torchscript with the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
This path doesn't give access to the current generation work that is being driven via the fx_importer
and may lead to errors.
Same as above, but with different python path and example:
```shell
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
python projects/pt1/examples/torchscript_resnet18_all_output_types.py python projects/pt1/examples/torchscript_resnet18_all_output_types.py
``` ```
This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA. This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA.
The main functionality is on `torch_mlir.compile()`'s `output_type`. The main functionality is on `torch_mlir.torchscript.compile()`'s `output_type`.
Ex: Ex:
```python ```python
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") module = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
``` ```
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
## Jupyter ## Jupyter

View File

@ -51,6 +51,22 @@ the ecosystem are:
Most of this document describes long-term ecosystem changes that will address Most of this document describes long-term ecosystem changes that will address
these, drastically improving Torch-MLIR's ability to meet its goals. these, drastically improving Torch-MLIR's ability to meet its goals.
## Current API Paths
Currently, there are two main API paths for the torch-mlir project:
- The first path is part of the legacy project pt1 code
(torch_mlir.torchscript.compile). This allows users to test the compiler's
output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`,
`RAW` and `STABLEHLO`). This path is deprecated and doesnt give access to
the current generation work that is being driven via the fx_importer. It is
tied to the old Torchscript path.
- The second path (torch_mlir.fx.export_and_import) allows users to import a
consolidated torch.export.ExportedProgram instance of an arbitrary Python
callable (an nn.Module, a function or a method) and output to torch dialect
mlir module. This path is aligned with PyTorch's roadmap, but the path is
not fully functional yet.
## Roadmap ## Roadmap
### Refactoring the frontend ### Refactoring the frontend

View File

@ -14,7 +14,7 @@ import torch._dynamo as dynamo
import torchvision.models as models import torchvision.models as models
from torchvision import transforms from torchvision import transforms
import torch_mlir from torch_mlir import torchscript
from torch_mlir.dynamo import make_simple_dynamo_backend from torch_mlir.dynamo import make_simple_dynamo_backend
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
@ -71,7 +71,7 @@ labels = load_labels()
@make_simple_dynamo_backend @make_simple_dynamo_backend
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]): example_inputs: List[torch.Tensor]):
mlir_module = torch_mlir.compile( mlir_module = torchscript.compile(
fx_graph, example_inputs, output_type="linalg-on-tensors") fx_graph, example_inputs, output_type="linalg-on-tensors")
backend = refbackend.RefBackendLinalgOnTensorsBackend() backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module) compiled = backend.compile(mlir_module)

View File

@ -12,7 +12,7 @@ import torch
import torchvision.models as models import torchvision.models as models
from torchvision import transforms from torchvision import transforms
import torch_mlir from torch_mlir import torchscript
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
@ -67,7 +67,7 @@ labels = load_labels()
resnet18 = models.resnet18(pretrained=True) resnet18 = models.resnet18(pretrained=True)
resnet18.train(False) resnet18.train(False)
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
backend = refbackend.RefBackendLinalgOnTensorsBackend() backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module) compiled = backend.compile(module)
jit_module = backend.load(compiled) jit_module = backend.load(compiled)

View File

@ -6,15 +6,15 @@
import torch import torch
import torchvision import torchvision
import torch_mlir from torch_mlir import torchscript
resnet18 = torchvision.models.resnet18(pretrained=True) resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.eval() resnet18.eval()
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10)) print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10)) print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
# TODO: Debug why this is so slow. # TODO: Debug why this is so slow.
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10)) print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))

View File

@ -184,7 +184,7 @@
"\n", "\n",
"# Compile the model with an example input.\n", "# Compile the model with an example input.\n",
"# We lower to the linalg-on-tensors form that the reference backend supports.\n", "# We lower to the linalg-on-tensors form that the reference backend supports.\n",
"compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n", "compiled = torch_mlir.torchscript.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
"# Load it on the reference backend.\n", "# Load it on the reference backend.\n",
"jit_module = compile_and_load_on_refbackend(compiled)\n", "jit_module = compile_and_load_on_refbackend(compiled)\n",
"# Run it!\n", "# Run it!\n",
@ -326,7 +326,7 @@
"source": [ "source": [
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n", "resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
"resnet18.eval()\n", "resnet18.eval()\n",
"compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n", "compiled = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
"jit_module = compile_and_load_on_refbackend(compiled)" "jit_module = compile_and_load_on_refbackend(compiled)"
] ]
}, },

View File

@ -1,13 +1,13 @@
import torch import torch
import torchvision.models as models import torchvision.models as models
import torch_mlir from torch_mlir import torchscript
model = models.resnet18(pretrained=True) model = models.resnet18(pretrained=True)
model.eval() model.eval()
data = torch.randn(2,3,200,200) data = torch.randn(2,3,200,200)
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir" out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False) module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module)) outf.write(str(module))

View File

@ -1,5 +1,5 @@
import torch import torch
import torch_mlir from torch_mlir import torchscript
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
@ -17,7 +17,7 @@ model.eval()
data = torch.randint(30522, (2, 128)) data = torch.randint(30522, (2, 128))
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True) module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module)) outf.write(str(module))

View File

@ -18,7 +18,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
SOURCES SOURCES
__init__.py torchscript.py
_dynamo_fx_importer.py _dynamo_fx_importer.py
compiler_utils.py compiler_utils.py
dynamo.py dynamo.py

View File

@ -6,7 +6,7 @@
# RUN: %PYTHON %s | FileCheck %s # RUN: %PYTHON %s | FileCheck %s
import torch import torch
import torch_mlir from torch_mlir import torchscript
class BasicModule(torch.nn.Module): class BasicModule(torch.nn.Module):
@ -15,17 +15,17 @@ class BasicModule(torch.nn.Module):
return torch.ops.aten.sin(x) return torch.ops.aten.sin(x)
example_args = torch_mlir.ExampleArgs() example_args = torchscript.ExampleArgs()
example_args.add_method("sin", torch.ones(2, 3)) example_args.add_method("sin", torch.ones(2, 3))
scripted = torch.jit.script(BasicModule()) scripted = torch.jit.script(BasicModule())
print(torch_mlir.compile(scripted, example_args)) print(torchscript.compile(scripted, example_args))
# CHECK: module # CHECK: module
# CHECK-DAG: func.func @sin # CHECK-DAG: func.func @sin
scripted = torch.jit.script(BasicModule()) scripted = torch.jit.script(BasicModule())
try: try:
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
torch_mlir.compile(scripted, torch_mlir.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))) torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
except Exception as e: except Exception as e:
print(e) print(e)

View File

@ -6,23 +6,23 @@
# RUN: %PYTHON %s | FileCheck %s # RUN: %PYTHON %s | FileCheck %s
import torch import torch
import torch_mlir from torch_mlir import torchscript
class BasicModule(torch.nn.Module): class BasicModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
return torch.ops.aten.sin(x) return torch.ops.aten.sin(x)
example_arg = torch.ones(2, 3) example_arg = torch.ones(2, 3)
example_args = torch_mlir.ExampleArgs.get(example_arg) example_args = torchscript.ExampleArgs.get(example_arg)
traced = torch.jit.trace(BasicModule(), example_arg) traced = torch.jit.trace(BasicModule(), example_arg)
print(torch_mlir.compile(traced, example_args)) print(torchscript.compile(traced, example_args))
# CHECK: module # CHECK: module
# CHECK-DAG: func.func @forward # CHECK-DAG: func.func @forward
traced = torch.jit.trace(BasicModule(), example_arg) traced = torch.jit.trace(BasicModule(), example_arg)
try: try:
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg)) torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg))
except Exception as e: except Exception as e:
print(e) print(e)

View File

@ -7,7 +7,7 @@
import torch import torch
import torch_mlir from torch_mlir import torchscript
class AddmmModule(torch.nn.Module): class AddmmModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -15,9 +15,9 @@ class AddmmModule(torch.nn.Module):
def forward(self, x, y, z): def forward(self, x, y, z):
return torch.ops.aten.addmm(x, y, z) return torch.ops.aten.addmm(x, y, z)
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)] example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)]
print(torch_mlir.compile(AddmmModule(), example_args, print(torchscript.compile(AddmmModule(), example_args,
output_type="torch", backend_legal_ops=["aten.addmm"])) output_type="torch", backend_legal_ops=["aten.addmm"]))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.addmm # CHECK: torch.aten.addmm

View File

@ -7,7 +7,7 @@
import torch import torch
import torch_mlir from torch_mlir import torchscript
class TanhModule(torch.nn.Module): class TanhModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -18,24 +18,24 @@ class TanhModule(torch.nn.Module):
tanh_example_input = torch.ones(2, 3) tanh_example_input = torch.ones(2, 3)
# Simplest case: One example argument. # Simplest case: One example argument.
print(torch_mlir.compile(TanhModule(), tanh_example_input)) print(torchscript.compile(TanhModule(), tanh_example_input))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# Use a TensorPlaceholder to represent dynamic axes. # Use a TensorPlaceholder to represent dynamic axes.
placeholder = torch_mlir.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1]) placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
print(torch_mlir.compile(TanhModule(), placeholder)) print(torchscript.compile(TanhModule(), placeholder))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
# Explicitly construct a TensorPlaceholder. # Explicitly construct a TensorPlaceholder.
placeholder = torch_mlir.TensorPlaceholder([-1, 2], torch.float32) placeholder = torchscript.TensorPlaceholder([-1, 2], torch.float32)
print(torch_mlir.compile(TanhModule(), placeholder)) print(torchscript.compile(TanhModule(), placeholder))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
# Basic smoke test for the raw output type. # Basic smoke test for the raw output type.
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW)) print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW))
# CHECK: torch.nn_module { # CHECK: torch.nn_module {
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule"> # CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">
@ -47,12 +47,12 @@ class MmModule(torch.nn.Module):
# N > 1 inputs. # N > 1 inputs.
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)] mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
print(torch_mlir.compile(MmModule(), mm_example_inputs)) print(torchscript.compile(MmModule(), mm_example_inputs))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32> # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
# Mixes Tensor's and TensorPlaceholder's. # Mixes Tensor's and TensorPlaceholder's.
mm_dynamic_inputs = [mm_example_inputs[0], torch_mlir.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])] mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
print(torch_mlir.compile(MmModule(), mm_dynamic_inputs)) print(torchscript.compile(MmModule(), mm_dynamic_inputs))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32> # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>

View File

@ -8,7 +8,7 @@
import functorch import functorch
import torch import torch
import torch_mlir from torch_mlir import torchscript
def simple(x): def simple(x):
return x * x return x * x
@ -17,6 +17,6 @@ example_input = torch.randn(1,)
graph = functorch.make_fx(simple)(torch.randn(1,)) graph = functorch.make_fx(simple)(torch.randn(1,))
# Simplest case: One example argument. # Simplest case: One example argument.
print(torch_mlir.compile(graph, example_input)) print(torchscript.compile(graph, example_input))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> # CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>

View File

@ -6,7 +6,7 @@
# RUN: %PYTHON %s | FileCheck %s # RUN: %PYTHON %s | FileCheck %s
import torch import torch
import torch_mlir from torch_mlir import torchscript
class TwoMethodsModule(torch.nn.Module): class TwoMethodsModule(torch.nn.Module):
@ -17,14 +17,14 @@ class TwoMethodsModule(torch.nn.Module):
return torch.ops.aten.cos(x) return torch.ops.aten.cos(x)
example_args = torch_mlir.ExampleArgs() example_args = torchscript.ExampleArgs()
example_args.add_method("sin", torch.ones(2, 3)) example_args.add_method("sin", torch.ones(2, 3))
example_args.add_method("cos", torch.ones(2, 4)) example_args.add_method("cos", torch.ones(2, 4))
# Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to # Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to
# check the `use_tracing` case first. # check the `use_tracing` case first.
print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True)) print(torchscript.compile(TwoMethodsModule(), example_args, use_tracing=True))
# CHECK: module # CHECK: module
# CHECK-DAG: func.func @sin # CHECK-DAG: func.func @sin
# CHECK-DAG: func.func @cos # CHECK-DAG: func.func @cos
@ -34,8 +34,8 @@ print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
# Otherwise the user would have to do this manually, which is tedious. This # Otherwise the user would have to do this manually, which is tedious. This
# technically mutates the user input model which is not great but probably okay # technically mutates the user input model which is not great but probably okay
# for this kind of API sugar. Users can always take full control of the process # for this kind of API sugar. Users can always take full control of the process
# by scripting the model themselves before passing it to `torch_mlir.compile`. # by scripting the model themselves before passing it to `torchscript.compile`.
print(torch_mlir.compile(TwoMethodsModule(), example_args)) print(torchscript.compile(TwoMethodsModule(), example_args))
# CHECK: module # CHECK: module
# CHECK-DAG: func.func @sin # CHECK-DAG: func.func @sin
# CHECK-DAG: func.func @cos # CHECK-DAG: func.func @cos

View File

@ -7,7 +7,7 @@
import torch import torch
import torch_mlir from torch_mlir import torchscript
class TanhModule(torch.nn.Module): class TanhModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -17,9 +17,9 @@ class TanhModule(torch.nn.Module):
tanh_example_input = torch.ones(2, 3) tanh_example_input = torch.ones(2, 3)
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.TORCH)) print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type="torch")) print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch"))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>

View File

@ -7,7 +7,7 @@
import torch import torch
import torch_mlir from torch_mlir import torchscript
class TanhModule(torch.nn.Module): class TanhModule(torch.nn.Module):
@ -17,38 +17,38 @@ class TanhModule(torch.nn.Module):
tanh_example_input = torch.ones(2, 3) tanh_example_input = torch.ones(2, 3)
# Simplest case: One example argument. # Simplest case: One example argument.
print(torch_mlir.compile(TanhModule(), tanh_example_input, use_tracing=True)) print(torchscript.compile(TanhModule(), tanh_example_input, use_tracing=True))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# Simplest case: Passed as a tuple. # Simplest case: Passed as a tuple.
print(torch_mlir.compile(TanhModule(), (tanh_example_input,), use_tracing=True)) print(torchscript.compile(TanhModule(), (tanh_example_input,), use_tracing=True))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# Simplest case: Passed as a list. # Simplest case: Passed as a list.
print(torch_mlir.compile(TanhModule(), [tanh_example_input], use_tracing=True)) print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
# TensorPlaceholder support. # TensorPlaceholder support.
placeholder = torch_mlir.TensorPlaceholder.like( placeholder = torchscript.TensorPlaceholder.like(
tanh_example_input, dynamic_axes=[1]) tanh_example_input, dynamic_axes=[1])
print(torch_mlir.compile(TanhModule(), [placeholder], print(torchscript.compile(TanhModule(), [placeholder],
use_tracing=True, ignore_traced_shapes=True)) use_tracing=True, ignore_traced_shapes=True))
# CHECK-LABEL: @forward # CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
try: try:
# CHECK: `ignore_traced_shapes` requires `use_tracing` # CHECK: `ignore_traced_shapes` requires `use_tracing`
torch_mlir.compile(TanhModule(), [placeholder], ignore_traced_shapes=True) torchscript.compile(TanhModule(), [placeholder], ignore_traced_shapes=True)
except Exception as e: except Exception as e:
print(e) print(e)
try: try:
# CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True` # CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`
torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True) torchscript.compile(TanhModule(), [placeholder], use_tracing=True)
except Exception as e: except Exception as e:
print(e) print(e)
@ -60,13 +60,13 @@ class DictModule(torch.nn.Module):
try: try:
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
except Exception as e: except Exception as e:
print(e) print(e)
try: try:
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
except Exception as e: except Exception as e:
print(e) print(e)

View File

@ -125,7 +125,7 @@ def make_simple_dynamo_backend(user_backend):
Args: Args:
user_backend: A function with the signature used by ordinary user_backend: A function with the signature used by ordinary
TorchDynamo backends. But the torch.fx.GraphModule passed to it TorchDynamo backends. But the torch.fx.GraphModule passed to it
will be normalized for consumption by `torch_mlir.compile`. will be normalized for consumption by `torchscript.compile`.
Returns: Returns:
A function with the signature used by TorchDynamo backends. A function with the signature used by TorchDynamo backends.
""" """

View File

@ -22,7 +22,7 @@ from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_li
class OutputType(Enum): class OutputType(Enum):
"""The kind of output that `torch_mlir.compile` can produce. """The kind of output that `torchscript.compile` can produce.
In MLIR terminology, this describes the mix of dialects that will be In MLIR terminology, this describes the mix of dialects that will be
produced by the conversion process. produced by the conversion process.
@ -392,13 +392,13 @@ def compile(model: torch.nn.Module,
strip_overloads(model) strip_overloads(model)
# Get the model as JIT IR (TorchScript) for import. # Get the model as JIT IR (TorchScript) for import.
# TODO: Longer-term, we probably need to split `torch_mlir.compile`. # TODO: Longer-term, we probably need to split `torchscript.compile`.
# There should be an "acquisition" step that does # There should be an "acquisition" step that does
# tracing/scripting/importing from FX/using torchdynamo.export/etc. # tracing/scripting/importing from FX/using torchdynamo.export/etc.
# + any lowering to the backend contract. Then there should be a # + any lowering to the backend contract. Then there should be a
# "backend lowering" step that does the actual lowering to each # "backend lowering" step that does the actual lowering to each
# backend. This separation should be visible at the Python API level, and # backend. This separation should be visible at the Python API level, and
# we can implement a deliberately simplified API like `torch_mlir.compile` # we can implement a deliberately simplified API like `torchscript.compile`
# on top of those building blocks. # on top of those building blocks.
if isinstance(model, torch.jit.ScriptModule): if isinstance(model, torch.jit.ScriptModule):
# If the user already converted the model to JIT IR themselves, just # If the user already converted the model to JIT IR themselves, just

View File

@ -6,7 +6,7 @@
from typing import Any from typing import Any
import torch import torch
import torch_mlir from torch_mlir import torchscript
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
@ -30,7 +30,7 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
def compile(self, program: torch.nn.Module) -> Any: def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward) example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile( module = torchscript.compile(
program, example_args, output_type="linalg-on-tensors") program, example_args, output_type="linalg-on-tensors")
return self.backend.compile(module) return self.backend.compile(module)

View File

@ -6,7 +6,7 @@
from typing import Any from typing import Any
import torch import torch
import torch_mlir from torch_mlir import torchscript
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
@ -30,7 +30,7 @@ class StablehloBackendTestConfig(TestConfig):
def compile(self, program: torch.nn.Module) -> Any: def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward) example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile(program, example_args, output_type="stablehlo") module = torchscript.compile(program, example_args, output_type="stablehlo")
return self.backend.compile(module) return self.backend.compile(module)

View File

@ -17,7 +17,7 @@ from torch._functorch.aot_autograd import (
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
from torch_mlir.dynamo import _get_decomposition_table from torch_mlir.dynamo import _get_decomposition_table
from torch_mlir import ( from torch_mlir.torchscript import (
_example_args, _example_args,
OutputType, OutputType,
BACKEND_LEGAL_OPS, BACKEND_LEGAL_OPS,

View File

@ -6,7 +6,7 @@
from typing import Any from typing import Any
import torch import torch
import torch_mlir from torch_mlir import torchscript
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
@ -30,7 +30,7 @@ class TosaBackendTestConfig(TestConfig):
def compile(self, program: torch.nn.Module) -> Any: def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward) example_args = convert_annotations_to_placeholders(program.forward)
module = torch_mlir.compile( module = torchscript.compile(
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx) program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
return self.backend.compile(module) return self.backend.compile(module)

View File

@ -3,13 +3,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE. # Also available under a BSD-style license. See LICENSE.
from torch_mlir import TensorPlaceholder from torch_mlir.torchscript import TensorPlaceholder
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
def convert_annotations_to_placeholders(forward_method): def convert_annotations_to_placeholders(forward_method):
"""Converts the annotations on a forward method into tensor placeholders. """Converts the annotations on a forward method into tensor placeholders.
These placeholders are suitable for being passed to `torch_mlir.compile`. These placeholders are suitable for being passed to `torchscript.compile`.
""" """
annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
placeholders = [] placeholders = []

View File

@ -5,7 +5,7 @@ from typing import List, Tuple
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.utils.cpp_extension import torch.utils.cpp_extension
import torch_mlir from torch_mlir import torchscript
from torch_mlir_e2e_test.annotations import export, annotate_args from torch_mlir_e2e_test.annotations import export, annotate_args
@ -56,7 +56,7 @@ def run():
mod = CustomOpExampleModule() mod = CustomOpExampleModule()
mod.eval() mod.eval()
module = torch_mlir.compile( module = torchscript.compile(
mod, mod,
torch.ones(3, 4), torch.ones(3, 4),
output_type="torch", output_type="torch",

View File

@ -1,5 +1,5 @@
import torch import torch
import torch_mlir from torch_mlir import torchscript
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
@ -39,6 +39,6 @@ class Model(torch.nn.Module):
with torch.no_grad(): with torch.no_grad():
return data return data
output_type = torch_mlir.OutputType.RAW output_type = torchscript.OutputType.RAW
mod = torch_mlir.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type) mod = torchscript.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type)
print(mod) print(mod)

View File

@ -39,6 +39,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers
extras/onnx_importer.py extras/onnx_importer.py
) )
declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
fx.py
)
declare_mlir_python_sources(TorchMLIRPythonSources.Tools declare_mlir_python_sources(TorchMLIRPythonSources.Tools
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources ADD_TO_PARENT TorchMLIRPythonSources

View File

@ -0,0 +1,25 @@
from typing import Optional
import torch
import torch.export
import torch.nn as nn
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
constraints: Optional[torch.export.Constraint] = None,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context)
prog = torch.export.export(f, args, kwargs, constraints=constraints)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op

View File

@ -3,7 +3,7 @@
import gc import gc
import sys import sys
import torch import torch
import torch_mlir from torch_mlir import torchscript
def run_test(f): def run_test(f):
@ -26,7 +26,7 @@ class TinyModel(torch.nn.Module):
# CHECK-LABEL: TEST: test_enable_ir_printing # CHECK-LABEL: TEST: test_enable_ir_printing
@run_test @run_test
def test_enable_ir_printing(): def test_enable_ir_printing():
torch_mlir.compile(TinyModel(), torchscript.compile(TinyModel(),
torch.ones(1, 3, 20, 20), torch.ones(1, 3, 20, 20),
output_type="linalg-on-tensors", output_type="linalg-on-tensors",
enable_ir_printing=True) enable_ir_printing=True)

View File

@ -1,5 +1,3 @@
# Copyright 2023 Advanced Micro Devices, Inc
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information. # See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@ -13,26 +11,7 @@ import torch
import torch.export import torch.export
import torch.nn as nn import torch.nn as nn
from torch_mlir.extras.fx_importer import FxImporter from torch_mlir import fx
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
constraints: Optional[torch.export.Constraint] = None,
**kwargs,
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context)
prog = torch.export.export(f, args, kwargs, constraints=constraints)
fx_importer.import_frozen_exported_program(prog)
return fx_importer.module_op
def run(f): def run(f):
@ -75,5 +54,5 @@ def test_import_frozen_exported_program():
def forward(self, x): def forward(self, x):
return torch.tanh(x) * get_a() * self.b * self.p return torch.tanh(x) * get_a() * self.b * self.p
m = export_and_import(Basic(), torch.randn(3, 4)) m = fx.export_and_import(Basic(), torch.randn(3, 4))
print(m) print(m)