mirror of https://github.com/llvm/torch-mlir
Rename torch_mlir.compile APIs and introduce FX based analogs (#2842)
Link to related RFC: https://discourse.llvm.org/t/rfc-rename-torch-mlir-compile-apis-and-introduce-fx-based-analogs/76646 This commit updates the documentation, tests, CMake files, and API for the proposed changes in the RFC. There is a new torch_mlir/fx.py for user level APIs related to importing modules and a corresponding test for this path can be found at test/python/fx_importer/basic_test.py. --------- Co-authored-by: MaheshRavishankar <mravisha@amd.com>int_view_hack
parent
cc06391630
commit
bfcf93ea21
|
@ -184,7 +184,7 @@ semantics. And often users want to erase the shapes in the trace to allow
|
||||||
dynamic shapes for the trace. Additionally, the Python-level data structures and
|
dynamic shapes for the trace. Additionally, the Python-level data structures and
|
||||||
APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we
|
APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we
|
||||||
consider both of those as the same from the perspective of the responsibilities
|
consider both of those as the same from the perspective of the responsibilities
|
||||||
of the compiler. Both are accessed via the `torch_mlir.compile` Python API.
|
of the compiler. Both are accessed via the `torch_mlir.torchscript.compile` Python API.
|
||||||
|
|
||||||
### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript
|
### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript
|
||||||
|
|
||||||
|
|
|
@ -120,37 +120,50 @@ cmake --build build
|
||||||
### Linux and macOS
|
### Linux and macOS
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
|
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
|
||||||
```
|
```
|
||||||
|
|
||||||
### Windows PowerShell
|
### Windows PowerShell
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/projects/pt1/examples"
|
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Testing MLIR output in various dialects
|
## Testing MLIR output in various dialects
|
||||||
|
|
||||||
To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
|
To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`.
|
||||||
|
|
||||||
Make sure you have activated the virtualenv and set the `PYTHONPATH` above
|
Make sure you have activated the virtualenv and set the `PYTHONPATH` above
|
||||||
(if running on Windows, modify the environment variable as shown above):
|
(if running on Windows, modify the environment variable as shown above):
|
||||||
```shell
|
```shell
|
||||||
source mlir_venv/bin/activate
|
source mlir_venv/bin/activate
|
||||||
|
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer
|
||||||
|
python test/python/fx_importer/basic_test.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will display the basic example in TORCH dialect.
|
||||||
|
|
||||||
|
To test the compiler's output to the different MLIR dialects, you can also use the deprecated path
|
||||||
|
using torchscript with the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`.
|
||||||
|
This path doesn't give access to the current generation work that is being driven via the fx_importer
|
||||||
|
and may lead to errors.
|
||||||
|
|
||||||
|
Same as above, but with different python path and example:
|
||||||
|
```shell
|
||||||
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
|
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples
|
||||||
python projects/pt1/examples/torchscript_resnet18_all_output_types.py
|
python projects/pt1/examples/torchscript_resnet18_all_output_types.py
|
||||||
```
|
```
|
||||||
|
|
||||||
This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA.
|
This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA.
|
||||||
|
|
||||||
The main functionality is on `torch_mlir.compile()`'s `output_type`.
|
The main functionality is on `torch_mlir.torchscript.compile()`'s `output_type`.
|
||||||
|
|
||||||
Ex:
|
Ex:
|
||||||
```python
|
```python
|
||||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
module = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||||
```
|
```
|
||||||
|
|
||||||
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
|
`output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
|
||||||
|
|
||||||
## Jupyter
|
## Jupyter
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,22 @@ the ecosystem are:
|
||||||
Most of this document describes long-term ecosystem changes that will address
|
Most of this document describes long-term ecosystem changes that will address
|
||||||
these, drastically improving Torch-MLIR's ability to meet its goals.
|
these, drastically improving Torch-MLIR's ability to meet its goals.
|
||||||
|
|
||||||
|
## Current API Paths
|
||||||
|
|
||||||
|
Currently, there are two main API paths for the torch-mlir project:
|
||||||
|
|
||||||
|
- The first path is part of the legacy project pt1 code
|
||||||
|
(torch_mlir.torchscript.compile). This allows users to test the compiler's
|
||||||
|
output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`,
|
||||||
|
`RAW` and `STABLEHLO`). This path is deprecated and doesn’t give access to
|
||||||
|
the current generation work that is being driven via the fx_importer. It is
|
||||||
|
tied to the old Torchscript path.
|
||||||
|
- The second path (torch_mlir.fx.export_and_import) allows users to import a
|
||||||
|
consolidated torch.export.ExportedProgram instance of an arbitrary Python
|
||||||
|
callable (an nn.Module, a function or a method) and output to torch dialect
|
||||||
|
mlir module. This path is aligned with PyTorch's roadmap, but the path is
|
||||||
|
not fully functional yet.
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
|
||||||
### Refactoring the frontend
|
### Refactoring the frontend
|
|
@ -14,7 +14,7 @@ import torch._dynamo as dynamo
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ labels = load_labels()
|
||||||
@make_simple_dynamo_backend
|
@make_simple_dynamo_backend
|
||||||
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||||
example_inputs: List[torch.Tensor]):
|
example_inputs: List[torch.Tensor]):
|
||||||
mlir_module = torch_mlir.compile(
|
mlir_module = torchscript.compile(
|
||||||
fx_graph, example_inputs, output_type="linalg-on-tensors")
|
fx_graph, example_inputs, output_type="linalg-on-tensors")
|
||||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||||
compiled = backend.compile(mlir_module)
|
compiled = backend.compile(mlir_module)
|
||||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ labels = load_labels()
|
||||||
|
|
||||||
resnet18 = models.resnet18(pretrained=True)
|
resnet18 = models.resnet18(pretrained=True)
|
||||||
resnet18.train(False)
|
resnet18.train(False)
|
||||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
||||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||||
compiled = backend.compile(module)
|
compiled = backend.compile(module)
|
||||||
jit_module = backend.load(compiled)
|
jit_module = backend.load(compiled)
|
||||||
|
|
|
@ -6,15 +6,15 @@
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
resnet18 = torchvision.models.resnet18(pretrained=True)
|
resnet18 = torchvision.models.resnet18(pretrained=True)
|
||||||
resnet18.eval()
|
resnet18.eval()
|
||||||
|
|
||||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||||
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors")
|
||||||
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||||
# TODO: Debug why this is so slow.
|
# TODO: Debug why this is so slow.
|
||||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
|
module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa")
|
||||||
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||||
|
|
|
@ -184,7 +184,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# Compile the model with an example input.\n",
|
"# Compile the model with an example input.\n",
|
||||||
"# We lower to the linalg-on-tensors form that the reference backend supports.\n",
|
"# We lower to the linalg-on-tensors form that the reference backend supports.\n",
|
||||||
"compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
|
"compiled = torch_mlir.torchscript.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
|
||||||
"# Load it on the reference backend.\n",
|
"# Load it on the reference backend.\n",
|
||||||
"jit_module = compile_and_load_on_refbackend(compiled)\n",
|
"jit_module = compile_and_load_on_refbackend(compiled)\n",
|
||||||
"# Run it!\n",
|
"# Run it!\n",
|
||||||
|
@ -326,7 +326,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
|
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
|
||||||
"resnet18.eval()\n",
|
"resnet18.eval()\n",
|
||||||
"compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
|
"compiled = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
|
||||||
"jit_module = compile_and_load_on_refbackend(compiled)"
|
"jit_module = compile_and_load_on_refbackend(compiled)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
import torch
|
import torch
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
model = models.resnet18(pretrained=True)
|
model = models.resnet18(pretrained=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
data = torch.randn(2,3,200,200)
|
data = torch.randn(2,3,200,200)
|
||||||
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
||||||
|
|
||||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
|
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False)
|
||||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||||
outf.write(str(module))
|
outf.write(str(module))
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
from transformers import BertForMaskedLM
|
from transformers import BertForMaskedLM
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ model.eval()
|
||||||
data = torch.randint(30522, (2, 128))
|
data = torch.randint(30522, (2, 128))
|
||||||
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
||||||
|
|
||||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
|
module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True)
|
||||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||||
outf.write(str(module))
|
outf.write(str(module))
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
|
||||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
|
ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
|
||||||
SOURCES
|
SOURCES
|
||||||
__init__.py
|
torchscript.py
|
||||||
_dynamo_fx_importer.py
|
_dynamo_fx_importer.py
|
||||||
compiler_utils.py
|
compiler_utils.py
|
||||||
dynamo.py
|
dynamo.py
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
# RUN: %PYTHON %s | FileCheck %s
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
class BasicModule(torch.nn.Module):
|
class BasicModule(torch.nn.Module):
|
||||||
|
@ -15,17 +15,17 @@ class BasicModule(torch.nn.Module):
|
||||||
return torch.ops.aten.sin(x)
|
return torch.ops.aten.sin(x)
|
||||||
|
|
||||||
|
|
||||||
example_args = torch_mlir.ExampleArgs()
|
example_args = torchscript.ExampleArgs()
|
||||||
example_args.add_method("sin", torch.ones(2, 3))
|
example_args.add_method("sin", torch.ones(2, 3))
|
||||||
|
|
||||||
scripted = torch.jit.script(BasicModule())
|
scripted = torch.jit.script(BasicModule())
|
||||||
print(torch_mlir.compile(scripted, example_args))
|
print(torchscript.compile(scripted, example_args))
|
||||||
# CHECK: module
|
# CHECK: module
|
||||||
# CHECK-DAG: func.func @sin
|
# CHECK-DAG: func.func @sin
|
||||||
|
|
||||||
scripted = torch.jit.script(BasicModule())
|
scripted = torch.jit.script(BasicModule())
|
||||||
try:
|
try:
|
||||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||||
torch_mlir.compile(scripted, torch_mlir.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
|
torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3)))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
|
@ -6,23 +6,23 @@
|
||||||
# RUN: %PYTHON %s | FileCheck %s
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
class BasicModule(torch.nn.Module):
|
class BasicModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.sin(x)
|
return torch.ops.aten.sin(x)
|
||||||
|
|
||||||
example_arg = torch.ones(2, 3)
|
example_arg = torch.ones(2, 3)
|
||||||
example_args = torch_mlir.ExampleArgs.get(example_arg)
|
example_args = torchscript.ExampleArgs.get(example_arg)
|
||||||
|
|
||||||
traced = torch.jit.trace(BasicModule(), example_arg)
|
traced = torch.jit.trace(BasicModule(), example_arg)
|
||||||
print(torch_mlir.compile(traced, example_args))
|
print(torchscript.compile(traced, example_args))
|
||||||
# CHECK: module
|
# CHECK: module
|
||||||
# CHECK-DAG: func.func @forward
|
# CHECK-DAG: func.func @forward
|
||||||
|
|
||||||
traced = torch.jit.trace(BasicModule(), example_arg)
|
traced = torch.jit.trace(BasicModule(), example_arg)
|
||||||
try:
|
try:
|
||||||
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
# CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition.
|
||||||
torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg))
|
torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
class AddmmModule(torch.nn.Module):
|
class AddmmModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -15,9 +15,9 @@ class AddmmModule(torch.nn.Module):
|
||||||
def forward(self, x, y, z):
|
def forward(self, x, y, z):
|
||||||
return torch.ops.aten.addmm(x, y, z)
|
return torch.ops.aten.addmm(x, y, z)
|
||||||
|
|
||||||
example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)]
|
example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)]
|
||||||
|
|
||||||
print(torch_mlir.compile(AddmmModule(), example_args,
|
print(torchscript.compile(AddmmModule(), example_args,
|
||||||
output_type="torch", backend_legal_ops=["aten.addmm"]))
|
output_type="torch", backend_legal_ops=["aten.addmm"]))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.addmm
|
# CHECK: torch.aten.addmm
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
class TanhModule(torch.nn.Module):
|
class TanhModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -18,24 +18,24 @@ class TanhModule(torch.nn.Module):
|
||||||
tanh_example_input = torch.ones(2, 3)
|
tanh_example_input = torch.ones(2, 3)
|
||||||
|
|
||||||
# Simplest case: One example argument.
|
# Simplest case: One example argument.
|
||||||
print(torch_mlir.compile(TanhModule(), tanh_example_input))
|
print(torchscript.compile(TanhModule(), tanh_example_input))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||||
|
|
||||||
# Use a TensorPlaceholder to represent dynamic axes.
|
# Use a TensorPlaceholder to represent dynamic axes.
|
||||||
placeholder = torch_mlir.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
|
placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1])
|
||||||
print(torch_mlir.compile(TanhModule(), placeholder))
|
print(torchscript.compile(TanhModule(), placeholder))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||||
|
|
||||||
# Explicitly construct a TensorPlaceholder.
|
# Explicitly construct a TensorPlaceholder.
|
||||||
placeholder = torch_mlir.TensorPlaceholder([-1, 2], torch.float32)
|
placeholder = torchscript.TensorPlaceholder([-1, 2], torch.float32)
|
||||||
print(torch_mlir.compile(TanhModule(), placeholder))
|
print(torchscript.compile(TanhModule(), placeholder))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
||||||
|
|
||||||
# Basic smoke test for the raw output type.
|
# Basic smoke test for the raw output type.
|
||||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW))
|
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW))
|
||||||
# CHECK: torch.nn_module {
|
# CHECK: torch.nn_module {
|
||||||
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">
|
# CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule">
|
||||||
|
|
||||||
|
@ -47,12 +47,12 @@ class MmModule(torch.nn.Module):
|
||||||
|
|
||||||
# N > 1 inputs.
|
# N > 1 inputs.
|
||||||
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
|
mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)]
|
||||||
print(torch_mlir.compile(MmModule(), mm_example_inputs))
|
print(torchscript.compile(MmModule(), mm_example_inputs))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
|
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
|
||||||
|
|
||||||
# Mixes Tensor's and TensorPlaceholder's.
|
# Mixes Tensor's and TensorPlaceholder's.
|
||||||
mm_dynamic_inputs = [mm_example_inputs[0], torch_mlir.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
|
mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])]
|
||||||
print(torch_mlir.compile(MmModule(), mm_dynamic_inputs))
|
print(torchscript.compile(MmModule(), mm_dynamic_inputs))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>
|
# CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
import functorch
|
import functorch
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
def simple(x):
|
def simple(x):
|
||||||
return x * x
|
return x * x
|
||||||
|
@ -17,6 +17,6 @@ example_input = torch.randn(1,)
|
||||||
graph = functorch.make_fx(simple)(torch.randn(1,))
|
graph = functorch.make_fx(simple)(torch.randn(1,))
|
||||||
|
|
||||||
# Simplest case: One example argument.
|
# Simplest case: One example argument.
|
||||||
print(torch_mlir.compile(graph, example_input))
|
print(torchscript.compile(graph, example_input))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>
|
# CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32>
|
|
@ -6,7 +6,7 @@
|
||||||
# RUN: %PYTHON %s | FileCheck %s
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
class TwoMethodsModule(torch.nn.Module):
|
class TwoMethodsModule(torch.nn.Module):
|
||||||
|
@ -17,14 +17,14 @@ class TwoMethodsModule(torch.nn.Module):
|
||||||
return torch.ops.aten.cos(x)
|
return torch.ops.aten.cos(x)
|
||||||
|
|
||||||
|
|
||||||
example_args = torch_mlir.ExampleArgs()
|
example_args = torchscript.ExampleArgs()
|
||||||
example_args.add_method("sin", torch.ones(2, 3))
|
example_args.add_method("sin", torch.ones(2, 3))
|
||||||
example_args.add_method("cos", torch.ones(2, 4))
|
example_args.add_method("cos", torch.ones(2, 4))
|
||||||
|
|
||||||
# Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to
|
# Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to
|
||||||
# check the `use_tracing` case first.
|
# check the `use_tracing` case first.
|
||||||
|
|
||||||
print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
|
print(torchscript.compile(TwoMethodsModule(), example_args, use_tracing=True))
|
||||||
# CHECK: module
|
# CHECK: module
|
||||||
# CHECK-DAG: func.func @sin
|
# CHECK-DAG: func.func @sin
|
||||||
# CHECK-DAG: func.func @cos
|
# CHECK-DAG: func.func @cos
|
||||||
|
@ -34,8 +34,8 @@ print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
|
||||||
# Otherwise the user would have to do this manually, which is tedious. This
|
# Otherwise the user would have to do this manually, which is tedious. This
|
||||||
# technically mutates the user input model which is not great but probably okay
|
# technically mutates the user input model which is not great but probably okay
|
||||||
# for this kind of API sugar. Users can always take full control of the process
|
# for this kind of API sugar. Users can always take full control of the process
|
||||||
# by scripting the model themselves before passing it to `torch_mlir.compile`.
|
# by scripting the model themselves before passing it to `torchscript.compile`.
|
||||||
print(torch_mlir.compile(TwoMethodsModule(), example_args))
|
print(torchscript.compile(TwoMethodsModule(), example_args))
|
||||||
# CHECK: module
|
# CHECK: module
|
||||||
# CHECK-DAG: func.func @sin
|
# CHECK-DAG: func.func @sin
|
||||||
# CHECK-DAG: func.func @cos
|
# CHECK-DAG: func.func @cos
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
class TanhModule(torch.nn.Module):
|
class TanhModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -17,9 +17,9 @@ class TanhModule(torch.nn.Module):
|
||||||
|
|
||||||
tanh_example_input = torch.ones(2, 3)
|
tanh_example_input = torch.ones(2, 3)
|
||||||
|
|
||||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.TORCH))
|
print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type="torch"))
|
print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch"))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
class TanhModule(torch.nn.Module):
|
class TanhModule(torch.nn.Module):
|
||||||
|
@ -17,38 +17,38 @@ class TanhModule(torch.nn.Module):
|
||||||
tanh_example_input = torch.ones(2, 3)
|
tanh_example_input = torch.ones(2, 3)
|
||||||
|
|
||||||
# Simplest case: One example argument.
|
# Simplest case: One example argument.
|
||||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, use_tracing=True))
|
print(torchscript.compile(TanhModule(), tanh_example_input, use_tracing=True))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||||
|
|
||||||
# Simplest case: Passed as a tuple.
|
# Simplest case: Passed as a tuple.
|
||||||
print(torch_mlir.compile(TanhModule(), (tanh_example_input,), use_tracing=True))
|
print(torchscript.compile(TanhModule(), (tanh_example_input,), use_tracing=True))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||||
|
|
||||||
# Simplest case: Passed as a list.
|
# Simplest case: Passed as a list.
|
||||||
print(torch_mlir.compile(TanhModule(), [tanh_example_input], use_tracing=True))
|
print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||||
|
|
||||||
# TensorPlaceholder support.
|
# TensorPlaceholder support.
|
||||||
placeholder = torch_mlir.TensorPlaceholder.like(
|
placeholder = torchscript.TensorPlaceholder.like(
|
||||||
tanh_example_input, dynamic_axes=[1])
|
tanh_example_input, dynamic_axes=[1])
|
||||||
print(torch_mlir.compile(TanhModule(), [placeholder],
|
print(torchscript.compile(TanhModule(), [placeholder],
|
||||||
use_tracing=True, ignore_traced_shapes=True))
|
use_tracing=True, ignore_traced_shapes=True))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# CHECK: `ignore_traced_shapes` requires `use_tracing`
|
# CHECK: `ignore_traced_shapes` requires `use_tracing`
|
||||||
torch_mlir.compile(TanhModule(), [placeholder], ignore_traced_shapes=True)
|
torchscript.compile(TanhModule(), [placeholder], ignore_traced_shapes=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`
|
# CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`
|
||||||
torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True)
|
torchscript.compile(TanhModule(), [placeholder], use_tracing=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
@ -60,13 +60,13 @@ class DictModule(torch.nn.Module):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||||
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
# CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||||
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
|
@ -125,7 +125,7 @@ def make_simple_dynamo_backend(user_backend):
|
||||||
Args:
|
Args:
|
||||||
user_backend: A function with the signature used by ordinary
|
user_backend: A function with the signature used by ordinary
|
||||||
TorchDynamo backends. But the torch.fx.GraphModule passed to it
|
TorchDynamo backends. But the torch.fx.GraphModule passed to it
|
||||||
will be normalized for consumption by `torch_mlir.compile`.
|
will be normalized for consumption by `torchscript.compile`.
|
||||||
Returns:
|
Returns:
|
||||||
A function with the signature used by TorchDynamo backends.
|
A function with the signature used by TorchDynamo backends.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -22,7 +22,7 @@ from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_li
|
||||||
|
|
||||||
|
|
||||||
class OutputType(Enum):
|
class OutputType(Enum):
|
||||||
"""The kind of output that `torch_mlir.compile` can produce.
|
"""The kind of output that `torchscript.compile` can produce.
|
||||||
|
|
||||||
In MLIR terminology, this describes the mix of dialects that will be
|
In MLIR terminology, this describes the mix of dialects that will be
|
||||||
produced by the conversion process.
|
produced by the conversion process.
|
||||||
|
@ -392,13 +392,13 @@ def compile(model: torch.nn.Module,
|
||||||
strip_overloads(model)
|
strip_overloads(model)
|
||||||
|
|
||||||
# Get the model as JIT IR (TorchScript) for import.
|
# Get the model as JIT IR (TorchScript) for import.
|
||||||
# TODO: Longer-term, we probably need to split `torch_mlir.compile`.
|
# TODO: Longer-term, we probably need to split `torchscript.compile`.
|
||||||
# There should be an "acquisition" step that does
|
# There should be an "acquisition" step that does
|
||||||
# tracing/scripting/importing from FX/using torchdynamo.export/etc.
|
# tracing/scripting/importing from FX/using torchdynamo.export/etc.
|
||||||
# + any lowering to the backend contract. Then there should be a
|
# + any lowering to the backend contract. Then there should be a
|
||||||
# "backend lowering" step that does the actual lowering to each
|
# "backend lowering" step that does the actual lowering to each
|
||||||
# backend. This separation should be visible at the Python API level, and
|
# backend. This separation should be visible at the Python API level, and
|
||||||
# we can implement a deliberately simplified API like `torch_mlir.compile`
|
# we can implement a deliberately simplified API like `torchscript.compile`
|
||||||
# on top of those building blocks.
|
# on top of those building blocks.
|
||||||
if isinstance(model, torch.jit.ScriptModule):
|
if isinstance(model, torch.jit.ScriptModule):
|
||||||
# If the user already converted the model to JIT IR themselves, just
|
# If the user already converted the model to JIT IR themselves, just
|
|
@ -6,7 +6,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
||||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
@ -30,7 +30,7 @@ class LinalgOnTensorsBackendTestConfig(TestConfig):
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> Any:
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
example_args = convert_annotations_to_placeholders(program.forward)
|
example_args = convert_annotations_to_placeholders(program.forward)
|
||||||
module = torch_mlir.compile(
|
module = torchscript.compile(
|
||||||
program, example_args, output_type="linalg-on-tensors")
|
program, example_args, output_type="linalg-on-tensors")
|
||||||
|
|
||||||
return self.backend.compile(module)
|
return self.backend.compile(module)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
|
from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend
|
||||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
@ -30,7 +30,7 @@ class StablehloBackendTestConfig(TestConfig):
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> Any:
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
example_args = convert_annotations_to_placeholders(program.forward)
|
example_args = convert_annotations_to_placeholders(program.forward)
|
||||||
module = torch_mlir.compile(program, example_args, output_type="stablehlo")
|
module = torchscript.compile(program, example_args, output_type="stablehlo")
|
||||||
|
|
||||||
return self.backend.compile(module)
|
return self.backend.compile(module)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from torch._functorch.aot_autograd import (
|
||||||
|
|
||||||
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
||||||
from torch_mlir.dynamo import _get_decomposition_table
|
from torch_mlir.dynamo import _get_decomposition_table
|
||||||
from torch_mlir import (
|
from torch_mlir.torchscript import (
|
||||||
_example_args,
|
_example_args,
|
||||||
OutputType,
|
OutputType,
|
||||||
BACKEND_LEGAL_OPS,
|
BACKEND_LEGAL_OPS,
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
|
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
|
||||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
@ -30,7 +30,7 @@ class TosaBackendTestConfig(TestConfig):
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> Any:
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
example_args = convert_annotations_to_placeholders(program.forward)
|
example_args = convert_annotations_to_placeholders(program.forward)
|
||||||
module = torch_mlir.compile(
|
module = torchscript.compile(
|
||||||
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
|
program, example_args, output_type="tosa", use_make_fx=self.use_make_fx)
|
||||||
|
|
||||||
return self.backend.compile(module)
|
return self.backend.compile(module)
|
||||||
|
|
|
@ -3,13 +3,13 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
from torch_mlir import TensorPlaceholder
|
from torch_mlir.torchscript import TensorPlaceholder
|
||||||
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||||
|
|
||||||
def convert_annotations_to_placeholders(forward_method):
|
def convert_annotations_to_placeholders(forward_method):
|
||||||
"""Converts the annotations on a forward method into tensor placeholders.
|
"""Converts the annotations on a forward method into tensor placeholders.
|
||||||
|
|
||||||
These placeholders are suitable for being passed to `torch_mlir.compile`.
|
These placeholders are suitable for being passed to `torchscript.compile`.
|
||||||
"""
|
"""
|
||||||
annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
|
annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
|
||||||
placeholders = []
|
placeholders = []
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
from torch_mlir_e2e_test.annotations import export, annotate_args
|
from torch_mlir_e2e_test.annotations import export, annotate_args
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ def run():
|
||||||
mod = CustomOpExampleModule()
|
mod = CustomOpExampleModule()
|
||||||
mod.eval()
|
mod.eval()
|
||||||
|
|
||||||
module = torch_mlir.compile(
|
module = torchscript.compile(
|
||||||
mod,
|
mod,
|
||||||
torch.ones(3, 4),
|
torch.ones(3, 4),
|
||||||
output_type="torch",
|
output_type="torch",
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||||
|
|
||||||
|
@ -39,6 +39,6 @@ class Model(torch.nn.Module):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return data
|
return data
|
||||||
|
|
||||||
output_type = torch_mlir.OutputType.RAW
|
output_type = torchscript.OutputType.RAW
|
||||||
mod = torch_mlir.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type)
|
mod = torchscript.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type)
|
||||||
print(mod)
|
print(mod)
|
||||||
|
|
|
@ -39,6 +39,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers
|
||||||
extras/onnx_importer.py
|
extras/onnx_importer.py
|
||||||
)
|
)
|
||||||
|
|
||||||
|
declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
|
||||||
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
|
ADD_TO_PARENT TorchMLIRPythonSources
|
||||||
|
SOURCES
|
||||||
|
fx.py
|
||||||
|
)
|
||||||
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonSources.Tools
|
declare_mlir_python_sources(TorchMLIRPythonSources.Tools
|
||||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
ADD_TO_PARENT TorchMLIRPythonSources
|
ADD_TO_PARENT TorchMLIRPythonSources
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.export
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from torch_mlir.extras.fx_importer import FxImporter
|
||||||
|
from torch_mlir import ir
|
||||||
|
from torch_mlir.dialects import torch as torch_d
|
||||||
|
|
||||||
|
def export_and_import(
|
||||||
|
f,
|
||||||
|
*args,
|
||||||
|
fx_importer: Optional[FxImporter] = None,
|
||||||
|
constraints: Optional[torch.export.Constraint] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
context = ir.Context()
|
||||||
|
torch_d.register_dialect(context)
|
||||||
|
|
||||||
|
if fx_importer is None:
|
||||||
|
fx_importer = FxImporter(context=context)
|
||||||
|
prog = torch.export.export(f, args, kwargs, constraints=constraints)
|
||||||
|
fx_importer.import_frozen_exported_program(prog)
|
||||||
|
return fx_importer.module_op
|
|
@ -3,7 +3,7 @@
|
||||||
import gc
|
import gc
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
from torch_mlir import torchscript
|
||||||
|
|
||||||
|
|
||||||
def run_test(f):
|
def run_test(f):
|
||||||
|
@ -26,7 +26,7 @@ class TinyModel(torch.nn.Module):
|
||||||
# CHECK-LABEL: TEST: test_enable_ir_printing
|
# CHECK-LABEL: TEST: test_enable_ir_printing
|
||||||
@run_test
|
@run_test
|
||||||
def test_enable_ir_printing():
|
def test_enable_ir_printing():
|
||||||
torch_mlir.compile(TinyModel(),
|
torchscript.compile(TinyModel(),
|
||||||
torch.ones(1, 3, 20, 20),
|
torch.ones(1, 3, 20, 20),
|
||||||
output_type="linalg-on-tensors",
|
output_type="linalg-on-tensors",
|
||||||
enable_ir_printing=True)
|
enable_ir_printing=True)
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
# Copyright 2023 Advanced Micro Devices, Inc
|
|
||||||
#
|
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
@ -13,26 +11,7 @@ import torch
|
||||||
import torch.export
|
import torch.export
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch_mlir.extras.fx_importer import FxImporter
|
from torch_mlir import fx
|
||||||
from torch_mlir import ir
|
|
||||||
from torch_mlir.dialects import torch as torch_d
|
|
||||||
|
|
||||||
|
|
||||||
def export_and_import(
|
|
||||||
f,
|
|
||||||
*args,
|
|
||||||
fx_importer: Optional[FxImporter] = None,
|
|
||||||
constraints: Optional[torch.export.Constraint] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
context = ir.Context()
|
|
||||||
torch_d.register_dialect(context)
|
|
||||||
|
|
||||||
if fx_importer is None:
|
|
||||||
fx_importer = FxImporter(context=context)
|
|
||||||
prog = torch.export.export(f, args, kwargs, constraints=constraints)
|
|
||||||
fx_importer.import_frozen_exported_program(prog)
|
|
||||||
return fx_importer.module_op
|
|
||||||
|
|
||||||
|
|
||||||
def run(f):
|
def run(f):
|
||||||
|
@ -75,5 +54,5 @@ def test_import_frozen_exported_program():
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.tanh(x) * get_a() * self.b * self.p
|
return torch.tanh(x) * get_a() * self.b * self.p
|
||||||
|
|
||||||
m = export_and_import(Basic(), torch.randn(3, 4))
|
m = fx.export_and_import(Basic(), torch.randn(3, 4))
|
||||||
print(m)
|
print(m)
|
||||||
|
|
Loading…
Reference in New Issue