mirror of https://github.com/llvm/torch-mlir
Rename torch_mlir.compile APIs and introduce FX based analogs (#2842)
Link to related RFC: https://discourse.llvm.org/t/rfc-rename-torch-mlir-compile-apis-and-introduce-fx-based-analogs/76646 This commit updates the documentation, tests, CMake files, and API for the proposed changes in the RFC. There is a new torch_mlir/fx.py for user level APIs related to importing modules and a corresponding test for this path can be found at test/python/fx_importer/basic_test.py. --------- Co-authored-by: MaheshRavishankar <mravisha@amd.com>int_view_hack
parent
cc06391630
commit
bfcf93ea21
|
@ -184,7 +184,7 @@ semantics. And often users want to erase the shapes in the trace to allow
|
|||
dynamic shapes for the trace. Additionally, the Python-level data structures and
|
||||
APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we
|
||||
consider both of those as the same from the perspective of the responsibilities
|
||||
of the compiler. Both are accessed via the `torch_mlir.compile` Python API.
|
||||
of the compiler. Both are accessed via the `torch_mlir.torchscript.compile` Python API.
|
||||
|
||||
### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript
|
||||
|
||||
|
|
|
@ -120,37 +120,50 @@ cmake --build build
|
|||
### Linux and macOS
|
||||
|
||||
```shell
|
||||
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
|
||||
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
|
||||
```
|
||||
|
||||
### Windows PowerShell
|
||||
|
||||
```shell
|
||||
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/projects/pt1/examples"
|
||||
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer"
|
||||
```
|
||||
|
||||
## Testing MLIR output in various dialects
|
||||
|
||||
To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
|
||||
To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`.
|
||||
|
||||
Make sure you have activated the virtualenv and set the `PYTHONPATH` above
|
||||
(if running on Windows, modify the environment variable as shown above):
|
||||
```shell
|
||||
source mlir_venv/bin/activate
|
||||
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
|
||||
python test/python/fx_importer/basic_test.py
|
||||
```
|
||||
|
||||
This will display the basic example in TORCH dialect.
|
||||
|
||||
To test the compiler's output to the different MLIR dialects, you can also use the deprecated path
|
||||
using torchscript with the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
|
||||
This path doesn't give access to the current generation work that is being driven via the fx_importer
|
||||
and may lead to errors.
|
||||
|
||||
Same as above, but with different python path and example:
|
||||
```shell
|
||||
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
|
||||
python projects/pt1/examples/torchscript_resnet18_all_output_types.py
|
||||
```
|
||||
|
||||
This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA.
|
||||
|
||||
The main functionality is on `torch_mlir.compile()`'s `output_type`.
|
||||
The main functionality is on `torch_mlir.torchscript.compile()`'s `output_type`.
|
||||
|
||||
Ex:
|
||||
```python
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||
module = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||
```
|
||||
|
||||
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
|
||||
`output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
|
||||
|
||||
## Jupyter
|
||||
|
||||
|
|
|
@ -51,6 +51,22 @@ the ecosystem are:
|
|||
Most of this document describes long-term ecosystem changes that will address
|
||||
these, drastically improving Torch-MLIR's ability to meet its goals.
|
||||
|
||||
## Current API Paths
|
||||
|
||||
Currently, there are two main API paths for the torch-mlir project:
|
||||
|
||||
- The first path is part of the legacy project pt1 code
|
||||
(torch_mlir.torchscript.compile). This allows users to test the compiler's
|
||||
output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`,
|
||||
`RAW` and `STABLEHLO`). This path is deprecated and doesn’t give access to
|
||||
the current generation work that is being driven via the fx_importer. It is
|
||||
tied to the old Torchscript path.
|
||||
- The second path (torch_mlir.fx.export_and_import) allows users to import a
|
||||
consolidated torch.export.ExportedProgram instance of an arbitrary Python
|
||||
callable (an nn.Module, a function or a method) and output to torch dialect
|
||||
mlir module. This path is aligned with PyTorch's roadmap, but the path is
|
||||
not fully functional yet.
|
||||
|
||||
## Roadmap
|
||||
|
||||
### Refactoring the frontend
|
|
@ -14,7 +14,7 @@ import torch._dynamo as dynamo
|
|||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
@ -71,7 +71,7 @@ labels = load_labels()
|
|||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
mlir_module = torch_mlir.compile(
|
||||
mlir_module = torchscript.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors")
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(mlir_module)
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
|||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
|
@ -67,7 +67,7 @@ labels = load_labels()
|
|||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
|
|
@ -6,15 +6,15 @@
|
|||
import torch
|
||||
import torchvision
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
resnet18 = torchvision.models.resnet18(pretrained=True)
|
||||
resnet18.eval()
|
||||
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
||||
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||
# TODO: Debug why this is so slow.
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
|
||||
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
|
||||
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||
|
|
|
@ -184,7 +184,7 @@
|
|||
"\n",
|
||||
"# Compile the model with an example input.\n",
|
||||
"# We lower to the linalg-on-tensors form that the reference backend supports.\n",
|
||||
"compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
|
||||
"compiled = torch_mlir.torchscript.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
|
||||
"# Load it on the reference backend.\n",
|
||||
"jit_module = compile_and_load_on_refbackend(compiled)\n",
|
||||
"# Run it!\n",
|
||||
|
@ -326,7 +326,7 @@
|
|||
"source": [
|
||||
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
|
||||
"resnet18.eval()\n",
|
||||
"compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
|
||||
"compiled = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
|
||||
"jit_module = compile_and_load_on_refbackend(compiled)"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
model = models.resnet18(pretrained=True)
|
||||
model.eval()
|
||||
data = torch.randn(2,3,200,200)
|
||||
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
|
||||
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False)
|
||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
|
@ -17,7 +17,7 @@ model.eval()
|
|||
data = torch.randint(30522, (2, 128))
|
||||
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
|
||||
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True)
|
||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
|
|||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
|
||||
SOURCES
|
||||
__init__.py
|
||||
torchscript.py
|
||||
_dynamo_fx_importer.py
|
||||
compiler_utils.py
|
||||
dynamo.py
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
class BasicModule(torch.nn.Module):
|
||||
|
@ -15,17 +15,17 @@ class BasicModule(torch.nn.Module):
|
|||
return torch.ops.aten.sin(x)
|
||||
|
||||
|
||||
example_args = torch_mlir.ExampleArgs()
|
||||
example_args = torchscript.ExampleArgs()
|
||||
example_args.add_method("sin", torch.ones(2, 3))
|
||||
|
||||
scripted = torch.jit.script(BasicModule())
|
||||
print(torch_mlir.compile(scripted, example_args))
|
||||
print(torchscript.compile(scripted, example_args))
|
||||
# CHECK: module
|
||||
# CHECK-DAG: func.func @sin
|
||||
|
||||
scripted = torch.jit.script(BasicModule())
|
||||
try:
|
||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||
torch_mlir.compile(scripted, torch_mlir.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
|
||||
torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
@ -6,23 +6,23 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
class BasicModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.sin(x)
|
||||
|
||||
example_arg = torch.ones(2, 3)
|
||||
example_args = torch_mlir.ExampleArgs.get(example_arg)
|
||||
example_args = torchscript.ExampleArgs.get(example_arg)
|
||||
|
||||
traced = torch.jit.trace(BasicModule(), example_arg)
|
||||
print(torch_mlir.compile(traced, example_args))
|
||||
print(torchscript.compile(traced, example_args))
|
||||
# CHECK: module
|
||||
# CHECK-DAG: func.func @forward
|
||||
|
||||
traced = torch.jit.trace(BasicModule(), example_arg)
|
||||
try:
|
||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||
torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg))
|
||||
torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
class AddmmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -15,9 +15,9 @@ class AddmmModule(torch.nn.Module):
|
|||
def forward(self, x, y, z):
|
||||
return torch.ops.aten.addmm(x, y, z)
|
||||
|
||||
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
|
||||
example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)]
|
||||
|
||||
print(torch_mlir.compile(AddmmModule(), example_args,
|
||||
print(torchscript.compile(AddmmModule(), example_args,
|
||||
output_type="torch", backend_legal_ops=["aten.addmm"]))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.addmm
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
class TanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -18,24 +18,24 @@ class TanhModule(torch.nn.Module):
|
|||
tanh_example_input = torch.ones(2, 3)
|
||||
|
||||
# Simplest case: One example argument.
|
||||
print(torch_mlir.compile(TanhModule(), tanh_example_input))
|
||||
print(torchscript.compile(TanhModule(), tanh_example_input))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
|
||||
# Use a TensorPlaceholder to represent dynamic axes.
|
||||
placeholder = torch_mlir.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
|
||||
print(torch_mlir.compile(TanhModule(), placeholder))
|
||||
placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
|
||||
print(torchscript.compile(TanhModule(), placeholder))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||
|
||||
# Explicitly construct a TensorPlaceholder.
|
||||
placeholder = torch_mlir.TensorPlaceholder([-1, 2], torch.float32)
|
||||
print(torch_mlir.compile(TanhModule(), placeholder))
|
||||
placeholder = torchscript.TensorPlaceholder([-1, 2], torch.float32)
|
||||
print(torchscript.compile(TanhModule(), placeholder))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
||||
|
||||
# Basic smoke test for the raw output type.
|
||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW))
|
||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW))
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">
|
||||
|
||||
|
@ -47,12 +47,12 @@ class MmModule(torch.nn.Module):
|
|||
|
||||
# N > 1 inputs.
|
||||
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
|
||||
print(torch_mlir.compile(MmModule(), mm_example_inputs))
|
||||
print(torchscript.compile(MmModule(), mm_example_inputs))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
|
||||
|
||||
# Mixes Tensor's and TensorPlaceholder's.
|
||||
mm_dynamic_inputs = [mm_example_inputs[0], torch_mlir.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
|
||||
print(torch_mlir.compile(MmModule(), mm_dynamic_inputs))
|
||||
mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
|
||||
print(torchscript.compile(MmModule(), mm_dynamic_inputs))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
import functorch
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
def simple(x):
|
||||
return x * x
|
||||
|
@ -17,6 +17,6 @@ example_input = torch.randn(1,)
|
|||
graph = functorch.make_fx(simple)(torch.randn(1,))
|
||||
|
||||
# Simplest case: One example argument.
|
||||
print(torch_mlir.compile(graph, example_input))
|
||||
print(torchscript.compile(graph, example_input))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>
|
|
@ -6,7 +6,7 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
class TwoMethodsModule(torch.nn.Module):
|
||||
|
@ -17,14 +17,14 @@ class TwoMethodsModule(torch.nn.Module):
|
|||
return torch.ops.aten.cos(x)
|
||||
|
||||
|
||||
example_args = torch_mlir.ExampleArgs()
|
||||
example_args = torchscript.ExampleArgs()
|
||||
example_args.add_method("sin", torch.ones(2, 3))
|
||||
example_args.add_method("cos", torch.ones(2, 4))
|
||||
|
||||
# Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to
|
||||
# check the `use_tracing` case first.
|
||||
|
||||
print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
|
||||
print(torchscript.compile(TwoMethodsModule(), example_args, use_tracing=True))
|
||||
# CHECK: module
|
||||
# CHECK-DAG: func.func @sin
|
||||
# CHECK-DAG: func.func @cos
|
||||
|
@ -34,8 +34,8 @@ print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
|
|||
# Otherwise the user would have to do this manually, which is tedious. This
|
||||
# technically mutates the user input model which is not great but probably okay
|
||||
# for this kind of API sugar. Users can always take full control of the process
|
||||
# by scripting the model themselves before passing it to `torch_mlir.compile`.
|
||||
print(torch_mlir.compile(TwoMethodsModule(), example_args))
|
||||
# by scripting the model themselves before passing it to `torchscript.compile`.
|
||||
print(torchscript.compile(TwoMethodsModule(), example_args))
|
||||
# CHECK: module
|
||||
# CHECK-DAG: func.func @sin
|
||||
# CHECK-DAG: func.func @cos
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
class TanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -17,9 +17,9 @@ class TanhModule(torch.nn.Module):
|
|||
|
||||
tanh_example_input = torch.ones(2, 3)
|
||||
|
||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.TORCH))
|
||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type="torch"))
|
||||
print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch"))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
class TanhModule(torch.nn.Module):
|
||||
|
@ -17,38 +17,38 @@ class TanhModule(torch.nn.Module):
|
|||
tanh_example_input = torch.ones(2, 3)
|
||||
|
||||
# Simplest case: One example argument.
|
||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, use_tracing=True))
|
||||
print(torchscript.compile(TanhModule(), tanh_example_input, use_tracing=True))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
|
||||
# Simplest case: Passed as a tuple.
|
||||
print(torch_mlir.compile(TanhModule(), (tanh_example_input,), use_tracing=True))
|
||||
print(torchscript.compile(TanhModule(), (tanh_example_input,), use_tracing=True))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
|
||||
# Simplest case: Passed as a list.
|
||||
print(torch_mlir.compile(TanhModule(), [tanh_example_input], use_tracing=True))
|
||||
print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
|
||||
# TensorPlaceholder support.
|
||||
placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
placeholder = torchscript.TensorPlaceholder.like(
|
||||
tanh_example_input, dynamic_axes=[1])
|
||||
print(torch_mlir.compile(TanhModule(), [placeholder],
|
||||
print(torchscript.compile(TanhModule(), [placeholder],
|
||||
use_tracing=True, ignore_traced_shapes=True))
|
||||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||
|
||||
try:
|
||||
# CHECK: `ignore_traced_shapes` requires `use_tracing`
|
||||
torch_mlir.compile(TanhModule(), [placeholder], ignore_traced_shapes=True)
|
||||
torchscript.compile(TanhModule(), [placeholder], ignore_traced_shapes=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
try:
|
||||
# CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`
|
||||
torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True)
|
||||
torchscript.compile(TanhModule(), [placeholder], use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
@ -60,13 +60,13 @@ class DictModule(torch.nn.Module):
|
|||
|
||||
try:
|
||||
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
||||
torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
try:
|
||||
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
||||
torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
|
|
@ -125,7 +125,7 @@ def make_simple_dynamo_backend(user_backend):
|
|||
Args:
|
||||
user_backend: A function with the signature used by ordinary
|
||||
TorchDynamo backends. But the torch.fx.GraphModule passed to it
|
||||
will be normalized for consumption by `torch_mlir.compile`.
|
||||
will be normalized for consumption by `torchscript.compile`.
|
||||
Returns:
|
||||
A function with the signature used by TorchDynamo backends.
|
||||
"""
|
||||
|
|
|
@ -22,7 +22,7 @@ from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_li
|
|||
|
||||
|
||||
class OutputType(Enum):
|
||||
"""The kind of output that `torch_mlir.compile` can produce.
|
||||
"""The kind of output that `torchscript.compile` can produce.
|
||||
|
||||
In MLIR terminology, this describes the mix of dialects that will be
|
||||
produced by the conversion process.
|
||||
|
@ -392,13 +392,13 @@ def compile(model: torch.nn.Module,
|
|||
strip_overloads(model)
|
||||
|
||||
# Get the model as JIT IR (TorchScript) for import.
|
||||
# TODO: Longer-term, we probably need to split `torch_mlir.compile`.
|
||||
# TODO: Longer-term, we probably need to split `torchscript.compile`.
|
||||
# There should be an "acquisition" step that does
|
||||
# tracing/scripting/importing from FX/using torchdynamo.export/etc.
|
||||
# + any lowering to the backend contract. Then there should be a
|
||||
# "backend lowering" step that does the actual lowering to each
|
||||
# backend. This separation should be visible at the Python API level, and
|
||||
# we can implement a deliberately simplified API like `torch_mlir.compile`
|
||||
# we can implement a deliberately simplified API like `torchscript.compile`
|
||||
# on top of those building blocks.
|
||||
if isinstance(model, torch.jit.ScriptModule):
|
||||
# If the user already converted the model to JIT IR themselves, just
|
|
@ -6,7 +6,7 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
@ -30,7 +30,7 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
|||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
example_args = convert_annotations_to_placeholders(program.forward)
|
||||
module = torch_mlir.compile(
|
||||
module = torchscript.compile(
|
||||
program, example_args, output_type="linalg-on-tensors")
|
||||
|
||||
return self.backend.compile(module)
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
@ -30,7 +30,7 @@ class StablehloBackendTestConfig(TestConfig):
|
|||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
example_args = convert_annotations_to_placeholders(program.forward)
|
||||
module = torch_mlir.compile(program, example_args, output_type="stablehlo")
|
||||
module = torchscript.compile(program, example_args, output_type="stablehlo")
|
||||
|
||||
return self.backend.compile(module)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from torch._functorch.aot_autograd import (
|
|||
|
||||
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
||||
from torch_mlir.dynamo import _get_decomposition_table
|
||||
from torch_mlir import (
|
||||
from torch_mlir.torchscript import (
|
||||
_example_args,
|
||||
OutputType,
|
||||
BACKEND_LEGAL_OPS,
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
@ -30,7 +30,7 @@ class TosaBackendTestConfig(TestConfig):
|
|||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
example_args = convert_annotations_to_placeholders(program.forward)
|
||||
module = torch_mlir.compile(
|
||||
module = torchscript.compile(
|
||||
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
|
||||
|
||||
return self.backend.compile(module)
|
||||
|
|
|
@ -3,13 +3,13 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from torch_mlir import TensorPlaceholder
|
||||
from torch_mlir.torchscript import TensorPlaceholder
|
||||
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||
|
||||
def convert_annotations_to_placeholders(forward_method):
|
||||
"""Converts the annotations on a forward method into tensor placeholders.
|
||||
|
||||
These placeholders are suitable for being passed to `torch_mlir.compile`.
|
||||
These placeholders are suitable for being passed to `torchscript.compile`.
|
||||
"""
|
||||
annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
|
||||
placeholders = []
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import List, Tuple
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.utils.cpp_extension
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
from torch_mlir_e2e_test.annotations import export, annotate_args
|
||||
|
||||
|
||||
|
@ -56,7 +56,7 @@ def run():
|
|||
mod = CustomOpExampleModule()
|
||||
mod.eval()
|
||||
|
||||
module = torch_mlir.compile(
|
||||
module = torchscript.compile(
|
||||
mod,
|
||||
torch.ones(3, 4),
|
||||
output_type="torch",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
|
@ -39,6 +39,6 @@ class Model(torch.nn.Module):
|
|||
with torch.no_grad():
|
||||
return data
|
||||
|
||||
output_type = torch_mlir.OutputType.RAW
|
||||
mod = torch_mlir.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type)
|
||||
output_type = torchscript.OutputType.RAW
|
||||
mod = torchscript.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type)
|
||||
print(mod)
|
||||
|
|
|
@ -39,6 +39,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers
|
|||
extras/onnx_importer.py
|
||||
)
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES
|
||||
fx.py
|
||||
)
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.Tools
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
import torch.nn as nn
|
||||
|
||||
from torch_mlir.extras.fx_importer import FxImporter
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects import torch as torch_d
|
||||
|
||||
def export_and_import(
|
||||
f,
|
||||
*args,
|
||||
fx_importer: Optional[FxImporter] = None,
|
||||
constraints: Optional[torch.export.Constraint] = None,
|
||||
**kwargs,
|
||||
):
|
||||
context = ir.Context()
|
||||
torch_d.register_dialect(context)
|
||||
|
||||
if fx_importer is None:
|
||||
fx_importer = FxImporter(context=context)
|
||||
prog = torch.export.export(f, args, kwargs, constraints=constraints)
|
||||
fx_importer.import_frozen_exported_program(prog)
|
||||
return fx_importer.module_op
|
|
@ -3,7 +3,7 @@
|
|||
import gc
|
||||
import sys
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import torchscript
|
||||
|
||||
|
||||
def run_test(f):
|
||||
|
@ -26,7 +26,7 @@ class TinyModel(torch.nn.Module):
|
|||
# CHECK-LABEL: TEST: test_enable_ir_printing
|
||||
@run_test
|
||||
def test_enable_ir_printing():
|
||||
torch_mlir.compile(TinyModel(),
|
||||
torchscript.compile(TinyModel(),
|
||||
torch.ones(1, 3, 20, 20),
|
||||
output_type="linalg-on-tensors",
|
||||
enable_ir_printing=True)
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
# Copyright 2023 Advanced Micro Devices, Inc
|
||||
#
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
@ -13,26 +11,7 @@ import torch
|
|||
import torch.export
|
||||
import torch.nn as nn
|
||||
|
||||
from torch_mlir.extras.fx_importer import FxImporter
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects import torch as torch_d
|
||||
|
||||
|
||||
def export_and_import(
|
||||
f,
|
||||
*args,
|
||||
fx_importer: Optional[FxImporter] = None,
|
||||
constraints: Optional[torch.export.Constraint] = None,
|
||||
**kwargs,
|
||||
):
|
||||
context = ir.Context()
|
||||
torch_d.register_dialect(context)
|
||||
|
||||
if fx_importer is None:
|
||||
fx_importer = FxImporter(context=context)
|
||||
prog = torch.export.export(f, args, kwargs, constraints=constraints)
|
||||
fx_importer.import_frozen_exported_program(prog)
|
||||
return fx_importer.module_op
|
||||
from torch_mlir import fx
|
||||
|
||||
|
||||
def run(f):
|
||||
|
@ -75,5 +54,5 @@ def test_import_frozen_exported_program():
|
|||
def forward(self, x):
|
||||
return torch.tanh(x) * get_a() * self.b * self.p
|
||||
|
||||
m = export_and_import(Basic(), torch.randn(3, 4))
|
||||
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
||||
print(m)
|
||||
|
|
Loading…
Reference in New Issue