mirror of https://github.com/llvm/torch-mlir
Breakup python pytorch deps (#2582)
This lifts the core of the jit_ir_importer and ltc out of the pt1 project, making them peers to it. As a side-effect of this layering, now the "MLIR bits" (dialects, etc) are not commingled with the various parts of the pt1 project, allowing pt1 and ltc to overlay cleanly onto a more fundamental "just MLIR" Python core. Prior to this, the Python namespace was polluted to the point that this could not happen. That "just MLIR" Python core will be introduced in a followup, which will create the space to upstream the FX and ONNX pure Python importers. This primary non-NFC change to the API is: * `torch_mlir.dialects.torch.importer.jit_ir` -> `torch_mlir.jit_ir_importer`. The rest is source code layering so that we can make the pt1 project optional without losing the other features. Progress on #2546.pull/2585/head
parent
facbe5d96b
commit
5eae0adff1
|
@ -27,15 +27,12 @@ jobs:
|
|||
matrix:
|
||||
os-arch: [ubuntu-x86_64, macos-arm64, windows-x86_64]
|
||||
llvm-build: [in-tree, out-of-tree]
|
||||
torch-binary: [ON, OFF]
|
||||
torch-binary: [ON]
|
||||
torch-version: [nightly, stable]
|
||||
exclude:
|
||||
# Exclude llvm in-tree and pytorch source
|
||||
- llvm-build: in-tree
|
||||
torch-binary: OFF
|
||||
# Exclude llvm out-of-tree and pytorch binary
|
||||
# Exclude llvm out-of-tree and pytorch stable (to save resources)
|
||||
- llvm-build: out-of-tree
|
||||
torch-binary: ON
|
||||
torch-version: stable
|
||||
# Exclude macos-arm64 and llvm out-of-tree altogether
|
||||
- os-arch: macos-arm64
|
||||
llvm-build: out-of-tree
|
||||
|
@ -45,9 +42,6 @@ jobs:
|
|||
llvm-build: out-of-tree
|
||||
- os-arch: windows-x86_64
|
||||
torch-version: stable
|
||||
# For PyTorch stable builds, we don't build PyTorch from source
|
||||
- torch-version: stable
|
||||
torch-binary: OFF
|
||||
include:
|
||||
# Specify OS versions
|
||||
- os-arch: ubuntu-x86_64
|
||||
|
|
|
@ -26,7 +26,7 @@ __pycache__
|
|||
bazel-*
|
||||
|
||||
# Autogenerated files
|
||||
/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated
|
||||
/projects/ltc/csrc/base_lazy_backend/generated
|
||||
|
||||
#Docker builds
|
||||
build_oot/
|
||||
|
|
|
@ -149,10 +149,12 @@ endfunction()
|
|||
# Configure CMake.
|
||||
list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
|
||||
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
|
||||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/build_tools/cmake)
|
||||
|
||||
include(TableGen)
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
include(AddMLIRPython)
|
||||
|
||||
################################################################################
|
||||
# Setup python.
|
||||
|
@ -231,6 +233,4 @@ endif()
|
|||
# Sub-projects
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
if(TORCH_MLIR_ENABLE_PROJECT_PT1)
|
||||
add_subdirectory(projects/pt1)
|
||||
endif()
|
||||
add_subdirectory(projects)
|
||||
|
|
|
@ -29,7 +29,6 @@ if not TORCH_INCLUDE_DIR.is_dir():
|
|||
TORCH_INCLUDE_DIR = TORCH_DIR
|
||||
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
|
||||
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
|
||||
TORCH_MLIR_PT1_DIR = TORCH_MLIR_DIR / "projects" / "pt1"
|
||||
|
||||
def reindent(text, prefix=""):
|
||||
return indent(dedent(text), prefix)
|
||||
|
@ -114,12 +113,12 @@ class GenTorchMlirLTC:
|
|||
self.binary_dir = Path(binary_dir)
|
||||
assert self.binary_dir.is_dir(), f"Binary directory not found: {self.binary_dir}"
|
||||
self.source_yaml = self.binary_dir.joinpath("generated_native_functions.yaml")
|
||||
self.backend_path = TORCH_MLIR_PT1_DIR.joinpath(
|
||||
"python", "torch_mlir", "csrc", "base_lazy_backend"
|
||||
self.backend_path = TORCH_MLIR_DIR.joinpath(
|
||||
"projects", "ltc", "csrc", "base_lazy_backend"
|
||||
)
|
||||
assert self.backend_path.is_dir(), f"Backend path not found: {self.backend_path}"
|
||||
self.generated_path = self.binary_dir.joinpath(
|
||||
"projects", "pt1", "python", "torch_mlir", "csrc", "base_lazy_backend", "generated"
|
||||
"projects", "ltc", "csrc", "base_lazy_backend", "generated"
|
||||
)
|
||||
self.generated_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -415,7 +414,7 @@ class GenTorchMlirLTC:
|
|||
// for ops that dont have a corresponding structured kernel or shape definition
|
||||
|
||||
#include "shape_inference.h"
|
||||
#include "torch_mlir/csrc/base_lazy_backend/utils/exception.h"
|
||||
#include "base_lazy_backend/utils/exception.h"
|
||||
namespace torch {{
|
||||
namespace lazy {{
|
||||
{}
|
||||
|
@ -467,7 +466,7 @@ class GenTorchMlirLTC:
|
|||
node_base="torch::lazy::TorchMlirNode",
|
||||
node_base_hdr=str(self.backend_path.joinpath("mlir_node.h")),
|
||||
tensor_class=self.tensor_class,
|
||||
tensor_class_hdr="torch_mlir/csrc/base_lazy_backend/tensor.h",
|
||||
tensor_class_hdr="base_lazy_backend/tensor.h",
|
||||
create_aten_from_ltc_tensor="CreateFunctionalizedAtenFromLtcTensor",
|
||||
shape_inference_hdr=str(self.generated_path.joinpath("shape_inference.h")),
|
||||
lazy_ir_generator=GenMlirLazyIr,
|
||||
|
|
|
@ -364,9 +364,9 @@ function setup_venv() {
|
|||
function build_out_of_tree() {
|
||||
local torch_from_bin="$1"
|
||||
local python_version="$2"
|
||||
echo ":::: Build out-of-tree Torch from binary: $torch_from_bin with Python: $python_version"
|
||||
|
||||
local torch_version="$3"
|
||||
echo ":::: Build out-of-tree Torch from binary: $torch_from_bin with Python: $python_version ($torch_version)"
|
||||
|
||||
local enable_ltc="ON"
|
||||
if [[ "${torch_version}" == "stable" ]]
|
||||
then
|
||||
|
|
|
@ -42,6 +42,6 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
|
|||
fi
|
||||
|
||||
PYTHONPATH="${pypath}" python \
|
||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.abstract_interp_lib_gen \
|
||||
-m torch_mlir.jit_ir_importer.build_tools.abstract_interp_lib_gen \
|
||||
--pytorch_op_extensions=${ext_module:-""} \
|
||||
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
||||
|
|
|
@ -43,7 +43,7 @@ fi
|
|||
|
||||
set +u
|
||||
PYTHONPATH="${PYTHONPATH}:${pypath}" python \
|
||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
|
||||
-m torch_mlir.jit_ir_importer.build_tools.torch_ods_gen \
|
||||
--torch_ir_include_dir="${torch_ir_include_dir}" \
|
||||
--pytorch_op_extensions="${ext_module}" \
|
||||
--debug_registry_dump="${torch_ir_include_dir}/JITOperatorRegistryDump.txt"
|
||||
|
|
|
@ -17,7 +17,7 @@ The end-to-end test is important to check the correctness of the other steps.
|
|||
|
||||
### Step 2. Update ods
|
||||
|
||||
Update [torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/main/projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py) with the new op and run [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh) to generate the ods. Running `update_torch_ods.sh` would dump all the operators with schema into `JITOperatorRegistryDump.txt`. It’s convenient to look for ops signatures and operands names in this file.
|
||||
Update [torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/main/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py) with the new op and run [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh) to generate the ods. Running `update_torch_ods.sh` would dump all the operators with schema into `JITOperatorRegistryDump.txt`. It’s convenient to look for ops signatures and operands names in this file.
|
||||
|
||||
### Step 3. Propagate types
|
||||
It’s essential to make sure the new op implements shape and dtype inference. See [abstract_interp_lib](https://github.com/llvm/torch-mlir/blob/main/docs/abstract_interp_lib.md) for information on adding shape and dtype inference.
|
||||
|
|
|
@ -26,7 +26,7 @@ The two main use cases are:
|
|||
## Architecture
|
||||
|
||||
Functions are defined as TorchScript-able Python functions in
|
||||
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py`.
|
||||
`python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py`.
|
||||
The signatures of the functions are systematically derived from Torch JIT
|
||||
operator registry. Most shape functions are expected to reuse the upstream
|
||||
helper functions
|
||||
|
|
|
@ -87,7 +87,7 @@ following order:
|
|||
|
||||
1. Shape of input tensor. Use `-1` for dynamic dimensions
|
||||
2. Dtype of the input tensor
|
||||
3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h#L54-L67). This
|
||||
3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/jit_ir_importer/csrc/class_annotator.h#L54-L67). This
|
||||
will always be true for E2E tests, since the [Torch-MLIR backend contract](architecture.md#the-backend-contract) requires all tensors in the
|
||||
IR to eventually have value semantics.
|
||||
|
||||
|
|
|
@ -55,14 +55,14 @@ factored such that we can handle this with one core import path, which is
|
|||
through the PyTorch
|
||||
"[JIT IR](https://github.com/pytorch/pytorch/blob/78c8a0d75220bdd4955415b5f81509e005af4232/torch/csrc/jit/OVERVIEW.md)",
|
||||
and lives in
|
||||
[torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir](https://github.com/llvm/torch-mlir/tree/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir).
|
||||
[torch-mlir/python/torch_mlir/jit_ir_importer](https://github.com/llvm/torch-mlir/tree/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir).
|
||||
The JIT IR is a highly principled IR that faithfully models a Python subset (+
|
||||
tensors, the PyTorch op registry, and a few other things). All the other PyTorch
|
||||
program representations can eventually bottom-out on the JIT IR via some path
|
||||
provided by PyTorch. The `torch` dialect is almost entirely in 1:1
|
||||
correspondence with the JIT IR -- this allows the importer to be extremely small
|
||||
(the core is
|
||||
[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp#L1)).
|
||||
[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/jit_ir_importer/csrc/node_importer.cpp#L1)).
|
||||
|
||||
### Ops
|
||||
|
||||
|
@ -70,7 +70,7 @@ See [TorchOps.td](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83
|
|||
|
||||
The ops in the `torch` dialect are almost entirely generated based on the
|
||||
PyTorch JIT IR operator registry via the script
|
||||
[torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py#L1) (invoked via [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh)).
|
||||
[torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py#L1) (invoked via [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh)).
|
||||
This script queries the registry and generates MLIR
|
||||
[ODS](https://mlir.llvm.org/docs/OpDefinitions/) in
|
||||
[GeneratedTorchOps.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L1). We have a guide for [adding a new op end-to-end](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation).
|
||||
|
@ -195,7 +195,7 @@ values. When one `torch.jit.script`'s a `torch.nn.Module`, the result is
|
|||
actually an `IValue` that represents the module, with a hierarchy of children
|
||||
`IValue`'s. Strictly speaking, JIT IR `torch::jit::Graph`'s are only used to
|
||||
represent the bodies of methods on the modules. So in addition to importing the
|
||||
JIT IR, we also need to import the `IValue`'s. This happens inside [ivalue_importer.cpp](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp#L1).
|
||||
JIT IR, we also need to import the `IValue`'s. This happens inside [ivalue_importer.cpp](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/python/torch_mlir/jit_ir_importer/csrc/ivalue_importer.cpp#L1).
|
||||
|
||||
Most of the IValue modeling can reuse `torch` dialect ops that already exist
|
||||
otherwise, such as `torch.constant.int` to represent an int in the object graph.
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
[Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR.
|
||||
After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation.
|
||||
|
||||
LTC support is provided through an abstract [`TorchMlirBackendImpl`](../python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR.
|
||||
LTC support is provided through an abstract [`TorchMlirBackendImpl`](../projects/ltc/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR.
|
||||
Implementations based on this abstract class will be able to specify their own compile and execution workflows.
|
||||
Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend).
|
||||
|
||||
|
@ -27,7 +27,7 @@ View examples [here](ltc_examples.md).
|
|||
- The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td),
|
||||
excluding those explicitly blacklisted in the YAML file
|
||||
|
||||
### Autogen Files ([`python/torch_mlir/csrc/base_lazy_backend/generated`](../python/torch_mlir/csrc/base_lazy_backend/generated))
|
||||
### Autogen Files ([`projects/ltc/csrc/base_lazy_backend/generated`](../projects/ltc/csrc/base_lazy_backend/generated))
|
||||
Generated files are created in this directory, which is ignored by version control.
|
||||
|
||||
- `LazyIr.h`
|
||||
|
@ -41,7 +41,7 @@ Generated files are created in this directory, which is ignored by version contr
|
|||
- `shape_inference.{cpp,h}`
|
||||
- Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions
|
||||
|
||||
### Base Backend ([`python/torch_mlir/csrc/base_lazy_backend`](../python/torch_mlir/csrc/base_lazy_backend))
|
||||
### Base Backend ([`projects/ltc/csrc/base_lazy_backend`](../projects/ltc/csrc/base_lazy_backend))
|
||||
|
||||
- `backend_impl.{cpp,h}`
|
||||
- Base LTC backend to setup Torch-MLIR lowering context
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
// This file is automatically generated. Please do not edit.
|
||||
// Generated via:
|
||||
// ```
|
||||
// python -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen
|
||||
// python -m torch_mlir.jit_ir_importer.build_tools.torch_ods_gen
|
||||
// ```
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,44 @@
|
|||
include(AddMLIRPython)
|
||||
|
||||
# Configure PyTorch if we have any features enabled which require it.
|
||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
|
||||
message(STATUS "Enabling PyTorch C++ dep (features depend on it)")
|
||||
include(TorchMLIRPyTorch)
|
||||
|
||||
TorchMLIRProbeForPyTorchInstall()
|
||||
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||
TorchMLIRConfigurePyTorch()
|
||||
else()
|
||||
# Assume it is a sibling to the overall project.
|
||||
set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
||||
message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}")
|
||||
endif()
|
||||
|
||||
find_package(Torch 1.11 REQUIRED)
|
||||
|
||||
set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen)
|
||||
|
||||
include_directories(BEFORE
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
message(STATUS "libtorch_python CXXFLAGS is ...${TORCH_CXXFLAGS}")
|
||||
message(STATUS "TORCH_LIBRARIES = ${TORCH_LIBRARIES}")
|
||||
endif()
|
||||
|
||||
# Include jit_ir_common if the jit_ir importer or LTC is enabled,
|
||||
# since they both require it.
|
||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
|
||||
add_subdirectory(jit_ir_common)
|
||||
endif()
|
||||
|
||||
# Include LTC.
|
||||
if(TORCH_MLIR_ENABLE_LTC)
|
||||
add_subdirectory(ltc)
|
||||
endif()
|
||||
|
||||
# Include overall PT1 project.
|
||||
if(TORCH_MLIR_ENABLE_PROJECT_PT1)
|
||||
add_subdirectory(pt1)
|
||||
endif()
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(csrc/jit_ir_importer)
|
|
@ -0,0 +1,27 @@
|
|||
# Static library with core functionality.
|
||||
# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build)
|
||||
# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376
|
||||
add_library(TorchMLIRJITIRImporter STATIC
|
||||
class_annotator.cpp
|
||||
function_importer.cpp
|
||||
node_importer.cpp
|
||||
ivalue_importer.cpp
|
||||
torch_to_mlir_utils.cpp
|
||||
)
|
||||
message(STATUS "Linking TorchMLIRJITImporter with ${TORCH_LIBRARIES}")
|
||||
target_link_libraries(TorchMLIRJITIRImporter
|
||||
TorchMLIRAggregateCAPI
|
||||
${TORCH_LIBRARIES}
|
||||
)
|
||||
# Includes are relative to the csrc dir (i.e. #include "jit_ir_importer/...")
|
||||
target_include_directories(TorchMLIRJITIRImporter PUBLIC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
)
|
||||
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
||||
OUTPUT_NAME lib_jit_ir_importer
|
||||
PREFIX ""
|
||||
SUFFIX ".a"
|
||||
CXX_VISIBILITY_PRESET "default"
|
||||
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
|
||||
)
|
|
@ -190,7 +190,8 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
|||
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||
mlirRegionCreate());
|
||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
|
||||
mlirRegionAppendOwnedBlock(nnModuleRegion,
|
||||
mlirBlockCreate(0, nullptr, nullptr));
|
||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||
InserterGuard inserterGuard(importBlock, nnModule);
|
||||
|
||||
|
@ -491,8 +492,9 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
|||
toMlirNamedAttribute(
|
||||
"name", mlirStringAttrGet(
|
||||
context, toMlirStringRef(classAttribute.getName()))),
|
||||
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
||||
loc, classAttribute.getType(), importOptions))),
|
||||
toMlirNamedAttribute(
|
||||
"type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
||||
loc, classAttribute.getType(), importOptions))),
|
||||
isPrivate);
|
||||
}
|
||||
|
|
@ -41,10 +41,9 @@ public:
|
|||
const ImportOptions &importOptions = {});
|
||||
|
||||
private:
|
||||
MlirBlock
|
||||
createBlockFor(Block *jitBlock,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions = {});
|
||||
MlirBlock createBlockFor(Block *jitBlock,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions = {});
|
||||
void mapValue(Value *jitValue, MlirValue value);
|
||||
void mapResults(Node *node, MlirOperation operation);
|
||||
MlirValue lookupMappedValue(Value *jitValue);
|
||||
|
@ -269,9 +268,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
|||
terminatorOperandTypes,
|
||||
/*userAllowsRefinement=*/false));
|
||||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
||||
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], createTerminator,
|
||||
c10::nullopt, importOptions));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -290,12 +289,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
|||
resultTypes,
|
||||
/*userAllowsRefinement=*/false));
|
||||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 1),
|
||||
importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions));
|
||||
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], createTerminator,
|
||||
c10::nullopt, importOptions));
|
||||
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1),
|
||||
importBlock(node->blocks()[1], createTerminator,
|
||||
c10::nullopt, importOptions));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -303,8 +302,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
|||
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
||||
auto methodName = node->s(c10::attr::name);
|
||||
torch::jit::Function *function = classType->findMethod(methodName);
|
||||
MlirType calleeType =
|
||||
getFunctionTypeFromSchema(context, function->getSchema(), importOptions);
|
||||
MlirType calleeType = getFunctionTypeFromSchema(
|
||||
context, function->getSchema(), importOptions);
|
||||
std::vector<MlirType> expectedTypes;
|
||||
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
|
||||
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
|
||||
|
@ -361,10 +360,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
|||
}
|
||||
}
|
||||
|
||||
MlirBlock NodeImporter::importBlock(
|
||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions) {
|
||||
MlirBlock
|
||||
NodeImporter::importBlock(Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions) {
|
||||
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
|
||||
for (Node *node : jitBlock->nodes()) {
|
||||
importNode(node, block, importOptions);
|
||||
|
@ -434,5 +433,6 @@ torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
|||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions) {
|
||||
NodeImporter importer(context);
|
||||
return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions);
|
||||
return importer.importBlock(jitBlock, createTerminator, blockArgTypes,
|
||||
importOptions);
|
||||
}
|
|
@ -36,11 +36,11 @@ using CreateTerminatorFn =
|
|||
/// are required to be for correctness. The code will internally attempt to
|
||||
/// adjust the types to the block argument types.
|
||||
/// TODO: Formalize what type conversions are allowed here.
|
||||
MlirBlock importBlock(
|
||||
MlirContext context, torch::jit::Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
||||
const ImportOptions &importOptions = {});
|
||||
MlirBlock
|
||||
importBlock(MlirContext context, torch::jit::Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
||||
const ImportOptions &importOptions = {});
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(csrc/base_lazy_backend)
|
|
@ -2,30 +2,6 @@
|
|||
# Setup PyTorch/LTC
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
include(TorchMLIRPyTorch)
|
||||
|
||||
TorchMLIRProbeForPyTorchInstall()
|
||||
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||
TorchMLIRConfigurePyTorch()
|
||||
else()
|
||||
# Assume it is a sibling to the overall project.
|
||||
set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
||||
message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}")
|
||||
endif()
|
||||
|
||||
find_package(Torch 1.11 REQUIRED)
|
||||
|
||||
set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen)
|
||||
|
||||
include_directories(BEFORE
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
${PROJECT_SOURCE_DIR}/projects/pt1/python
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
set(LTC_GENERATED
|
||||
generated/LazyNativeFunctions.cpp
|
||||
generated/RegisterLazy.cpp
|
||||
|
@ -80,6 +56,12 @@ add_library(torch_mlir_ltc_backend SHARED
|
|||
utils/tensor_utils.cpp
|
||||
)
|
||||
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
||||
# Includes are resolved relative to csrc (i.e. #include "base_lazy_backend/...").
|
||||
# Add both the source and generated include directories.
|
||||
target_include_directories(torch_mlir_ltc_backend PUBLIC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
${CMAKE_CURRENT_BINARY_DIR}/..
|
||||
)
|
||||
|
||||
add_dependencies(torch_mlir_ltc_backend
|
||||
TorchMLIRJITIRImporter
|
||||
|
@ -112,13 +94,13 @@ add_custom_command(
|
|||
add_custom_command(
|
||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||
COMMAND cp
|
||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/*.h
|
||||
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/*.h
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/)
|
||||
|
||||
add_custom_command(
|
||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||
COMMAND cp
|
||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/generated/*.h
|
||||
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/generated/*.h
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/)
|
||||
|
||||
add_custom_command(
|
||||
|
@ -129,7 +111,7 @@ add_custom_command(
|
|||
add_custom_command(
|
||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||
COMMAND cp
|
||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/ops/*.h
|
||||
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/ops/*.h
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/)
|
||||
|
||||
add_custom_command(
|
||||
|
@ -140,5 +122,5 @@ add_custom_command(
|
|||
add_custom_command(
|
||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||
COMMAND cp
|
||||
${PROJECT_SOURCE_DIR}/projects/pt1/python/torch_mlir/csrc/base_lazy_backend/utils/*.h
|
||||
${PROJECT_SOURCE_DIR}/projects/ltc/csrc/base_lazy_backend/utils/*.h
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/)
|
|
@ -21,8 +21,8 @@
|
|||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Pass.h"
|
||||
|
||||
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
|
||||
#include "backend_impl.h"
|
||||
#include "jit_ir_importer/function_importer.h"
|
||||
#include "mlir_lowering_context.h"
|
||||
#include "mlir_node.h"
|
||||
#include "utils/debug.h"
|
|
@ -92,8 +92,8 @@
|
|||
"import torchvision\n",
|
||||
"\n",
|
||||
"import torch_mlir\n",
|
||||
"from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder\n",
|
||||
"from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations\n",
|
||||
"from torch_mlir.jit_ir_importer import ClassAnnotator, ModuleBuilder\n",
|
||||
"from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations\n",
|
||||
"\n",
|
||||
"from torch_mlir.passmanager import PassManager\n",
|
||||
"from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend"
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
include(AddMLIRPython)
|
||||
|
||||
# Disables generation of "version soname" (i.e. libFoo.so.<version>), which
|
||||
# causes pure duplication as part of Python wheels.
|
||||
set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON)
|
||||
|
@ -90,9 +88,6 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
|
|||
# Lazy Tensor Core
|
||||
################################################################################
|
||||
|
||||
if(TORCH_MLIR_ENABLE_LTC)
|
||||
add_subdirectory(torch_mlir/csrc/base_lazy_backend)
|
||||
endif()
|
||||
# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it
|
||||
# generates a dummy Python library when disabled.
|
||||
if(NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
||||
|
@ -104,7 +99,8 @@ endif()
|
|||
################################################################################
|
||||
|
||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
||||
add_subdirectory(torch_mlir/dialects/torch/importer/jit_ir)
|
||||
add_subdirectory(torch_mlir/jit_ir_importer)
|
||||
add_subdirectory(torch_mlir/csrc/jit_ir_importer)
|
||||
add_subdirectory(torch_mlir_e2e_test)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@
|
|||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
||||
from torch_mlir.jit_ir_importer import ClassAnnotator
|
||||
from torch_mlir.jit_ir_importer.torchscript_annotations import extract_annotations
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -17,8 +17,8 @@ from torch_mlir.dynamo import _get_decomposition_table
|
|||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from .compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library
|
||||
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
||||
|
||||
|
||||
class OutputType(Enum):
|
||||
|
|
|
@ -1,39 +1,3 @@
|
|||
# Sharp edge: Torch extensions need to use the same pybind11 that torch
|
||||
# was compiled with, or else there will be issues in cross module exception
|
||||
# handling (which will abort instead of raise). We circumvent the possibility
|
||||
# by forcing the torch directories first.
|
||||
include_directories(BEFORE
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
# Static library with core functionality.
|
||||
# We can't use a shared library here, due to issues with linking on macOS-arm64 (the library itself won't build)
|
||||
# For details, see: https://github.com/llvm/torch-mlir/runs/7919012376
|
||||
add_library(TorchMLIRJITIRImporter STATIC
|
||||
class_annotator.cpp
|
||||
function_importer.cpp
|
||||
node_importer.cpp
|
||||
ivalue_importer.cpp
|
||||
torch_to_mlir_utils.cpp
|
||||
)
|
||||
target_link_libraries(TorchMLIRJITIRImporter
|
||||
TorchMLIRAggregateCAPI
|
||||
${TORCH_LIBRARIES}
|
||||
)
|
||||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}")
|
||||
set_target_properties(TorchMLIRJITIRImporter PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
||||
OUTPUT_NAME lib_jit_ir_importer
|
||||
PREFIX ""
|
||||
SUFFIX ".a"
|
||||
CXX_VISIBILITY_PRESET "default"
|
||||
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
|
||||
)
|
||||
|
||||
# Separate Pybind MODULE due to issues with a SHARED library.
|
||||
# https://github.com/llvm/torch-mlir/issues/1154
|
||||
add_library(TorchMLIRJITIRImporterPybind MODULE
|
||||
|
@ -62,7 +26,6 @@ if(Python3_LIBRARIES)
|
|||
)
|
||||
endif()
|
||||
|
||||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}")
|
||||
set_target_properties(TorchMLIRJITIRImporterPybind PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
||||
OUTPUT_NAME _jit_ir_importer
|
|
@ -8,7 +8,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "class_annotator_pybind.h"
|
||||
#include "class_annotator.h"
|
||||
#include "jit_ir_importer/class_annotator.h"
|
||||
|
||||
#include <torch/csrc/Dtype.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
@ -18,7 +18,7 @@ using namespace torch_mlir;
|
|||
static c10::ScalarType convertToC10ScalarType(py::object obj) {
|
||||
if (THPDtype_Check(obj.ptr())) {
|
||||
// Need reinterpret_cast, since no C++-level inheritance is involved.
|
||||
THPDtype *dtype = reinterpret_cast<THPDtype *>(obj.ptr());
|
||||
THPDtype* dtype = reinterpret_cast<THPDtype*>(obj.ptr());
|
||||
return dtype->scalar_type;
|
||||
}
|
||||
std::stringstream ss;
|
||||
|
@ -48,16 +48,17 @@ static std::vector<ArgAnnotation> getArgAnnotations(py::list pyArgAnnotations) {
|
|||
return argAnnotations;
|
||||
}
|
||||
|
||||
void torch_mlir::initClassAnnotatorBindings(py::module &m) {
|
||||
void torch_mlir::initClassAnnotatorBindings(py::module& m) {
|
||||
py::class_<ClassAnnotator>(m, "ClassAnnotator")
|
||||
.def(py::init<>())
|
||||
.def("exportPath", &ClassAnnotator::exportPath)
|
||||
.def("exportNone", &ClassAnnotator::exportNone)
|
||||
.def("annotateArgs",
|
||||
[&](ClassAnnotator &cls_annotator, c10::ClassType &rootClassType,
|
||||
std::vector<std::string> path, py::list argAnnotations) {
|
||||
cls_annotator.annotateArgs(rootClassType, path,
|
||||
getArgAnnotations(argAnnotations));
|
||||
})
|
||||
.def(
|
||||
"annotateArgs",
|
||||
[&](ClassAnnotator& cls_annotator, c10::ClassType& rootClassType,
|
||||
std::vector<std::string> path, py::list argAnnotations) {
|
||||
cls_annotator.annotateArgs(
|
||||
rootClassType, path, getArgAnnotations(argAnnotations));
|
||||
})
|
||||
.def("__repr__", &ClassAnnotator::toString);
|
||||
}
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
namespace py = pybind11;
|
||||
namespace torch_mlir {
|
||||
void initClassAnnotatorBindings(py::module &m);
|
||||
void initClassAnnotatorBindings(py::module& m);
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_CLASS_ANNOTATOR_PYBIND_H
|
|
@ -50,9 +50,9 @@ static py::list getRegisteredOps() {
|
|||
// since the JIT has its own dispatch mechanism that it uses to implement
|
||||
// "prim" ops and a handful of "aten" ops that are effectively prim ops, such
|
||||
// as `aten::__is__`.
|
||||
for (const std::shared_ptr<torch::jit::Operator> &op :
|
||||
for (const std::shared_ptr<torch::jit::Operator>& op :
|
||||
torch::jit::getAllOperators()) {
|
||||
const c10::FunctionSchema &schema = op->schema();
|
||||
const c10::FunctionSchema& schema = op->schema();
|
||||
|
||||
py::dict record;
|
||||
{
|
||||
|
@ -69,7 +69,7 @@ static py::list getRegisteredOps() {
|
|||
|
||||
py::list arguments;
|
||||
py::list returns;
|
||||
auto addArgument = [](py::list &container, const c10::Argument &arg) {
|
||||
auto addArgument = [](py::list& container, const c10::Argument& arg) {
|
||||
py::dict argRecord;
|
||||
argRecord["name"] = arg.name();
|
||||
argRecord["type"] = arg.type()->str();
|
||||
|
@ -87,10 +87,10 @@ static py::list getRegisteredOps() {
|
|||
py::dict aliasInfo;
|
||||
py::list before;
|
||||
py::list after;
|
||||
for (auto &symbol : arg.alias_info()->beforeSets()) {
|
||||
for (auto& symbol : arg.alias_info()->beforeSets()) {
|
||||
before.append(std::string(symbol.toQualString()));
|
||||
}
|
||||
for (auto &symbol : arg.alias_info()->afterSets()) {
|
||||
for (auto& symbol : arg.alias_info()->afterSets()) {
|
||||
after.append(std::string(symbol.toQualString()));
|
||||
}
|
||||
aliasInfo["is_write"] = arg.alias_info()->isWrite();
|
||||
|
@ -101,10 +101,10 @@ static py::list getRegisteredOps() {
|
|||
|
||||
container.append(std::move(argRecord));
|
||||
};
|
||||
for (auto &argument : schema.arguments()) {
|
||||
for (auto& argument : schema.arguments()) {
|
||||
addArgument(arguments, argument);
|
||||
}
|
||||
for (auto &returnArg : schema.returns()) {
|
||||
for (auto& returnArg : schema.returns()) {
|
||||
addArgument(returns, returnArg);
|
||||
}
|
||||
record["arguments"] = std::move(arguments);
|
||||
|
@ -115,6 +115,6 @@ static py::list getRegisteredOps() {
|
|||
return results;
|
||||
}
|
||||
|
||||
void torch_mlir::initGetRegisteredOpsBindings(py::module &m) {
|
||||
void torch_mlir::initGetRegisteredOpsBindings(py::module& m) {
|
||||
m.def("get_registered_ops", &getRegisteredOps, kGetRegisteredOpsDocstring);
|
||||
}
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
namespace torch_mlir {
|
||||
|
||||
void initGetRegisteredOpsBindings(py::module &m);
|
||||
void initGetRegisteredOpsBindings(py::module& m);
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
|
@ -8,17 +8,19 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "import_options_pybind.h"
|
||||
#include "import_options.h"
|
||||
#include "jit_ir_importer/import_options.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
void torch_mlir::initImportOptionsBindings(py::module &m) {
|
||||
void torch_mlir::initImportOptionsBindings(py::module& m) {
|
||||
py::class_<ImportOptions>(m, "ImportOptions")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("assumeTensorsHaveValueSemantics",
|
||||
&ImportOptions::assumeTensorsHaveValueSemantics)
|
||||
.def_readwrite("ignoreExistingTensorShapesAndDtypes",
|
||||
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
|
||||
.def_readwrite(
|
||||
"assumeTensorsHaveValueSemantics",
|
||||
&ImportOptions::assumeTensorsHaveValueSemantics)
|
||||
.def_readwrite(
|
||||
"ignoreExistingTensorShapesAndDtypes",
|
||||
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
|
||||
}
|
|
@ -13,7 +13,7 @@
|
|||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
void initImportOptionsBindings(pybind11::module &m);
|
||||
void initImportOptionsBindings(pybind11::module& m);
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_PYBIND_H
|
|
@ -9,9 +9,9 @@
|
|||
|
||||
#include "module_builder.h"
|
||||
|
||||
#include "function_importer.h"
|
||||
#include "ivalue_importer.h"
|
||||
#include "mlir_utils.h"
|
||||
#include "jit_ir_importer/function_importer.h"
|
||||
#include "jit_ir_importer/ivalue_importer.h"
|
||||
#include "jit_ir_importer/mlir_utils.h"
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
|
@ -22,7 +22,7 @@
|
|||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
||||
static py::object getMlirIrClass(const char *className) {
|
||||
static py::object getMlirIrClass(const char* className) {
|
||||
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className);
|
||||
}
|
||||
|
||||
|
@ -33,7 +33,7 @@ static py::object createPythonContextIfNone(py::object contextObj) {
|
|||
return contextObj;
|
||||
}
|
||||
|
||||
static MlirContext castPythonObjectToMlirContext(py::object &contextObj) {
|
||||
static MlirContext castPythonObjectToMlirContext(py::object& contextObj) {
|
||||
assert(!contextObj.is_none() && "context cannot be None");
|
||||
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
|
||||
|
@ -77,15 +77,15 @@ static void printDiagnostic(MlirDiagnostic diagnostic) {
|
|||
std::stringstream ss;
|
||||
ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
|
||||
<< ": ";
|
||||
auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) {
|
||||
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
|
||||
auto stringCallback = [](MlirStringRef s, void* stringCallbackUserData) {
|
||||
auto* ssp = static_cast<std::stringstream*>(stringCallbackUserData);
|
||||
ssp->write(s.data, s.length);
|
||||
};
|
||||
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
|
||||
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void*>(&ss));
|
||||
// Use pybind11's print:
|
||||
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
||||
py::print(ss.str(),
|
||||
py::arg("file") = py::module_::import("sys").attr("stderr"));
|
||||
py::print(
|
||||
ss.str(), py::arg("file") = py::module_::import("sys").attr("stderr"));
|
||||
}
|
||||
|
||||
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
||||
|
@ -93,7 +93,7 @@ static void printDiagnostic(MlirDiagnostic diagnostic) {
|
|||
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
||||
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
||||
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
||||
void *) -> MlirLogicalResult {
|
||||
void*) -> MlirLogicalResult {
|
||||
printDiagnostic(diagnostic);
|
||||
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
|
||||
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
|
||||
|
@ -101,7 +101,7 @@ static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
|||
return mlirLogicalResultSuccess();
|
||||
};
|
||||
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
||||
context, diagnosticHandler, nullptr, [](void *) { return; });
|
||||
context, diagnosticHandler, nullptr, [](void*) { return; });
|
||||
// Ignore the ID. We intend to keep this handler for the entire lifetime
|
||||
// of this context.
|
||||
(void)id;
|
||||
|
@ -123,28 +123,28 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
|||
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
||||
}
|
||||
|
||||
torch::jit::StrongFunctionPtr
|
||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function,
|
||||
py::object maybeImportOptions) {
|
||||
torch::jit::StrongFunctionPtr ModuleBuilder::importFunction(
|
||||
torch::jit::StrongFunctionPtr function, py::object maybeImportOptions) {
|
||||
ImportOptions importOptions;
|
||||
if (!maybeImportOptions.is_none()) {
|
||||
importOptions = py::cast<ImportOptions>(maybeImportOptions);
|
||||
}
|
||||
MlirBlock block = getBodyBlock();
|
||||
MlirOperation terminator = this->terminator;
|
||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_,
|
||||
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
||||
MlirOperation func = importJitFunctionAsFuncOp(
|
||||
context, function.function_,
|
||||
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
||||
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
||||
return function;
|
||||
}
|
||||
|
||||
void ModuleBuilder::importModule(torch::jit::Module jitModule,
|
||||
py::object maybeClassAnnotator,
|
||||
py::object maybeImportOptions) {
|
||||
void ModuleBuilder::importModule(
|
||||
torch::jit::Module jitModule, py::object maybeClassAnnotator,
|
||||
py::object maybeImportOptions) {
|
||||
ClassAnnotator dummyAnnotator;
|
||||
ClassAnnotator *classAnnotator = &dummyAnnotator;
|
||||
ClassAnnotator* classAnnotator = &dummyAnnotator;
|
||||
if (!maybeClassAnnotator.is_none()) {
|
||||
classAnnotator = py::cast<ClassAnnotator *>(maybeClassAnnotator);
|
||||
classAnnotator = py::cast<ClassAnnotator*>(maybeClassAnnotator);
|
||||
}
|
||||
ImportOptions importOptions;
|
||||
if (!maybeImportOptions.is_none()) {
|
||||
|
@ -168,14 +168,15 @@ void ModuleBuilder::importModule(torch::jit::Module jitModule,
|
|||
// precise `torch.class_type` names.
|
||||
//
|
||||
// This name is not semantically load-bearing!!!
|
||||
auto &name = *jitModule.type()->name();
|
||||
auto& name = *jitModule.type()->name();
|
||||
auto debugModuleNameAttr = mlirStringAttrGet(
|
||||
context, toMlirStringRef(name.atoms()[name.atoms().size() - 1]));
|
||||
mlirOperationSetAttributeByName(mlirModuleGetOperation(module),
|
||||
toMlirStringRef("torch.debug_module_name"),
|
||||
debugModuleNameAttr);
|
||||
importIValue(jitModule._ivalue(), mlirModuleGetBody(module),
|
||||
mlirModuleGetContext(module), *classAnnotator, importOptions);
|
||||
mlirOperationSetAttributeByName(
|
||||
mlirModuleGetOperation(module),
|
||||
toMlirStringRef("torch.debug_module_name"), debugModuleNameAttr);
|
||||
importIValue(
|
||||
jitModule._ivalue(), mlirModuleGetBody(module),
|
||||
mlirModuleGetContext(module), *classAnnotator, importOptions);
|
||||
}
|
||||
|
||||
MlirBlock ModuleBuilder::getBodyBlock() {
|
||||
|
@ -183,14 +184,16 @@ MlirBlock ModuleBuilder::getBodyBlock() {
|
|||
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
||||
}
|
||||
|
||||
void ModuleBuilder::bind(py::module &m) {
|
||||
void ModuleBuilder::bind(py::module& m) {
|
||||
py::class_<ModuleBuilder>(m, "ModuleBuilder")
|
||||
.def(py::init<py::object>(), py::arg("context") = py::none())
|
||||
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||
.def("import_function", &ModuleBuilder::importFunction, py::arg("function"),
|
||||
py::arg("importOptions") = py::none())
|
||||
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
||||
py::arg("classAnnotator") = py::none(),
|
||||
py::arg("importOptions") = py::none());
|
||||
.def(
|
||||
"import_function", &ModuleBuilder::importFunction,
|
||||
py::arg("function"), py::arg("importOptions") = py::none())
|
||||
.def(
|
||||
"import_module", &ModuleBuilder::importModule, py::arg("module"),
|
||||
py::arg("classAnnotator") = py::none(),
|
||||
py::arg("importOptions") = py::none());
|
||||
}
|
|
@ -10,7 +10,7 @@
|
|||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||
#define TORCHMLIRJITIRIMPORTER_CSRC_BUILDER_H
|
||||
|
||||
#include "class_annotator.h"
|
||||
#include "jit_ir_importer/class_annotator.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
|
@ -29,7 +29,7 @@ public:
|
|||
ModuleBuilder(pybind11::object contextObj);
|
||||
|
||||
/// Creates Python bindings for the class.
|
||||
static void bind(pybind11::module &m);
|
||||
static void bind(pybind11::module& m);
|
||||
|
||||
pybind11::object getContextObj() { return contextObj; }
|
||||
pybind11::object getModuleObj() { return moduleObj; }
|
||||
|
@ -38,16 +38,15 @@ public:
|
|||
// torch.jit.ScriptFunction is the C++ type torch::jit::StrongFunctionPtr.
|
||||
// Just a bit of naming cruft.
|
||||
// Returns the same function, making it suitable as a nested decorator.
|
||||
torch::jit::StrongFunctionPtr
|
||||
importFunction(torch::jit::StrongFunctionPtr function,
|
||||
py::object maybeImportOptions);
|
||||
torch::jit::StrongFunctionPtr importFunction(
|
||||
torch::jit::StrongFunctionPtr function, py::object maybeImportOptions);
|
||||
|
||||
// Imports a torch::jit::Module into the current module, using the
|
||||
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
||||
// a ClassAnnotator.
|
||||
void importModule(torch::jit::Module jitModule,
|
||||
py::object maybeClassAnnotator,
|
||||
py::object maybeImportOptions);
|
||||
void importModule(
|
||||
torch::jit::Module jitModule, py::object maybeClassAnnotator,
|
||||
py::object maybeImportOptions);
|
||||
|
||||
private:
|
||||
MlirBlock getBodyBlock();
|
|
@ -1,28 +1,3 @@
|
|||
###########################################################################
|
||||
# Setup PyTorch
|
||||
###########################################################################
|
||||
|
||||
include(TorchMLIRPyTorch)
|
||||
|
||||
TorchMLIRProbeForPyTorchInstall()
|
||||
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||
TorchMLIRConfigurePyTorch()
|
||||
else()
|
||||
# Assume it is a sibling to the overall project.
|
||||
set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
||||
message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}")
|
||||
endif()
|
||||
|
||||
find_package(Torch 1.11 REQUIRED)
|
||||
|
||||
###########################################################################
|
||||
# Setup Python development
|
||||
###########################################################################
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/externals/llvm-project/mlir/cmake/modules")
|
||||
include(MLIRDetectPythonEnv)
|
||||
mlir_configure_python_dev_packages()
|
||||
|
||||
###########################################################################
|
||||
# Library definition
|
||||
###########################################################################
|
||||
|
|
|
@ -14,12 +14,12 @@
|
|||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||
#include <torch/csrc/lazy/core/shape.h>
|
||||
|
||||
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/debug.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/exception.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
|
||||
#include <base_lazy_backend/backend_impl.h>
|
||||
#include <base_lazy_backend/generated/LazyNativeFunctions.h>
|
||||
#include <base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <base_lazy_backend/utils/debug.h>
|
||||
#include <base_lazy_backend/utils/exception.h>
|
||||
#include <base_lazy_backend/utils/string_utils.h>
|
||||
|
||||
#include "backend_impl.h"
|
||||
|
||||
|
|
|
@ -11,10 +11,10 @@
|
|||
#include "torch/csrc/lazy/core/config.h"
|
||||
#include "torch/csrc/lazy/backend/backend_interface.h"
|
||||
|
||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h>
|
||||
#include <base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <base_lazy_backend/utils/string_utils.h>
|
||||
#include <base_lazy_backend/utils/sys_utils.h>
|
||||
#include <base_lazy_backend/utils/tensor_utils.h>
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
#-------------------------------------------------------------------------------
|
||||
# Setup PyTorch
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
include(TorchMLIRPyTorch)
|
||||
|
||||
TorchMLIRProbeForPyTorchInstall()
|
||||
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||
TorchMLIRConfigurePyTorch()
|
||||
else()
|
||||
# Assume it is a sibling to the overall project.
|
||||
set(Torch_DIR "${PROJECT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
||||
message(STATUS "Attempting to locate libtorch as a sibling to the project: ${Torch_DIR}")
|
||||
endif()
|
||||
|
||||
find_package(Torch 1.11 REQUIRED)
|
||||
|
||||
message(STATUS "libtorch_python CXXFLAGS is ...${TORCH_CXXFLAGS}")
|
||||
#-------------------------------------------------------------------------------
|
||||
# Subdirectories
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
add_subdirectory(csrc)
|
||||
|
||||
## Declare the sources of the Python module.
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES_GLOB
|
||||
dialects/torch/importer/jit_ir/*.py
|
||||
)
|
|
@ -0,0 +1,12 @@
|
|||
#-------------------------------------------------------------------------------
|
||||
# Subdirectories
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
## Declare the sources of the Python module.
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES_GLOB
|
||||
jit_ir_importer/*.py
|
||||
)
|
|
@ -11,8 +11,11 @@ import torch
|
|||
|
||||
# Our native extension is not self-contained. It references libraries which
|
||||
# must come in via the above first.
|
||||
from ....._mlir_libs._jit_ir_importer import *
|
||||
from .._mlir_libs._jit_ir_importer import *
|
||||
|
||||
# Ensure that the torch dialect has been loaded as it registers passes
|
||||
# and other things the jit_ir_importer needs.
|
||||
from ..dialects import torch as _unused_torch_dialect
|
||||
|
||||
__all__ = [
|
||||
"debug_trace_to_stderr",
|
|
@ -10,7 +10,7 @@ import codecs
|
|||
|
||||
import torch
|
||||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
from torch_mlir.jit_ir_importer import ModuleBuilder
|
||||
from torch_mlir.passmanager import PassManager
|
||||
|
||||
from .registry import Registry
|
|
@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
|
|||
import torch
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator
|
||||
from torch_mlir.jit_ir_importer import ClassAnnotator
|
||||
|
||||
# Decorators
|
||||
|
|
@ -23,4 +23,4 @@ add_lit_testsuite(check-torch-mlir-pt1 "Running the torch-mlir PT1 regression te
|
|||
)
|
||||
set_target_properties(check-torch-mlir-pt1 PROPERTIES FOLDER "Tests")
|
||||
|
||||
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
||||
add_lit_testsuites(TORCH_MLIR_PT1 ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
||||
|
|
|
@ -19,7 +19,7 @@ from lit.llvm.subst import FindTool
|
|||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'TORCH_MLIR'
|
||||
config.name = 'TORCH_MLIR_PT1'
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue