mirror of https://github.com/llvm/torch-mlir
Breakup python pytorch deps (#2582)
This lifts the core of the jit_ir_importer and ltc out of the pt1 project, making them peers to it. As a side-effect of this layering, now the "MLIR bits" (dialects, etc) are not commingled with the various parts of the pt1 project, allowing pt1 and ltc to overlay cleanly onto a more fundamental "just MLIR" Python core. Prior to this, the Python namespace was polluted to the point that this could not happen. That "just MLIR" Python core will be introduced in a followup, which will create the space to upstream the FX and ONNX pure Python importers. This primary non-NFC change to the API is: * `torch_mlir.dialects.torch.importer.jit_ir` -> `torch_mlir.jit_ir_importer`. The rest is source code layering so that we can make the pt1 project optional without losing the other features. Progress on #2546.pull/2585/head
parent
facbe5d96b
commit
5eae0adff1
|
@ -27,15 +27,12 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
os-arch: [ubuntu-x86_64, macos-arm64, windows-x86_64]
|
os-arch: [ubuntu-x86_64, macos-arm64, windows-x86_64]
|
||||||
llvm-build: [in-tree, out-of-tree]
|
llvm-build: [in-tree, out-of-tree]
|
||||||
torch-binary: [ON, OFF]
|
torch-binary: [ON]
|
||||||
torch-version: [nightly, stable]
|
torch-version: [nightly, stable]
|
||||||
exclude:
|
exclude:
|
||||||
# Exclude llvm in-tree and pytorch source
|
# Exclude llvm out-of-tree and pytorch stable (to save resources)
|
||||||
- llvm-build: in-tree
|
|
||||||
torch-binary: OFF
|
|
||||||
# Exclude llvm out-of-tree and pytorch binary
|
|
||||||
- llvm-build: out-of-tree
|
- llvm-build: out-of-tree
|
||||||
torch-binary: ON
|
torch-version: stable
|
||||||
# Exclude macos-arm64 and llvm out-of-tree altogether
|
# Exclude macos-arm64 and llvm out-of-tree altogether
|
||||||
- os-arch: macos-arm64
|
- os-arch: macos-arm64
|
||||||
llvm-build: out-of-tree
|
llvm-build: out-of-tree
|
||||||
|
@ -45,9 +42,6 @@ jobs:
|
||||||
llvm-build: out-of-tree
|
llvm-build: out-of-tree
|
||||||
- os-arch: windows-x86_64
|
- os-arch: windows-x86_64
|
||||||
torch-version: stable
|
torch-version: stable
|
||||||
# For PyTorch stable builds, we don't build PyTorch from source
|
|
||||||
- torch-version: stable
|
|
||||||
torch-binary: OFF
|
|
||||||
include:
|
include:
|
||||||
# Specify OS versions
|
# Specify OS versions
|
||||||
- os-arch: ubuntu-x86_64
|
- os-arch: ubuntu-x86_64
|
||||||
|
|
|
@ -26,7 +26,7 @@ __pycache__
|
||||||
bazel-*
|
bazel-*
|
||||||
|
|
||||||
# Autogenerated files
|
# Autogenerated files
|
||||||
/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated
|
/projects/ltc/csrc/base_lazy_backend/generated
|
||||||
|
|
||||||
#Docker builds
|
#Docker builds
|
||||||
build_oot/
|
build_oot/
|
||||||
|
|
|
@ -149,10 +149,12 @@ endfunction()
|
||||||
# Configure CMake.
|
# Configure CMake.
|
||||||
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
|
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
|
||||||
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
|
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
|
||||||
|
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/build_tools/cmake)
|
||||||
|
|
||||||
include(TableGen)
|
include(TableGen)
|
||||||
include(AddLLVM)
|
include(AddLLVM)
|
||||||
include(AddMLIR)
|
include(AddMLIR)
|
||||||
|
include(AddMLIRPython)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Setup python.
|
# Setup python.
|
||||||
|
@ -231,6 +233,4 @@ endif()
|
||||||
# Sub-projects
|
# Sub-projects
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_PROJECT_PT1)
|
add_subdirectory(projects)
|
||||||
add_subdirectory(projects/pt1)
|
|
||||||
endif()
|
|
||||||
|
|
|
@ -29,7 +29,6 @@ if not TORCH_INCLUDE_DIR.is_dir():
|
||||||
TORCH_INCLUDE_DIR = TORCH_DIR
|
TORCH_INCLUDE_DIR = TORCH_DIR
|
||||||
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
|
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
|
||||||
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
||||||
TORCH_MLIR_PT1_DIR = TORCH_MLIR_DIR / "projects" / "pt1"
|
|
||||||
|
|
||||||
def reindent(text, prefix=""):
|
def reindent(text, prefix=""):
|
||||||
return indent(dedent(text), prefix)
|
return indent(dedent(text), prefix)
|
||||||
|
@ -114,12 +113,12 @@ class GenTorchMlirLTC:
|
||||||
self.binary_dir = Path(binary_dir)
|
self.binary_dir = Path(binary_dir)
|
||||||
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
|
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
|
||||||
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
|
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
|
||||||
self.backend_path = TORCH_MLIR_PT1_DIR.joinpath(
|
self.backend_path = TORCH_MLIR_DIR.joinpath(
|
||||||
"python", "torch_mlir", "csrc", "base_lazy_backend"
|
"projects", "ltc", "csrc", "base_lazy_backend"
|
||||||
)
|
)
|
||||||
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
|
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
|
||||||
self.generated_path = self.binary_dir.joinpath(
|
self.generated_path = self.binary_dir.joinpath(
|
||||||
"projects", "pt1", "python", "torch_mlir", "csrc", "base_lazy_backend", "generated"
|
"projects", "ltc", "csrc", "base_lazy_backend", "generated"
|
||||||
)
|
)
|
||||||
self.generated_path.mkdir(parents=True, exist_ok=True)
|
self.generated_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -415,7 +414,7 @@ class GenTorchMlirLTC:
|
||||||
// for ops that dont have a corresponding structured kernel or shape definition
|
// for ops that dont have a corresponding structured kernel or shape definition
|
||||||
|
|
||||||
#include "shape_inference.h"
|
#include "shape_inference.h"
|
||||||
#include "torch_mlir/csrc/base_lazy_backend/utils/exception.h"
|
#include "base_lazy_backend/utils/exception.h"
|
||||||
namespace torch {{
|
namespace torch {{
|
||||||
namespace lazy {{
|
namespace lazy {{
|
||||||
{}
|
{}
|
||||||
|
@ -467,7 +466,7 @@ class GenTorchMlirLTC:
|
||||||
node_base="torch::lazy::TorchMlirNode",
|
node_base="torch::lazy::TorchMlirNode",
|
||||||
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
|
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
|
||||||
tensor_class=self.tensor_class,
|
tensor_class=self.tensor_class,
|
||||||
tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h",
|
tensor_class_hdr="base_lazy_backend/tensor.h",
|
||||||
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
|
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
|
||||||
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
|
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
|
||||||
lazy_ir_generator=GenMlirLazyIr,
|
lazy_ir_generator=GenMlirLazyIr,
|
||||||
|
|
|
@ -364,9 +364,9 @@ function setup_venv() {
|
||||||
function build_out_of_tree() {
|
function build_out_of_tree() {
|
||||||
local torch_from_bin="$1"
|
local torch_from_bin="$1"
|
||||||
local python_version="$2"
|
local python_version="$2"
|
||||||
echo ":::: Build out-of-tree Torch from binary: $torch_from_bin with Python: $python_version"
|
|
||||||
|
|
||||||
local torch_version="$3"
|
local torch_version="$3"
|
||||||
|
echo ":::: Build out-of-tree Torch from binary: $torch_from_bin with Python: $python_version ($torch_version)"
|
||||||
|
|
||||||
local enable_ltc="ON"
|
local enable_ltc="ON"
|
||||||
if [[ "${torch_version}" == "stable" ]]
|
if [[ "${torch_version}" == "stable" ]]
|
||||||
then
|
then
|
||||||
|
|
|
@ -42,6 +42,6 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
PYTHONPATH="${pypath}" python \
|
PYTHONPATH="${pypath}" python \
|
||||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.abstract_interp_lib_gen \
|
-m torch_mlir.jit_ir_importer.build_tools.abstract_interp_lib_gen \
|
||||||
--pytorch_op_extensions=${ext_module:-""} \
|
--pytorch_op_extensions=${ext_module:-""} \
|
||||||
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
||||||
|
|
|
@ -43,7 +43,7 @@ fi
|
||||||
|
|
||||||
set +u
|
set +u
|
||||||
PYTHONPATH="${PYTHONPATH}:${pypath}" python \
|
PYTHONPATH="${PYTHONPATH}:${pypath}" python \
|
||||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
|
-m torch_mlir.jit_ir_importer.build_tools.torch_ods_gen \
|
||||||
--torch_ir_include_dir="${torch_ir_include_dir}" \
|
--torch_ir_include_dir="${torch_ir_include_dir}" \
|
||||||
--pytorch_op_extensions="${ext_module}" \
|
--pytorch_op_extensions="${ext_module}" \
|
||||||
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"
|
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"
|
||||||
|
|
|
@ -17,7 +17,7 @@ The end-to-end test is important to check the correctness of the other steps.
|
||||||
|
|
||||||
### Step 2. Update ods
|
### Step 2. Update ods
|
||||||
|
|
||||||
Update [torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/main/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py) with the new op and run [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh) to generate the ods. Running `update_torch_ods.sh` would dump all the operators with schema into `JITOperatorRegistryDump.txt`. It’s convenient to look for ops signatures and operands names in this file.
|
Update [torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/main/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py) with the new op and run [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh) to generate the ods. Running `update_torch_ods.sh` would dump all the operators with schema into `JITOperatorRegistryDump.txt`. It’s convenient to look for ops signatures and operands names in this file.
|
||||||
|
|
||||||
### Step 3. Propagate types
|
### Step 3. Propagate types
|
||||||
It’s essential to make sure the new op implements shape and dtype inference. See [abstract_interp_lib](https://github.com/llvm/torch-mlir/blob/main/docs/abstract_interp_lib.md) for information on adding shape and dtype inference.
|
It’s essential to make sure the new op implements shape and dtype inference. See [abstract_interp_lib](https://github.com/llvm/torch-mlir/blob/main/docs/abstract_interp_lib.md) for information on adding shape and dtype inference.
|
||||||
|
|
|
@ -26,7 +26,7 @@ The two main use cases are:
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
Functions are defined as TorchScript-able Python functions in
|
Functions are defined as TorchScript-able Python functions in
|
||||||
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py`.
|
`python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py`.
|
||||||
The signatures of the functions are systematically derived from Torch JIT
|
The signatures of the functions are systematically derived from Torch JIT
|
||||||
operator registry. Most shape functions are expected to reuse the upstream
|
operator registry. Most shape functions are expected to reuse the upstream
|
||||||
helper functions
|
helper functions
|
||||||
|
|
|
@ -87,7 +87,7 @@ following order:
|
||||||
|
|
||||||
1. Shape of input tensor. Use `-1` for dynamic dimensions
|
1. Shape of input tensor. Use `-1` for dynamic dimensions
|
||||||
2. Dtype of the input tensor
|
2. Dtype of the input tensor
|
||||||
3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h#L54-L67). This
|
3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/jit_ir_importer/csrc/class_annotator.h#L54-L67). This
|
||||||
will always be true for E2E tests, since the [Torch-MLIR backend contract](architecture.md#the-backend-contract) requires all tensors in the
|
will always be true for E2E tests, since the [Torch-MLIR backend contract](architecture.md#the-backend-contract) requires all tensors in the
|
||||||
IR to eventually have value semantics.
|
IR to eventually have value semantics.
|
||||||
|
|
||||||
|
|
|
@ -55,14 +55,14 @@ factored such that we can handle this with one core import path, which is
|
||||||
through the PyTorch
|
through the PyTorch
|
||||||
"[JIT IR](https://github.com/pytorch/pytorch/blob/78c8a0d75220bdd4955415b5f81509e005af4232/torch/csrc/jit/OVERVIEW.md)",
|
"[JIT IR](https://github.com/pytorch/pytorch/blob/78c8a0d75220bdd4955415b5f81509e005af4232/torch/csrc/jit/OVERVIEW.md)",
|
||||||
and lives in
|
and lives in
|
||||||
[torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir](https://github.com/llvm/torch-mlir/tree/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir).
|
[torch-mlir/python/torch_mlir/jit_ir_importer](https://github.com/llvm/torch-mlir/tree/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir).
|
||||||
The JIT IR is a highly principled IR that faithfully models a Python subset (+
|
The JIT IR is a highly principled IR that faithfully models a Python subset (+
|
||||||
tensors, the PyTorch op registry, and a few other things). All the other PyTorch
|
tensors, the PyTorch op registry, and a few other things). All the other PyTorch
|
||||||
program representations can eventually bottom-out on the JIT IR via some path
|
program representations can eventually bottom-out on the JIT IR via some path
|
||||||
provided by PyTorch. The `torch` dialect is almost entirely in 1:1
|
provided by PyTorch. The `torch` dialect is almost entirely in 1:1
|
||||||
correspondence with the JIT IR -- this allows the importer to be extremely small
|
correspondence with the JIT IR -- this allows the importer to be extremely small
|
||||||
(the core is
|
(the core is
|
||||||
[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp#L1)).
|
[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/jit_ir_importer/csrc/node_importer.cpp#L1)).
|
||||||
|
|
||||||
### Ops
|
### Ops
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ See [TorchOps.td](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83
|
||||||
|
|
||||||
The ops in the `torch` dialect are almost entirely generated based on the
|
The ops in the `torch` dialect are almost entirely generated based on the
|
||||||
PyTorch JIT IR operator registry via the script
|
PyTorch JIT IR operator registry via the script
|
||||||
[torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py#L1) (invoked via [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh)).
|
[torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py#L1) (invoked via [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh)).
|
||||||
This script queries the registry and generates MLIR
|
This script queries the registry and generates MLIR
|
||||||
[ODS](https://mlir.llvm.org/docs/OpDefinitions/) in
|
[ODS](https://mlir.llvm.org/docs/OpDefinitions/) in
|
||||||
[GeneratedTorchOps.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L1). We have a guide for [adding a new op end-to-end](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation).
|
[GeneratedTorchOps.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L1). We have a guide for [adding a new op end-to-end](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation).
|
||||||
|
@ -195,7 +195,7 @@ values. When one `torch.jit.script`'s a `torch.nn.Module`, the result is
|
||||||
actually an `IValue` that represents the module, with a hierarchy of children
|
actually an `IValue` that represents the module, with a hierarchy of children
|
||||||
`IValue`'s. Strictly speaking, JIT IR `torch::jit::Graph`'s are only used to
|
`IValue`'s. Strictly speaking, JIT IR `torch::jit::Graph`'s are only used to
|
||||||
represent the bodies of methods on the modules. So in addition to importing the
|
represent the bodies of methods on the modules. So in addition to importing the
|
||||||
JIT IR, we also need to import the `IValue`'s. This happens inside [ivalue_importer.cpp](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp#L1).
|
JIT IR, we also need to import the `IValue`'s. This happens inside [ivalue_importer.cpp](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.cpp#L1).
|
||||||
|
|
||||||
Most of the IValue modeling can reuse `torch` dialect ops that already exist
|
Most of the IValue modeling can reuse `torch` dialect ops that already exist
|
||||||
otherwise, such as `torch.constant.int` to represent an int in the object graph.
|
otherwise, such as `torch.constant.int` to represent an int in the object graph.
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
[Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR.
|
[Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR.
|
||||||
After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation.
|
After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation.
|
||||||
|
|
||||||
LTC support is provided through an abstract [`TorchMlirBackendImpl`](../python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR.
|
LTC support is provided through an abstract [`TorchMlirBackendImpl`](../projects/ltc/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR.
|
||||||
Implementations based on this abstract class will be able to specify their own compile and execution workflows.
|
Implementations based on this abstract class will be able to specify their own compile and execution workflows.
|
||||||
Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend).
|
Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend).
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ View examples [here](ltc_examples.md).
|
||||||
- The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td),
|
- The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td),
|
||||||
excluding those explicitly blacklisted in the YAML file
|
excluding those explicitly blacklisted in the YAML file
|
||||||
|
|
||||||
### Autogen Files ([`python/torch_mlir/csrc/base_lazy_backend/generated`](../python/torch_mlir/csrc/base_lazy_backend/generated))
|
### Autogen Files ([`projects/ltc/csrc/base_lazy_backend/generated`](../projects/ltc/csrc/base_lazy_backend/generated))
|
||||||
Generated files are created in this directory, which is ignored by version control.
|
Generated files are created in this directory, which is ignored by version control.
|
||||||
|
|
||||||
- `LazyIr.h`
|
- `LazyIr.h`
|
||||||
|
@ -41,7 +41,7 @@ Generated files are created in this directory, which is ignored by version contr
|
||||||
- `shape_inference.{cpp,h}`
|
- `shape_inference.{cpp,h}`
|
||||||
- Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions
|
- Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions
|
||||||
|
|
||||||
### Base Backend ([`python/torch_mlir/csrc/base_lazy_backend`](../python/torch_mlir/csrc/base_lazy_backend))
|
### Base Backend ([`projects/ltc/csrc/base_lazy_backend`](../projects/ltc/csrc/base_lazy_backend))
|
||||||
|
|
||||||
- `backend_impl.{cpp,h}`
|
- `backend_impl.{cpp,h}`
|
||||||
- Base LTC backend to setup Torch-MLIR lowering context
|
- Base LTC backend to setup Torch-MLIR lowering context
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
// This file is automatically generated. Please do not edit.
|
// This file is automatically generated. Please do not edit.
|
||||||
// Generated via:
|
// Generated via:
|
||||||
// ```
|
// ```
|
||||||
// python -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen
|
// python -m torch_mlir.jit_ir_importer.build_tools.torch_ods_gen
|
||||||
// ```
|
// ```
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,44 @@
|
||||||
|
include(AddMLIRPython)
|
||||||
|
|
||||||
|
# Configure PyTorch if we have any features enabled which require it.
|
||||||
|
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
|
||||||
|
message(STATUS "Enabling PyTorch C++ dep (features depend on it)")
|
||||||
|
include(TorchMLIRPyTorch)
|
||||||
|
|
||||||
|
TorchMLIRProbeForPyTorchInstall()
|
||||||
|
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||||
|
TorchMLIRConfigurePyTorch()
|
||||||
|
else()
|
||||||
|
# Assume it is a sibling to the overall project.
|
||||||
|
set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
||||||
|
message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_package(Torch 1.11 REQUIRED)
|
||||||
|
|
||||||
|
set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen)
|
||||||
|
|
||||||
|
include_directories(BEFORE
|
||||||
|
${TORCH_INCLUDE_DIRS}
|
||||||
|
${Python3_INCLUDE_DIRS}
|
||||||
|
)
|
||||||
|
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
message(STATUS "libtorch_python CXXFLAGS is ...${TORCH_CXXFLAGS}")
|
||||||
|
message(STATUS "TORCH_LIBRARIES = ${TORCH_LIBRARIES}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Include jit_ir_common if the jit_ir importer or LTC is enabled,
|
||||||
|
# since they both require it.
|
||||||
|
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
|
||||||
|
add_subdirectory(jit_ir_common)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Include LTC.
|
||||||
|
if(TORCH_MLIR_ENABLE_LTC)
|
||||||
|
add_subdirectory(ltc)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Include overall PT1 project.
|
||||||
|
if(TORCH_MLIR_ENABLE_PROJECT_PT1)
|
||||||
|
add_subdirectory(pt1)
|
||||||
|
endif()
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(csrc/jit_ir_importer)
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Static library with core functionality.
|
||||||
|
# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build)
|
||||||
|
# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376
|
||||||
|
add_library(TorchMLIRJITIRImporter STATIC
|
||||||
|
class_annotator.cpp
|
||||||
|
function_importer.cpp
|
||||||
|
node_importer.cpp
|
||||||
|
ivalue_importer.cpp
|
||||||
|
torch_to_mlir_utils.cpp
|
||||||
|
)
|
||||||
|
message(STATUS "Linking TorchMLIRJITImporter with ${TORCH_LIBRARIES}")
|
||||||
|
target_link_libraries(TorchMLIRJITIRImporter
|
||||||
|
TorchMLIRAggregateCAPI
|
||||||
|
${TORCH_LIBRARIES}
|
||||||
|
)
|
||||||
|
# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...")
|
||||||
|
target_include_directories(TorchMLIRJITIRImporter PUBLIC
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||||
|
)
|
||||||
|
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
||||||
|
OUTPUT_NAME lib_jit_ir_importer
|
||||||
|
PREFIX ""
|
||||||
|
SUFFIX ".a"
|
||||||
|
CXX_VISIBILITY_PRESET "default"
|
||||||
|
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
|
||||||
|
)
|
|
@ -190,7 +190,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||||
mlirRegionCreate());
|
mlirRegionCreate());
|
||||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
|
mlirRegionAppendOwnedBlock(nnModuleRegion,
|
||||||
|
mlirBlockCreate(0, nullptr, nullptr));
|
||||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||||
InserterGuard inserterGuard(importBlock, nnModule);
|
InserterGuard inserterGuard(importBlock, nnModule);
|
||||||
|
|
||||||
|
@ -491,8 +492,9 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"name", mlirStringAttrGet(
|
"name", mlirStringAttrGet(
|
||||||
context, toMlirStringRef(classAttribute.getName()))),
|
context, toMlirStringRef(classAttribute.getName()))),
|
||||||
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
toMlirNamedAttribute(
|
||||||
loc, classAttribute.getType(), importOptions))),
|
"type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
||||||
|
loc, classAttribute.getType(), importOptions))),
|
||||||
isPrivate);
|
isPrivate);
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,10 +41,9 @@ public:
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MlirBlock
|
MlirBlock createBlockFor(Block *jitBlock,
|
||||||
createBlockFor(Block *jitBlock,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
const ImportOptions &importOptions = {});
|
||||||
const ImportOptions &importOptions = {});
|
|
||||||
void mapValue(Value *jitValue, MlirValue value);
|
void mapValue(Value *jitValue, MlirValue value);
|
||||||
void mapResults(Node *node, MlirOperation operation);
|
void mapResults(Node *node, MlirOperation operation);
|
||||||
MlirValue lookupMappedValue(Value *jitValue);
|
MlirValue lookupMappedValue(Value *jitValue);
|
||||||
|
@ -269,9 +268,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
terminatorOperandTypes,
|
terminatorOperandTypes,
|
||||||
/*userAllowsRefinement=*/false));
|
/*userAllowsRefinement=*/false));
|
||||||
};
|
};
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
||||||
mlirOperationGetRegion(operation, 0),
|
importBlock(node->blocks()[0], createTerminator,
|
||||||
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
c10::nullopt, importOptions));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,12 +289,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
resultTypes,
|
resultTypes,
|
||||||
/*userAllowsRefinement=*/false));
|
/*userAllowsRefinement=*/false));
|
||||||
};
|
};
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
||||||
mlirOperationGetRegion(operation, 0),
|
importBlock(node->blocks()[0], createTerminator,
|
||||||
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
c10::nullopt, importOptions));
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1),
|
||||||
mlirOperationGetRegion(operation, 1),
|
importBlock(node->blocks()[1], createTerminator,
|
||||||
importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions));
|
c10::nullopt, importOptions));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,8 +302,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
||||||
auto methodName = node->s(c10::attr::name);
|
auto methodName = node->s(c10::attr::name);
|
||||||
torch::jit::Function *function = classType->findMethod(methodName);
|
torch::jit::Function *function = classType->findMethod(methodName);
|
||||||
MlirType calleeType =
|
MlirType calleeType = getFunctionTypeFromSchema(
|
||||||
getFunctionTypeFromSchema(context, function->getSchema(), importOptions);
|
context, function->getSchema(), importOptions);
|
||||||
std::vector<MlirType> expectedTypes;
|
std::vector<MlirType> expectedTypes;
|
||||||
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
|
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
|
||||||
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
|
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
|
||||||
|
@ -361,10 +360,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock NodeImporter::importBlock(
|
MlirBlock
|
||||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
NodeImporter::importBlock(Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions &importOptions) {
|
||||||
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
|
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
|
||||||
for (Node *node : jitBlock->nodes()) {
|
for (Node *node : jitBlock->nodes()) {
|
||||||
importNode(node, block, importOptions);
|
importNode(node, block, importOptions);
|
||||||
|
@ -434,5 +433,6 @@ torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||||
const ImportOptions &importOptions) {
|
const ImportOptions &importOptions) {
|
||||||
NodeImporter importer(context);
|
NodeImporter importer(context);
|
||||||
return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions);
|
return importer.importBlock(jitBlock, createTerminator, blockArgTypes,
|
||||||
|
importOptions);
|
||||||
}
|
}
|
|
@ -36,11 +36,11 @@ using CreateTerminatorFn =
|
||||||
/// are required to be for correctness. The code will internally attempt to
|
/// are required to be for correctness. The code will internally attempt to
|
||||||
/// adjust the types to the block argument types.
|
/// adjust the types to the block argument types.
|
||||||
/// TODO: Formalize what type conversions are allowed here.
|
/// TODO: Formalize what type conversions are allowed here.
|
||||||
MlirBlock importBlock(
|
MlirBlock
|
||||||
MlirContext context, torch::jit::Block *jitBlock,
|
importBlock(MlirContext context, torch::jit::Block *jitBlock,
|
||||||
CreateTerminatorFn createTerminator,
|
CreateTerminatorFn createTerminator,
|
||||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
||||||
const ImportOptions &importOptions = {});
|
const ImportOptions &importOptions = {});
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
add_subdirectory(csrc/base_lazy_backend)
|
|
@ -2,30 +2,6 @@
|
||||||
# Setup PyTorch/LTC
|
# Setup PyTorch/LTC
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
include(TorchMLIRPyTorch)
|
|
||||||
|
|
||||||
TorchMLIRProbeForPyTorchInstall()
|
|
||||||
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
|
||||||
TorchMLIRConfigurePyTorch()
|
|
||||||
else()
|
|
||||||
# Assume it is a sibling to the overall project.
|
|
||||||
set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
|
||||||
message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_package(Torch 1.11 REQUIRED)
|
|
||||||
|
|
||||||
set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen)
|
|
||||||
|
|
||||||
include_directories(BEFORE
|
|
||||||
${TORCH_INCLUDE_DIRS}
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
${Python3_INCLUDE_DIRS}
|
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python
|
|
||||||
)
|
|
||||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
|
||||||
|
|
||||||
set(LTC_GENERATED
|
set(LTC_GENERATED
|
||||||
generated/LazyNativeFunctions.cpp
|
generated/LazyNativeFunctions.cpp
|
||||||
generated/RegisterLazy.cpp
|
generated/RegisterLazy.cpp
|
||||||
|
@ -80,6 +56,12 @@ add_library(torch_mlir_ltc_backend SHARED
|
||||||
utils/tensor_utils.cpp
|
utils/tensor_utils.cpp
|
||||||
)
|
)
|
||||||
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
||||||
|
# Includes are resolved relative to csrc (i.e. #include "base_lazy_backend/...").
|
||||||
|
# Add both the source and generated include directories.
|
||||||
|
target_include_directories(torch_mlir_ltc_backend PUBLIC
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/..
|
||||||
|
)
|
||||||
|
|
||||||
add_dependencies(torch_mlir_ltc_backend
|
add_dependencies(torch_mlir_ltc_backend
|
||||||
TorchMLIRJITIRImporter
|
TorchMLIRJITIRImporter
|
||||||
|
@ -112,13 +94,13 @@ add_custom_command(
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||||
COMMAND cp
|
COMMAND cp
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/*.h
|
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/*.h
|
||||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/)
|
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/)
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||||
COMMAND cp
|
COMMAND cp
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated/*.h
|
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/generated/*.h
|
||||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/)
|
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/)
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
@ -129,7 +111,7 @@ add_custom_command(
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||||
COMMAND cp
|
COMMAND cp
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/*.h
|
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/ops/*.h
|
||||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/)
|
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/)
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
@ -140,5 +122,5 @@ add_custom_command(
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||||
COMMAND cp
|
COMMAND cp
|
||||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/*.h
|
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/utils/*.h
|
||||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/)
|
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/)
|
|
@ -21,8 +21,8 @@
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
#include "mlir-c/Pass.h"
|
#include "mlir-c/Pass.h"
|
||||||
|
|
||||||
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
|
|
||||||
#include "backend_impl.h"
|
#include "backend_impl.h"
|
||||||
|
#include "jit_ir_importer/function_importer.h"
|
||||||
#include "mlir_lowering_context.h"
|
#include "mlir_lowering_context.h"
|
||||||
#include "mlir_node.h"
|
#include "mlir_node.h"
|
||||||
#include "utils/debug.h"
|
#include "utils/debug.h"
|
|
@ -92,8 +92,8 @@
|
||||||
"import torchvision\n",
|
"import torchvision\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import torch_mlir\n",
|
"import torch_mlir\n",
|
||||||
"from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder\n",
|
"from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder\n",
|
||||||
"from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations\n",
|
"from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from torch_mlir.passmanager import PassManager\n",
|
"from torch_mlir.passmanager import PassManager\n",
|
||||||
"from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend"
|
"from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend"
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
include(AddMLIRPython)
|
|
||||||
|
|
||||||
# Disables generation of "version soname" (i.e. libFoo.so.<version>), which
|
# Disables generation of "version soname" (i.e. libFoo.so.<version>), which
|
||||||
# causes pure duplication as part of Python wheels.
|
# causes pure duplication as part of Python wheels.
|
||||||
set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON)
|
set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON)
|
||||||
|
@ -90,9 +88,6 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
|
||||||
# Lazy Tensor Core
|
# Lazy Tensor Core
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_LTC)
|
|
||||||
add_subdirectory(torch_mlir/csrc/base_lazy_backend)
|
|
||||||
endif()
|
|
||||||
# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it
|
# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it
|
||||||
# generates a dummy Python library when disabled.
|
# generates a dummy Python library when disabled.
|
||||||
if(NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
if(NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
||||||
|
@ -104,7 +99,8 @@ endif()
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
||||||
add_subdirectory(torch_mlir/dialects/torch/importer/jit_ir)
|
add_subdirectory(torch_mlir/jit_ir_importer)
|
||||||
|
add_subdirectory(torch_mlir/csrc/jit_ir_importer)
|
||||||
add_subdirectory(torch_mlir_e2e_test)
|
add_subdirectory(torch_mlir_e2e_test)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator
|
from torch_mlir.jit_ir_importer import ClassAnnotator
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations
|
||||||
|
|
||||||
class MmModule(torch.nn.Module):
|
class MmModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -17,8 +17,8 @@ from torch_mlir.dynamo import _get_decomposition_table
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
|
||||||
from .compiler_utils import run_pipeline_with_repro_report
|
from .compiler_utils import run_pipeline_with_repro_report
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library
|
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
||||||
|
|
||||||
|
|
||||||
class OutputType(Enum):
|
class OutputType(Enum):
|
||||||
|
|
|
@ -1,39 +1,3 @@
|
||||||
# Sharp edge: Torch extensions need to use the same pybind11 that torch
|
|
||||||
# was compiled with, or else there will be issues in cross module exception
|
|
||||||
# handling (which will abort instead of raise). We circumvent the possibility
|
|
||||||
# by forcing the torch directories first.
|
|
||||||
include_directories(BEFORE
|
|
||||||
${TORCH_INCLUDE_DIRS}
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
${Python3_INCLUDE_DIRS}
|
|
||||||
)
|
|
||||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
|
||||||
|
|
||||||
# Static library with core functionality.
|
|
||||||
# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build)
|
|
||||||
# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376
|
|
||||||
add_library(TorchMLIRJITIRImporter STATIC
|
|
||||||
class_annotator.cpp
|
|
||||||
function_importer.cpp
|
|
||||||
node_importer.cpp
|
|
||||||
ivalue_importer.cpp
|
|
||||||
torch_to_mlir_utils.cpp
|
|
||||||
)
|
|
||||||
target_link_libraries(TorchMLIRJITIRImporter
|
|
||||||
TorchMLIRAggregateCAPI
|
|
||||||
${TORCH_LIBRARIES}
|
|
||||||
)
|
|
||||||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}")
|
|
||||||
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
|
|
||||||
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
|
||||||
OUTPUT_NAME lib_jit_ir_importer
|
|
||||||
PREFIX ""
|
|
||||||
SUFFIX ".a"
|
|
||||||
CXX_VISIBILITY_PRESET "default"
|
|
||||||
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Separate Pybind MODULE due to issues with a SHARED library.
|
# Separate Pybind MODULE due to issues with a SHARED library.
|
||||||
# https://github.com/llvm/torch-mlir/issues/1154
|
# https://github.com/llvm/torch-mlir/issues/1154
|
||||||
add_library(TorchMLIRJITIRImporterPybind MODULE
|
add_library(TorchMLIRJITIRImporterPybind MODULE
|
||||||
|
@ -62,7 +26,6 @@ if(Python3_LIBRARIES)
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}")
|
|
||||||
set_target_properties(TorchMLIRJITIRImporterPybind PROPERTIES
|
set_target_properties(TorchMLIRJITIRImporterPybind PROPERTIES
|
||||||
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
||||||
OUTPUT_NAME _jit_ir_importer
|
OUTPUT_NAME _jit_ir_importer
|
|
@ -8,7 +8,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "class_annotator_pybind.h"
|
#include "class_annotator_pybind.h"
|
||||||
#include "class_annotator.h"
|
#include "jit_ir_importer/class_annotator.h"
|
||||||
|
|
||||||
#include <torch/csrc/Dtype.h>
|
#include <torch/csrc/Dtype.h>
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
@ -18,7 +18,7 @@ using namespace torch_mlir;
|
||||||
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
||||||
if (THPDtype_Check(obj.ptr())) {
|
if (THPDtype_Check(obj.ptr())) {
|
||||||
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
||||||
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
|
THPDtype* dtype = reinterpret_cast<THPDtype*>(obj.ptr());
|
||||||
return dtype->scalar_type;
|
return dtype->scalar_type;
|
||||||
}
|
}
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
@ -48,16 +48,17 @@ static std::vector<ArgAnnotation> getArgAnnotations(py::list pyArgAnnotations) {
|
||||||
return argAnnotations;
|
return argAnnotations;
|
||||||
}
|
}
|
||||||
|
|
||||||
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
|
void torch_mlir::initClassAnnotatorBindings(py::module& m) {
|
||||||
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def("exportPath", &ClassAnnotator::exportPath)
|
.def("exportPath", &ClassAnnotator::exportPath)
|
||||||
.def("exportNone", &ClassAnnotator::exportNone)
|
.def("exportNone", &ClassAnnotator::exportNone)
|
||||||
.def("annotateArgs",
|
.def(
|
||||||
[&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType,
|
"annotateArgs",
|
||||||
std::vector<std::string> path, py::list argAnnotations) {
|
[&](ClassAnnotator& cls_annotator, c10::ClassType& rootClassType,
|
||||||
cls_annotator.annotateArgs(rootClassType, path,
|
std::vector<std::string> path, py::list argAnnotations) {
|
||||||
getArgAnnotations(argAnnotations));
|
cls_annotator.annotateArgs(
|
||||||
})
|
rootClassType, path, getArgAnnotations(argAnnotations));
|
||||||
|
})
|
||||||
.def("__repr__", &ClassAnnotator::toString);
|
.def("__repr__", &ClassAnnotator::toString);
|
||||||
}
|
}
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
void initClassAnnotatorBindings(py::module &m);
|
void initClassAnnotatorBindings(py::module& m);
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
|
@ -50,9 +50,9 @@ static py::list getRegisteredOps() {
|
||||||
// since the JIT has its own dispatch mechanism that it uses to implement
|
// since the JIT has its own dispatch mechanism that it uses to implement
|
||||||
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
||||||
// as `aten::__is__`.
|
// as `aten::__is__`.
|
||||||
for (const std::shared_ptr<torch::jit::Operator> &op :
|
for (const std::shared_ptr<torch::jit::Operator>& op :
|
||||||
torch::jit::getAllOperators()) {
|
torch::jit::getAllOperators()) {
|
||||||
const c10::FunctionSchema &schema = op->schema();
|
const c10::FunctionSchema& schema = op->schema();
|
||||||
|
|
||||||
py::dict record;
|
py::dict record;
|
||||||
{
|
{
|
||||||
|
@ -69,7 +69,7 @@ static py::list getRegisteredOps() {
|
||||||
|
|
||||||
py::list arguments;
|
py::list arguments;
|
||||||
py::list returns;
|
py::list returns;
|
||||||
auto addArgument = [](py::list &container, const c10::Argument &arg) {
|
auto addArgument = [](py::list& container, const c10::Argument& arg) {
|
||||||
py::dict argRecord;
|
py::dict argRecord;
|
||||||
argRecord["name"] = arg.name();
|
argRecord["name"] = arg.name();
|
||||||
argRecord["type"] = arg.type()->str();
|
argRecord["type"] = arg.type()->str();
|
||||||
|
@ -87,10 +87,10 @@ static py::list getRegisteredOps() {
|
||||||
py::dict aliasInfo;
|
py::dict aliasInfo;
|
||||||
py::list before;
|
py::list before;
|
||||||
py::list after;
|
py::list after;
|
||||||
for (auto &symbol : arg.alias_info()->beforeSets()) {
|
for (auto& symbol : arg.alias_info()->beforeSets()) {
|
||||||
before.append(std::string(symbol.toQualString()));
|
before.append(std::string(symbol.toQualString()));
|
||||||
}
|
}
|
||||||
for (auto &symbol : arg.alias_info()->afterSets()) {
|
for (auto& symbol : arg.alias_info()->afterSets()) {
|
||||||
after.append(std::string(symbol.toQualString()));
|
after.append(std::string(symbol.toQualString()));
|
||||||
}
|
}
|
||||||
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
||||||
|
@ -101,10 +101,10 @@ static py::list getRegisteredOps() {
|
||||||
|
|
||||||
container.append(std::move(argRecord));
|
container.append(std::move(argRecord));
|
||||||
};
|
};
|
||||||
for (auto &argument : schema.arguments()) {
|
for (auto& argument : schema.arguments()) {
|
||||||
addArgument(arguments, argument);
|
addArgument(arguments, argument);
|
||||||
}
|
}
|
||||||
for (auto &returnArg : schema.returns()) {
|
for (auto& returnArg : schema.returns()) {
|
||||||
addArgument(returns, returnArg);
|
addArgument(returns, returnArg);
|
||||||
}
|
}
|
||||||
record["arguments"] = std::move(arguments);
|
record["arguments"] = std::move(arguments);
|
||||||
|
@ -115,6 +115,6 @@ static py::list getRegisteredOps() {
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
void torch_mlir::initGetRegisteredOpsBindings(py::module &m) {
|
void torch_mlir::initGetRegisteredOpsBindings(py::module& m) {
|
||||||
m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring);
|
m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring);
|
||||||
}
|
}
|
|
@ -19,7 +19,7 @@
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
void initGetRegisteredOpsBindings(py::module &m);
|
void initGetRegisteredOpsBindings(py::module& m);
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
|
@ -8,17 +8,19 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "import_options_pybind.h"
|
#include "import_options_pybind.h"
|
||||||
#include "import_options.h"
|
#include "jit_ir_importer/import_options.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
void torch_mlir::initImportOptionsBindings(py::module &m) {
|
void torch_mlir::initImportOptionsBindings(py::module& m) {
|
||||||
py::class_<ImportOptions>(m, "ImportOptions")
|
py::class_<ImportOptions>(m, "ImportOptions")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def_readwrite("assumeTensorsHaveValueSemantics",
|
.def_readwrite(
|
||||||
&ImportOptions::assumeTensorsHaveValueSemantics)
|
"assumeTensorsHaveValueSemantics",
|
||||||
.def_readwrite("ignoreExistingTensorShapesAndDtypes",
|
&ImportOptions::assumeTensorsHaveValueSemantics)
|
||||||
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
|
.def_readwrite(
|
||||||
|
"ignoreExistingTensorShapesAndDtypes",
|
||||||
|
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
|
||||||
}
|
}
|
|
@ -13,7 +13,7 @@
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
void initImportOptionsBindings(pybind11::module &m);
|
void initImportOptionsBindings(pybind11::module& m);
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H
|
#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H
|
|
@ -9,9 +9,9 @@
|
||||||
|
|
||||||
#include "module_builder.h"
|
#include "module_builder.h"
|
||||||
|
|
||||||
#include "function_importer.h"
|
#include "jit_ir_importer/function_importer.h"
|
||||||
#include "ivalue_importer.h"
|
#include "jit_ir_importer/ivalue_importer.h"
|
||||||
#include "mlir_utils.h"
|
#include "jit_ir_importer/mlir_utils.h"
|
||||||
|
|
||||||
#include "mlir-c/Bindings/Python/Interop.h"
|
#include "mlir-c/Bindings/Python/Interop.h"
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
|
@ -22,7 +22,7 @@
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
static py::object getMlirIrClass(const char *className) {
|
static py::object getMlirIrClass(const char* className) {
|
||||||
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className);
|
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ static py::object createPythonContextIfNone(py::object contextObj) {
|
||||||
return contextObj;
|
return contextObj;
|
||||||
}
|
}
|
||||||
|
|
||||||
static MlirContext castPythonObjectToMlirContext(py::object &contextObj) {
|
static MlirContext castPythonObjectToMlirContext(py::object& contextObj) {
|
||||||
assert(!contextObj.is_none() && "context cannot be None");
|
assert(!contextObj.is_none() && "context cannot be None");
|
||||||
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||||
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
|
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
|
||||||
|
@ -77,15 +77,15 @@ static void printDiagnostic(MlirDiagnostic diagnostic) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
|
ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
|
||||||
<< ": ";
|
<< ": ";
|
||||||
auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) {
|
auto stringCallback = [](MlirStringRef s, void* stringCallbackUserData) {
|
||||||
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
|
auto* ssp = static_cast<std::stringstream*>(stringCallbackUserData);
|
||||||
ssp->write(s.data, s.length);
|
ssp->write(s.data, s.length);
|
||||||
};
|
};
|
||||||
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
|
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void*>(&ss));
|
||||||
// Use pybind11's print:
|
// Use pybind11's print:
|
||||||
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
||||||
py::print(ss.str(),
|
py::print(
|
||||||
py::arg("file") = py::module_::import("sys").attr("stderr"));
|
ss.str(), py::arg("file") = py::module_::import("sys").attr("stderr"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
||||||
|
@ -93,7 +93,7 @@ static void printDiagnostic(MlirDiagnostic diagnostic) {
|
||||||
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
||||||
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
||||||
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
||||||
void *) -> MlirLogicalResult {
|
void*) -> MlirLogicalResult {
|
||||||
printDiagnostic(diagnostic);
|
printDiagnostic(diagnostic);
|
||||||
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
|
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
|
||||||
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
|
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
|
||||||
|
@ -101,7 +101,7 @@ static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
||||||
return mlirLogicalResultSuccess();
|
return mlirLogicalResultSuccess();
|
||||||
};
|
};
|
||||||
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
||||||
context, diagnosticHandler, nullptr, [](void *) { return; });
|
context, diagnosticHandler, nullptr, [](void*) { return; });
|
||||||
// Ignore the ID. We intend to keep this handler for the entire lifetime
|
// Ignore the ID. We intend to keep this handler for the entire lifetime
|
||||||
// of this context.
|
// of this context.
|
||||||
(void)id;
|
(void)id;
|
||||||
|
@ -123,28 +123,28 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
||||||
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::jit::StrongFunctionPtr
|
torch::jit::StrongFunctionPtr ModuleBuilder::importFunction(
|
||||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function,
|
torch::jit::StrongFunctionPtr function, py::object maybeImportOptions) {
|
||||||
py::object maybeImportOptions) {
|
|
||||||
ImportOptions importOptions;
|
ImportOptions importOptions;
|
||||||
if (!maybeImportOptions.is_none()) {
|
if (!maybeImportOptions.is_none()) {
|
||||||
importOptions = py::cast<ImportOptions>(maybeImportOptions);
|
importOptions = py::cast<ImportOptions>(maybeImportOptions);
|
||||||
}
|
}
|
||||||
MlirBlock block = getBodyBlock();
|
MlirBlock block = getBodyBlock();
|
||||||
MlirOperation terminator = this->terminator;
|
MlirOperation terminator = this->terminator;
|
||||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_,
|
MlirOperation func = importJitFunctionAsFuncOp(
|
||||||
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
context, function.function_,
|
||||||
|
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
||||||
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleBuilder::importModule(torch::jit::Module jitModule,
|
void ModuleBuilder::importModule(
|
||||||
py::object maybeClassAnnotator,
|
torch::jit::Module jitModule, py::object maybeClassAnnotator,
|
||||||
py::object maybeImportOptions) {
|
py::object maybeImportOptions) {
|
||||||
ClassAnnotator dummyAnnotator;
|
ClassAnnotator dummyAnnotator;
|
||||||
ClassAnnotator *classAnnotator = &dummyAnnotator;
|
ClassAnnotator* classAnnotator = &dummyAnnotator;
|
||||||
if (!maybeClassAnnotator.is_none()) {
|
if (!maybeClassAnnotator.is_none()) {
|
||||||
classAnnotator = py::cast<ClassAnnotator *>(maybeClassAnnotator);
|
classAnnotator = py::cast<ClassAnnotator*>(maybeClassAnnotator);
|
||||||
}
|
}
|
||||||
ImportOptions importOptions;
|
ImportOptions importOptions;
|
||||||
if (!maybeImportOptions.is_none()) {
|
if (!maybeImportOptions.is_none()) {
|
||||||
|
@ -168,14 +168,15 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule,
|
||||||
// precise `torch.class_type` names.
|
// precise `torch.class_type` names.
|
||||||
//
|
//
|
||||||
// This name is not semantically load-bearing!!!
|
// This name is not semantically load-bearing!!!
|
||||||
auto &name = *jitModule.type()->name();
|
auto& name = *jitModule.type()->name();
|
||||||
auto debugModuleNameAttr = mlirStringAttrGet(
|
auto debugModuleNameAttr = mlirStringAttrGet(
|
||||||
context, toMlirStringRef(name.atoms()[name.atoms().size() - 1]));
|
context, toMlirStringRef(name.atoms()[name.atoms().size() - 1]));
|
||||||
mlirOperationSetAttributeByName(mlirModuleGetOperation(module),
|
mlirOperationSetAttributeByName(
|
||||||
toMlirStringRef("torch.debug_module_name"),
|
mlirModuleGetOperation(module),
|
||||||
debugModuleNameAttr);
|
toMlirStringRef("torch.debug_module_name"), debugModuleNameAttr);
|
||||||
importIValue(jitModule._ivalue(), mlirModuleGetBody(module),
|
importIValue(
|
||||||
mlirModuleGetContext(module), *classAnnotator, importOptions);
|
jitModule._ivalue(), mlirModuleGetBody(module),
|
||||||
|
mlirModuleGetContext(module), *classAnnotator, importOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock ModuleBuilder::getBodyBlock() {
|
MlirBlock ModuleBuilder::getBodyBlock() {
|
||||||
|
@ -183,14 +184,16 @@ MlirBlock ModuleBuilder::getBodyBlock() {
|
||||||
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleBuilder::bind(py::module &m) {
|
void ModuleBuilder::bind(py::module& m) {
|
||||||
py::class_<ModuleBuilder>(m, "ModuleBuilder")
|
py::class_<ModuleBuilder>(m, "ModuleBuilder")
|
||||||
.def(py::init<py::object>(), py::arg("context") = py::none())
|
.def(py::init<py::object>(), py::arg("context") = py::none())
|
||||||
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||||
.def("import_function", &ModuleBuilder::importFunction, py::arg("function"),
|
.def(
|
||||||
py::arg("importOptions") = py::none())
|
"import_function", &ModuleBuilder::importFunction,
|
||||||
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
py::arg("function"), py::arg("importOptions") = py::none())
|
||||||
py::arg("classAnnotator") = py::none(),
|
.def(
|
||||||
py::arg("importOptions") = py::none());
|
"import_module", &ModuleBuilder::importModule, py::arg("module"),
|
||||||
|
py::arg("classAnnotator") = py::none(),
|
||||||
|
py::arg("importOptions") = py::none());
|
||||||
}
|
}
|
|
@ -10,7 +10,7 @@
|
||||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||||
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||||
|
|
||||||
#include "class_annotator.h"
|
#include "jit_ir_importer/class_annotator.h"
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ public:
|
||||||
ModuleBuilder(pybind11::object contextObj);
|
ModuleBuilder(pybind11::object contextObj);
|
||||||
|
|
||||||
/// Creates Python bindings for the class.
|
/// Creates Python bindings for the class.
|
||||||
static void bind(pybind11::module &m);
|
static void bind(pybind11::module& m);
|
||||||
|
|
||||||
pybind11::object getContextObj() { return contextObj; }
|
pybind11::object getContextObj() { return contextObj; }
|
||||||
pybind11::object getModuleObj() { return moduleObj; }
|
pybind11::object getModuleObj() { return moduleObj; }
|
||||||
|
@ -38,16 +38,15 @@ public:
|
||||||
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
||||||
// Just a bit of naming cruft.
|
// Just a bit of naming cruft.
|
||||||
// Returns the same function, making it suitable as a nested decorator.
|
// Returns the same function, making it suitable as a nested decorator.
|
||||||
torch::jit::StrongFunctionPtr
|
torch::jit::StrongFunctionPtr importFunction(
|
||||||
importFunction(torch::jit::StrongFunctionPtr function,
|
torch::jit::StrongFunctionPtr function, py::object maybeImportOptions);
|
||||||
py::object maybeImportOptions);
|
|
||||||
|
|
||||||
// Imports a torch::jit::Module into the current module, using the
|
// Imports a torch::jit::Module into the current module, using the
|
||||||
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
||||||
// a ClassAnnotator.
|
// a ClassAnnotator.
|
||||||
void importModule(torch::jit::Module jitModule,
|
void importModule(
|
||||||
py::object maybeClassAnnotator,
|
torch::jit::Module jitModule, py::object maybeClassAnnotator,
|
||||||
py::object maybeImportOptions);
|
py::object maybeImportOptions);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MlirBlock getBodyBlock();
|
MlirBlock getBodyBlock();
|
|
@ -1,28 +1,3 @@
|
||||||
###########################################################################
|
|
||||||
# Setup PyTorch
|
|
||||||
###########################################################################
|
|
||||||
|
|
||||||
include(TorchMLIRPyTorch)
|
|
||||||
|
|
||||||
TorchMLIRProbeForPyTorchInstall()
|
|
||||||
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
|
||||||
TorchMLIRConfigurePyTorch()
|
|
||||||
else()
|
|
||||||
# Assume it is a sibling to the overall project.
|
|
||||||
set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
|
||||||
message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_package(Torch 1.11 REQUIRED)
|
|
||||||
|
|
||||||
###########################################################################
|
|
||||||
# Setup Python development
|
|
||||||
###########################################################################
|
|
||||||
|
|
||||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/externals/llvm-project/mlir/cmake/modules")
|
|
||||||
include(MLIRDetectPythonEnv)
|
|
||||||
mlir_configure_python_dev_packages()
|
|
||||||
|
|
||||||
###########################################################################
|
###########################################################################
|
||||||
# Library definition
|
# Library definition
|
||||||
###########################################################################
|
###########################################################################
|
||||||
|
|
|
@ -14,12 +14,12 @@
|
||||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||||
#include <torch/csrc/lazy/core/shape.h>
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
|
#include <base_lazy_backend/backend_impl.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
|
#include <base_lazy_backend/generated/LazyNativeFunctions.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
#include <base_lazy_backend/mlir_lowering_context.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/debug.h>
|
#include <base_lazy_backend/utils/debug.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/exception.h>
|
#include <base_lazy_backend/utils/exception.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
|
#include <base_lazy_backend/utils/string_utils.h>
|
||||||
|
|
||||||
#include "backend_impl.h"
|
#include "backend_impl.h"
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,10 @@
|
||||||
#include "torch/csrc/lazy/core/config.h"
|
#include "torch/csrc/lazy/core/config.h"
|
||||||
#include "torch/csrc/lazy/backend/backend_interface.h"
|
#include "torch/csrc/lazy/backend/backend_interface.h"
|
||||||
|
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
#include <base_lazy_backend/mlir_lowering_context.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
|
#include <base_lazy_backend/utils/string_utils.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h>
|
#include <base_lazy_backend/utils/sys_utils.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h>
|
#include <base_lazy_backend/utils/tensor_utils.h>
|
||||||
|
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
# Setup PyTorch
|
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
include(TorchMLIRPyTorch)
|
|
||||||
|
|
||||||
TorchMLIRProbeForPyTorchInstall()
|
|
||||||
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
|
||||||
TorchMLIRConfigurePyTorch()
|
|
||||||
else()
|
|
||||||
# Assume it is a sibling to the overall project.
|
|
||||||
set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
|
||||||
message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_package(Torch 1.11 REQUIRED)
|
|
||||||
|
|
||||||
message(STATUS "libtorch_python CXXFLAGS is ...${TORCH_CXXFLAGS}")
|
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
# Subdirectories
|
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
add_subdirectory(csrc)
|
|
||||||
|
|
||||||
## Declare the sources of the Python module.
|
|
||||||
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter
|
|
||||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
|
||||||
ADD_TO_PARENT TorchMLIRPythonSources
|
|
||||||
SOURCES_GLOB
|
|
||||||
dialects/torch/importer/jit_ir/*.py
|
|
||||||
)
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
# Subdirectories
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
## Declare the sources of the Python module.
|
||||||
|
|
||||||
|
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter
|
||||||
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
|
ADD_TO_PARENT TorchMLIRPythonSources
|
||||||
|
SOURCES_GLOB
|
||||||
|
jit_ir_importer/*.py
|
||||||
|
)
|
|
@ -11,8 +11,11 @@ import torch
|
||||||
|
|
||||||
# Our native extension is not self-contained. It references libraries which
|
# Our native extension is not self-contained. It references libraries which
|
||||||
# must come in via the above first.
|
# must come in via the above first.
|
||||||
from ....._mlir_libs._jit_ir_importer import *
|
from .._mlir_libs._jit_ir_importer import *
|
||||||
|
|
||||||
|
# Ensure that the torch dialect has been loaded as it registers passes
|
||||||
|
# and other things the jit_ir_importer needs.
|
||||||
|
from ..dialects import torch as _unused_torch_dialect
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"debug_trace_to_stderr",
|
"debug_trace_to_stderr",
|
|
@ -10,7 +10,7 @@ import codecs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||||
from torch_mlir.passmanager import PassManager
|
from torch_mlir.passmanager import PassManager
|
||||||
|
|
||||||
from .registry import Registry
|
from .registry import Registry
|
|
@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator
|
from torch_mlir.jit_ir_importer import ClassAnnotator
|
||||||
|
|
||||||
# Decorators
|
# Decorators
|
||||||
|
|
|
@ -23,4 +23,4 @@ add_lit_testsuite(check-torch-mlir-pt1 "Running the torch-mlir PT1 regression te
|
||||||
)
|
)
|
||||||
set_target_properties(check-torch-mlir-pt1 PROPERTIES FOLDER "Tests")
|
set_target_properties(check-torch-mlir-pt1 PROPERTIES FOLDER "Tests")
|
||||||
|
|
||||||
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
add_lit_testsuites(TORCH_MLIR_PT1 ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
||||||
|
|
|
@ -19,7 +19,7 @@ from lit.llvm.subst import FindTool
|
||||||
# Configuration file for the 'lit' test runner.
|
# Configuration file for the 'lit' test runner.
|
||||||
|
|
||||||
# name: The name of this test suite.
|
# name: The name of this test suite.
|
||||||
config.name = 'TORCH_MLIR'
|
config.name = 'TORCH_MLIR_PT1'
|
||||||
|
|
||||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue