mirror of https://github.com/llvm/torch-mlir
Reference Lazy Backend (#1045)
* Changed Example MLIR backend to Reference MLIR backend * Moved reference_ltc_backend into csrc * Merged sys_utils.h * Renamed reference_ltc_backend to reference_lazy_backend * Addressed review comments * Update docs with new library name * Removed _REFERENCE_LAZY_BACKEND from .gitignore * Added reference_lazy_backend to the TorchMLIRPythonModules dependency list Fixed typo in `ltc_examples.md` Missed instance where `ltc_backend` was used instead of `lazy_backend`.pull/1125/head
parent
f5acad8512
commit
47bb38d180
|
@ -26,6 +26,3 @@ bazel-*
|
||||||
|
|
||||||
# Autogenerated files
|
# Autogenerated files
|
||||||
/python/torch_mlir/csrc/base_lazy_backend/generated
|
/python/torch_mlir/csrc/base_lazy_backend/generated
|
||||||
|
|
||||||
# Example backend
|
|
||||||
examples/ltc_backend/ltc_backend/_EXAMPLE_MLIR_BACKEND.cpython-37m-x86_64-linux-gnu.so
|
|
||||||
|
|
|
@ -192,4 +192,3 @@ else()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
add_subdirectory(examples)
|
|
||||||
|
|
|
@ -377,7 +377,7 @@ class GenTorchMlirLTC:
|
||||||
// for ops that dont have a corresponding structured kernel or shape definition
|
// for ops that dont have a corresponding structured kernel or shape definition
|
||||||
|
|
||||||
#include "shape_inference.h"
|
#include "shape_inference.h"
|
||||||
#include "../../utils/exception.h"
|
#include "../utils/exception.h"
|
||||||
namespace torch {{
|
namespace torch {{
|
||||||
namespace lazy {{
|
namespace lazy {{
|
||||||
{}
|
{}
|
||||||
|
|
|
@ -60,12 +60,15 @@ Generated files are created in this directory, which is ignored by version contr
|
||||||
- `shape_inference.cpp`
|
- `shape_inference.cpp`
|
||||||
- Implementation of select shape inference functions (most functions are [implemented upstream](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/core/shape_inference.cpp))
|
- Implementation of select shape inference functions (most functions are [implemented upstream](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/core/shape_inference.cpp))
|
||||||
|
|
||||||
|
### Reference Backend ([`python/torch_mlir/csrc/reference_lazy_backend`](../python/torch_mlir/csrc/reference_lazy_backend))
|
||||||
|
|
||||||
|
- `backend_impl.{cpp,h}`
|
||||||
|
- Reference Torch-MLIR LTC backend implementation, which simply stores the MLIR as a string and executes computation on CPU
|
||||||
|
- `reference_lazy_backend_pybind.cpp`
|
||||||
|
- pybind for reference Torch-MLIR LTC backend
|
||||||
|
|
||||||
### Examples ([`examples`](../examples))
|
### Examples ([`examples`](../examples))
|
||||||
|
|
||||||
- `examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.{cpp,h}`
|
|
||||||
- Example Torch-MLIR LTC backend implementation, which simply stores the MLIR as a string and executes computation on CPU
|
|
||||||
- `examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp`
|
|
||||||
- pybind for example Torch-MLIR LTC backend
|
|
||||||
- `ltc_backend_bert.py`
|
- `ltc_backend_bert.py`
|
||||||
- Example HuggingFace BERT model traced by LTC to MLIR
|
- Example HuggingFace BERT model traced by LTC to MLIR
|
||||||
- `ltc_backend_mnist.py`
|
- `ltc_backend_mnist.py`
|
||||||
|
@ -77,7 +80,7 @@ Generated files are created in this directory, which is ignored by version contr
|
||||||
|
|
||||||
The journey begins with a tensor in PyTorch on the `lazy` device, which may undergo a number of operations during its lifetime.
|
The journey begins with a tensor in PyTorch on the `lazy` device, which may undergo a number of operations during its lifetime.
|
||||||
```python
|
```python
|
||||||
>>> ltc_backend._initialize()
|
>>> lazy_backend._initialize()
|
||||||
>>> x = torch.tensor(..., device='lazy')
|
>>> x = torch.tensor(..., device='lazy')
|
||||||
>>> y = torch.tanh(x)
|
>>> y = torch.tanh(x)
|
||||||
...
|
...
|
||||||
|
@ -116,17 +119,17 @@ Finally, the compiled computation is sent to `TorchMlirBackendImpl::ExecuteCompu
|
||||||
|
|
||||||
## Implementing a custom backend
|
## Implementing a custom backend
|
||||||
|
|
||||||
An example implementation of a custom backend is available [here](../examples/ltc_backend/ltc_backend).
|
A reference implementation of a custom backend is available [here](../python/torch_mlir/csrc/reference_lazy_backend/).
|
||||||
All the work involved with generating MLIR is handled in the base LTC backend, so vendors only need to worry about implementing `Compile`, `ExecuteComputation`, and some other minor methods to interface with the device.
|
All the work involved with generating MLIR is handled in the base LTC backend, so vendors only need to worry about implementing `Compile`, `ExecuteComputation`, and some other minor methods to interface with the device.
|
||||||
|
|
||||||
A pybind is needed to invoke C++ code to register the autogen PyTorch kernels and the custom backend itself.
|
A pybind is needed to invoke C++ code to register the autogen PyTorch kernels and the custom backend itself.
|
||||||
Most of the code in the example implementation should be reusable, excluding some debug related function (e.g. `get_latest_computation`).
|
Most of the code in the reference implementation should be reusable, excluding some debug related function (e.g. `get_latest_computation`).
|
||||||
|
|
||||||
## Future Expansion
|
## Future Expansion
|
||||||
|
|
||||||
There are a number of areas for future improvement:
|
There are a number of areas for future improvement:
|
||||||
- Generate source information in `jit::Graph` so it can be embedded in the MLIR
|
- Generate source information in `jit::Graph` so it can be embedded in the MLIR
|
||||||
- Currently the example backend implementation executes via the `jit::Graph` instead of the MLIR since we currently lack lowerings for many ops, which would make it difficult to run models such as HF BERT
|
- Currently the reference backend implementation executes via the `jit::Graph` instead of the MLIR since we currently lack lowerings for many ops, which would make it difficult to run models such as HF BERT
|
||||||
- In the future, we should change the implementation to lower the MLIR to linalg and execute on a reference backend
|
- In the future, we should change the implementation to lower the MLIR to linalg and execute on a reference backend
|
||||||
- As new models get tested, we will inevitably run into errors related to unimplemented shape inference functions.
|
- As new models get tested, we will inevitably run into errors related to unimplemented shape inference functions.
|
||||||
This problem is simply solved by implementing the missing function, or adding a structured kernel to PyTorch.
|
This problem is simply solved by implementing the missing function, or adding a structured kernel to PyTorch.
|
||||||
|
|
|
@ -6,10 +6,10 @@ Refer to the main documentation [here](ltc_backend.md).
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
import torch._lazy
|
import torch._lazy
|
||||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
import torch_mlir.reference_lazy_backend._REFERENCE_LAZY_BACKEND as lazy_backend
|
||||||
|
|
||||||
# Register the example LTC backend.
|
# Register the example LTC backend.
|
||||||
ltc_backend._initialize()
|
lazy_backend._initialize()
|
||||||
|
|
||||||
device = 'lazy'
|
device = 'lazy'
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ torch._lazy.mark_step()
|
||||||
print('Results:', outputs)
|
print('Results:', outputs)
|
||||||
|
|
||||||
# Optionally dump MLIR graph generated from LTC trace.
|
# Optionally dump MLIR graph generated from LTC trace.
|
||||||
computation = ltc_backend.get_latest_computation()
|
computation = lazy_backend.get_latest_computation()
|
||||||
if computation:
|
if computation:
|
||||||
print(computation.debug_string())
|
print(computation.debug_string())
|
||||||
```
|
```
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
add_subdirectory(ltc_backend)
|
|
|
@ -1,26 +0,0 @@
|
||||||
//===- sys_utils.h --------------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace sys_util {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T GetEnv(const std::string &name, const T &default_value = T(0)) {
|
|
||||||
const char *env = std::getenv(name.c_str());
|
|
||||||
if (!env) {
|
|
||||||
return default_value;
|
|
||||||
}
|
|
||||||
return T(std::atoi(env));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace sys_util
|
|
|
@ -113,8 +113,8 @@ def main(device='lazy', full_size=False):
|
||||||
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
|
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
|
||||||
|
|
||||||
# Get debug information from LTC
|
# Get debug information from LTC
|
||||||
if 'ltc_backend' in sys.modules:
|
if 'torch_mlir.reference_lazy_backend._REFERENCE_LAZY_BACKEND' in sys.modules:
|
||||||
computation = ltc_backend.get_latest_computation()
|
computation = lazy_backend.get_latest_computation()
|
||||||
if computation:
|
if computation:
|
||||||
print(computation.debug_string())
|
print(computation.debug_string())
|
||||||
|
|
||||||
|
@ -148,9 +148,9 @@ if __name__ == "__main__":
|
||||||
torch._lazy.ts_backend.init()
|
torch._lazy.ts_backend.init()
|
||||||
|
|
||||||
elif args.device == "MLIR_EXAMPLE":
|
elif args.device == "MLIR_EXAMPLE":
|
||||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
import torch_mlir.reference_lazy_backend._REFERENCE_LAZY_BACKEND as lazy_backend
|
||||||
|
|
||||||
ltc_backend._initialize()
|
lazy_backend._initialize()
|
||||||
|
|
||||||
device = "lazy"
|
device = "lazy"
|
||||||
print("Initialized backend")
|
print("Initialized backend")
|
||||||
|
|
|
@ -65,8 +65,8 @@ def main(device='lazy'):
|
||||||
torch._lazy.mark_step()
|
torch._lazy.mark_step()
|
||||||
|
|
||||||
# Get debug information from LTC
|
# Get debug information from LTC
|
||||||
if 'ltc_backend' in sys.modules:
|
if 'torch_mlir.reference_lazy_backend._REFERENCE_LAZY_BACKEND' in sys.modules:
|
||||||
computation = ltc_backend.get_latest_computation()
|
computation = lazy_backend.get_latest_computation()
|
||||||
if computation:
|
if computation:
|
||||||
print(computation.debug_string())
|
print(computation.debug_string())
|
||||||
|
|
||||||
|
@ -93,9 +93,9 @@ if __name__ == "__main__":
|
||||||
torch._lazy.ts_backend.init()
|
torch._lazy.ts_backend.init()
|
||||||
|
|
||||||
elif args.device == "MLIR_EXAMPLE":
|
elif args.device == "MLIR_EXAMPLE":
|
||||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
import torch_mlir.reference_lazy_backend._REFERENCE_LAZY_BACKEND as lazy_backend
|
||||||
|
|
||||||
ltc_backend._initialize()
|
lazy_backend._initialize()
|
||||||
|
|
||||||
device = "lazy"
|
device = "lazy"
|
||||||
print("Initialized backend")
|
print("Initialized backend")
|
||||||
|
|
|
@ -60,7 +60,8 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
|
||||||
# Lazy Tensor Core
|
# Lazy Tensor Core
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
add_subdirectory(torch_mlir/csrc)
|
add_subdirectory(torch_mlir/csrc/base_lazy_backend)
|
||||||
|
add_subdirectory(torch_mlir/csrc/reference_lazy_backend)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Optionally handle JIT IR importer.
|
# Optionally handle JIT IR importer.
|
||||||
|
@ -155,6 +156,6 @@ endif()
|
||||||
|
|
||||||
# Add Torch-MLIR LTC backend as dependency
|
# Add Torch-MLIR LTC backend as dependency
|
||||||
add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend)
|
add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend)
|
||||||
|
add_dependencies(TorchMLIRPythonModules reference_lazy_backend)
|
||||||
|
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
|
|
||||||
|
|
|
@ -20,15 +20,15 @@ include_directories(BEFORE
|
||||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
|
||||||
set(LTC_GENERATED
|
set(LTC_GENERATED
|
||||||
base_lazy_backend/generated/LazyNativeFunctions.cpp
|
generated/LazyNativeFunctions.cpp
|
||||||
base_lazy_backend/generated/RegisterLazy.cpp
|
generated/RegisterLazy.cpp
|
||||||
base_lazy_backend/generated/shape_inference.cpp
|
generated/shape_inference.cpp
|
||||||
)
|
)
|
||||||
set(LTC_BACKEND_DEPENDS
|
set(LTC_BACKEND_DEPENDS
|
||||||
base_lazy_backend/mlir_lowering_context.cpp
|
mlir_lowering_context.cpp
|
||||||
base_lazy_backend/mlir_native_functions.cpp
|
mlir_native_functions.cpp
|
||||||
base_lazy_backend/mlir_node_lowering.cpp
|
mlir_node_lowering.cpp
|
||||||
base_lazy_backend/shape_inference.cpp
|
shape_inference.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate Lazy IR Nodes
|
# Generate Lazy IR Nodes
|
||||||
|
@ -57,10 +57,10 @@ add_custom_target(
|
||||||
add_library(torch_mlir_ltc_backend SHARED
|
add_library(torch_mlir_ltc_backend SHARED
|
||||||
${LTC_GENERATED}
|
${LTC_GENERATED}
|
||||||
${LTC_BACKEND_DEPENDS}
|
${LTC_BACKEND_DEPENDS}
|
||||||
base_lazy_backend/backend_impl.cpp
|
backend_impl.cpp
|
||||||
base_lazy_backend/mlir_node.cpp
|
mlir_node.cpp
|
||||||
base_lazy_backend/ops/device_data.cpp
|
ops/device_data.cpp
|
||||||
base_lazy_backend/ops/generic.cpp
|
ops/generic.cpp
|
||||||
)
|
)
|
||||||
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
||||||
|
|
|
@ -15,12 +15,12 @@
|
||||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
#include <torch/csrc/lazy/core/shape.h>
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
#include "../utils/debug.h"
|
|
||||||
#include "../utils/exception.h"
|
|
||||||
#include "backend_impl.h"
|
#include "backend_impl.h"
|
||||||
#include "ir_builder.h"
|
#include "ir_builder.h"
|
||||||
#include "mlir_lowering_context.h"
|
#include "mlir_lowering_context.h"
|
||||||
#include "ops/device_data.h"
|
#include "ops/device_data.h"
|
||||||
|
#include "utils/debug.h"
|
||||||
|
#include "utils/exception.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
#include "mlir_node.h"
|
#include "mlir_node.h"
|
||||||
#include "ops/device_data.h"
|
#include "ops/device_data.h"
|
||||||
#include "ops/generic.h"
|
#include "ops/generic.h"
|
||||||
#include "../utils/exception.h"
|
#include "utils/exception.h"
|
||||||
|
|
||||||
// This file contains the TorchMlir IrBuilder
|
// This file contains the TorchMlir IrBuilder
|
||||||
|
|
||||||
|
|
|
@ -17,13 +17,13 @@
|
||||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||||
|
|
||||||
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
|
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
|
||||||
#include "../utils/debug.h"
|
|
||||||
#include "../utils/exception.h"
|
|
||||||
#include "backend_impl.h"
|
#include "backend_impl.h"
|
||||||
#include "mlir-c/Registration.h"
|
#include "mlir-c/Registration.h"
|
||||||
#include "mlir_lowering_context.h"
|
#include "mlir_lowering_context.h"
|
||||||
#include "mlir_node.h"
|
#include "mlir_node.h"
|
||||||
#include "torch-mlir-c/Registration.h"
|
#include "torch-mlir-c/Registration.h"
|
||||||
|
#include "utils/debug.h"
|
||||||
|
#include "utils/exception.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
|
@ -10,11 +10,11 @@
|
||||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <ATen/InferSize.h>
|
|
||||||
#include <ATen/Operators.h>
|
|
||||||
#include <ATen/FunctionalTensorWrapper.h>
|
#include <ATen/FunctionalTensorWrapper.h>
|
||||||
|
#include <ATen/InferSize.h>
|
||||||
#include <ATen/MetaFunctions.h>
|
#include <ATen/MetaFunctions.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
|
#include <ATen/Operators.h>
|
||||||
#include <ATen/native/BinaryOps.h>
|
#include <ATen/native/BinaryOps.h>
|
||||||
#include <ATen/native/CPUFallback.h>
|
#include <ATen/native/CPUFallback.h>
|
||||||
#include <ATen/ops/empty.h>
|
#include <ATen/ops/empty.h>
|
||||||
|
@ -28,12 +28,11 @@
|
||||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
|
||||||
#include "../utils/exception.h"
|
|
||||||
#include "../utils/sys_utils.h"
|
|
||||||
#include "generated/shape_inference.h"
|
|
||||||
#include "generated/LazyNativeFunctions.h"
|
#include "generated/LazyNativeFunctions.h"
|
||||||
|
#include "generated/shape_inference.h"
|
||||||
#include "ops/to_copy.h"
|
#include "ops/to_copy.h"
|
||||||
|
#include "utils/exception.h"
|
||||||
|
#include "utils/sys_utils.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
@ -174,7 +173,6 @@ at::Tensor LazyNativeFunctions::cat(at::TensorList tensors, int64_t dim) {
|
||||||
// return result;
|
// return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// clone is special in LT because we make it a no-op.
|
// clone is special in LT because we make it a no-op.
|
||||||
// This should be safe to do, because every operator in the LT is functional.
|
// This should be safe to do, because every operator in the LT is functional.
|
||||||
at::Tensor LazyNativeFunctions::clone(
|
at::Tensor LazyNativeFunctions::clone(
|
||||||
|
@ -290,12 +288,16 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
auto lazy_self = torch::lazy::TryGetLtcTensor(self);
|
auto lazy_self = torch::lazy::TryGetLtcTensor(self);
|
||||||
if (!lazy_self && device && device->type() == c10::kLazy) {
|
if (!lazy_self && device && device->type() == c10::kLazy) {
|
||||||
// Case 1: eager->lazy (we create a new lazy tensor)
|
// Case 1: eager->lazy (we create a new lazy tensor)
|
||||||
// See Note [Lazy Tensor Functionalization]
|
// See Note [Lazy Tensor Functionalization]
|
||||||
// Invariant: if the functionalization key is in the exclude set, then we're expected
|
// Invariant: if the functionalization key is in the exclude set, then we're expected
|
||||||
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
|
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
|
||||||
bool functionalize_output = !c10::impl::tls_local_dispatch_key_set().excluded_.has(c10::DispatchKey::Functionalize);
|
bool functionalize_output =
|
||||||
return torch::lazy::to_lazy_tensor(self, options, *device, /*non_blocking=*/non_blocking, /*functionalize_output=*/functionalize_output);
|
!c10::impl::tls_local_dispatch_key_set().excluded_.has(
|
||||||
|
c10::DispatchKey::Functionalize);
|
||||||
|
return torch::lazy::to_lazy_tensor(
|
||||||
|
self, options, *device, /*non_blocking=*/non_blocking,
|
||||||
|
/*functionalize_output=*/functionalize_output);
|
||||||
} else if (device && device->type() != c10::kLazy) {
|
} else if (device && device->type() != c10::kLazy) {
|
||||||
// Case 2: lazy->eager (forces a graph break since we are materializing a tensor)
|
// Case 2: lazy->eager (forces a graph break since we are materializing a tensor)
|
||||||
|
|
||||||
|
@ -368,7 +370,8 @@ at::Tensor LazyNativeFunctions::empty(
|
||||||
auto x_result = at::empty(size, options, memory_format);
|
auto x_result = at::empty(size, options, memory_format);
|
||||||
auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device));
|
auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device));
|
||||||
// See Note [Lazy Tensor Functionalization]
|
// See Note [Lazy Tensor Functionalization]
|
||||||
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(c10::DispatchKey::Functionalize)) {
|
if (c10::impl::tls_local_dispatch_key_set().excluded_.has(
|
||||||
|
c10::DispatchKey::Functionalize)) {
|
||||||
// Invariant: if the functionalization key is in the exclude set, then we're expected
|
// Invariant: if the functionalization key is in the exclude set, then we're expected
|
||||||
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
|
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
|
||||||
return tensor;
|
return tensor;
|
||||||
|
@ -409,7 +412,8 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
|
||||||
// LazyTensor always opts into functionalization.
|
// LazyTensor always opts into functionalization.
|
||||||
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.
|
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.
|
||||||
at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) {
|
at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) {
|
||||||
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(tensor));
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
!at::functionalization::impl::isFunctionalTensor(tensor));
|
||||||
return at::functionalization::impl::to_functional_tensor(tensor);
|
return at::functionalization::impl::to_functional_tensor(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -418,43 +422,75 @@ at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) {
|
||||||
// These are all composite ops that LTC can technically re-use / get for free,
|
// These are all composite ops that LTC can technically re-use / get for free,
|
||||||
// but we need to "functionalize" them to remove the view ops before we can use them.
|
// but we need to "functionalize" them to remove the view ops before we can use them.
|
||||||
at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
|
at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(block_diag)>::call(tensors);
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
|
block_diag)>::call(tensors);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::new_empty_strided(const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
at::Tensor LazyNativeFunctions::new_empty_strided(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout, device, pin_memory);
|
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
|
||||||
|
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||||
|
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||||
|
return at::functionalization::
|
||||||
|
functionalize_aten_op<ATEN_OP(new_empty_strided)>::call(
|
||||||
|
self, size, stride, dtype, layout, device, pin_memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::narrow_copy(const at::Tensor& self, int64_t dim, int64_t start, int64_t length) {
|
at::Tensor LazyNativeFunctions::narrow_copy(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(narrow_copy)>::call(self, dim, start, length);
|
const at::Tensor& self, int64_t dim, int64_t start, int64_t length) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
|
narrow_copy)>::call(self, dim, start, length);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::pixel_shuffle(const at::Tensor & self, int64_t upscale_factor) {
|
at::Tensor LazyNativeFunctions::pixel_shuffle(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(pixel_shuffle)>::call(self, upscale_factor);
|
const at::Tensor& self, int64_t upscale_factor) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
|
pixel_shuffle)>::call(self, upscale_factor);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::pixel_unshuffle(const at::Tensor & self, int64_t downscale_factor) {
|
at::Tensor LazyNativeFunctions::pixel_unshuffle(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(pixel_unshuffle)>::call(self, downscale_factor);
|
const at::Tensor& self, int64_t downscale_factor) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
|
pixel_unshuffle)>::call(self, downscale_factor);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::select_backward(const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index) {
|
at::Tensor LazyNativeFunctions::select_backward(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(select_backward)>::call(grad_output, input_sizes, dim, index);
|
const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim,
|
||||||
|
int64_t index) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
|
select_backward)>::call(grad_output, input_sizes, dim, index);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::slice_backward(const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
at::Tensor LazyNativeFunctions::slice_backward(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
|
const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim,
|
||||||
|
int64_t start, int64_t end, int64_t step) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
|
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::diagonal_backward(const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
|
at::Tensor LazyNativeFunctions::diagonal_backward(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
|
const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t offset,
|
||||||
|
int64_t dim1, int64_t dim2) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
|
diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::_trilinear(const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim) {
|
at::Tensor LazyNativeFunctions::_trilinear(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>::call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
|
const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3,
|
||||||
|
at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3,
|
||||||
|
at::IntArrayRef sumdim, int64_t unroll_dim) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>::
|
||||||
|
call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
|
||||||
}
|
}
|
||||||
::std::tuple<at::Tensor,at::Tensor> LazyNativeFunctions::linalg_inv_ex(const at::Tensor & self, bool check_errors) {
|
::std::tuple<at::Tensor, at::Tensor>
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(linalg_inv_ex)>::call(self, check_errors);
|
LazyNativeFunctions::linalg_inv_ex(const at::Tensor& self, bool check_errors) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
|
linalg_inv_ex)>::call(self, check_errors);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::linalg_pinv(const at::Tensor & self, const c10::optional<at::Tensor> & atol, const c10::optional<at::Tensor> & rtol, bool hermitian) {
|
at::Tensor LazyNativeFunctions::linalg_pinv(
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP2(linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
|
const at::Tensor& self, const c10::optional<at::Tensor>& atol,
|
||||||
|
const c10::optional<at::Tensor>& rtol, bool hermitian) {
|
||||||
|
return at::functionalization::functionalize_aten_op<ATEN_OP2(
|
||||||
|
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
|
||||||
}
|
}
|
||||||
|
|
||||||
// functionalize_aten_op can't handle out= ops directly.
|
// functionalize_aten_op can't handle out= ops directly.
|
||||||
// Instead, we can call the composite kernel from core, and copy and mutations back to the inputs.
|
// Instead, we can call the composite kernel from core, and copy and mutations back to the inputs.
|
||||||
at::Tensor & LazyNativeFunctions::logsumexp_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor& out) {
|
at::Tensor& LazyNativeFunctions::logsumexp_out(
|
||||||
|
const at::Tensor& self, at::IntArrayRef dim, bool keepdim,
|
||||||
|
at::Tensor& out) {
|
||||||
auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
|
auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
|
||||||
auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
|
auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
|
||||||
// directly call the composite kernel from core.
|
// directly call the composite kernel from core.
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir_node.h"
|
#include "mlir_node.h"
|
||||||
#include "../utils/exception.h"
|
#include "utils/exception.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
@ -74,7 +74,8 @@ hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
||||||
OpKind TorchMlirTensorList::ClassOpKind() {
|
OpKind TorchMlirTensorList::ClassOpKind() {
|
||||||
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
||||||
// import otherwise
|
// import otherwise
|
||||||
static const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
|
static const OpKind tensor_list_opkind =
|
||||||
|
OpKind::Get("lazy_tensors::tensor_list");
|
||||||
return tensor_list_opkind;
|
return tensor_list_opkind;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
#include <torch/csrc/lazy/core/ir.h>
|
#include <torch/csrc/lazy/core/ir.h>
|
||||||
#include <torch/csrc/lazy/core/shape.h>
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
#include "../utils/debug.h"
|
|
||||||
#include "../utils/exception.h"
|
|
||||||
#include "mlir_lowering_context.h"
|
#include "mlir_lowering_context.h"
|
||||||
|
#include "utils/debug.h"
|
||||||
|
#include "utils/exception.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
@ -60,7 +60,6 @@ private:
|
||||||
hash_t dag_hash_;
|
hash_t dag_hash_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// TensorList represents an at::TensorList which is a vector[Tensor] but is also
|
// TensorList represents an at::TensorList which is a vector[Tensor] but is also
|
||||||
// a first-class IValue and can be fed as a single input to a TS program. It is
|
// a first-class IValue and can be fed as a single input to a TS program. It is
|
||||||
// much easier to handle TensorLists in Lazy Tensor code if they are represented
|
// much easier to handle TensorLists in Lazy Tensor code if they are represented
|
||||||
|
|
|
@ -209,17 +209,17 @@ GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
|
||||||
return cloned.front();
|
return cloned.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GenerateCopy(
|
||||||
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source, TorchMlirFunction function) {
|
torch::jit::Value* destination, torch::jit::Value* source,
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
TorchMlirFunction function) {
|
||||||
arguments.emplace_back(destination);
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
arguments.emplace_back(source);
|
arguments.emplace_back(destination);
|
||||||
LowerBuiltin(
|
arguments.emplace_back(source);
|
||||||
at::aten::copy_,
|
LowerBuiltin(
|
||||||
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), function, arguments);
|
at::aten::copy_, c10::ArrayRef<Shape>(compute_shape_copy(source->type())),
|
||||||
|
function, arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
torch::jit::Value* GenerateSlice(
|
torch::jit::Value* GenerateSlice(
|
||||||
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
|
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
|
||||||
int64_t step, TorchMlirFunction function) {
|
int64_t step, TorchMlirFunction function) {
|
||||||
|
@ -234,8 +234,7 @@ torch::jit::Value* GenerateSlice(
|
||||||
at::aten::slice,
|
at::aten::slice,
|
||||||
c10::ArrayRef<Shape>(
|
c10::ArrayRef<Shape>(
|
||||||
compute_shape_slice(base->type(), dim, start, end, step)),
|
compute_shape_slice(base->type(), dim, start, end, step)),
|
||||||
function,
|
function, arguments);
|
||||||
arguments);
|
|
||||||
CHECK_EQ(selected.size(), 1);
|
CHECK_EQ(selected.size(), 1);
|
||||||
return selected.front();
|
return selected.front();
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,8 @@
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "../utils/exception.h"
|
|
||||||
#include "generated/shape_inference.h"
|
#include "generated/shape_inference.h"
|
||||||
|
#include "utils/exception.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
@ -20,7 +20,7 @@ namespace lazy {
|
||||||
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future.
|
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future.
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape>
|
std::vector<torch::lazy::Shape>
|
||||||
compute_shape_div(const at::Tensor& self, const at::Scalar & other) {
|
compute_shape_div(const at::Tensor& self, const at::Scalar& other) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,15 @@
|
||||||
|
|
||||||
namespace sys_util {
|
namespace sys_util {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static T GetEnv(const std::string& name, const T& default_value = T(0)) {
|
||||||
|
const char* env = std::getenv(name.c_str());
|
||||||
|
if (!env) {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
return T(std::atoi(env));
|
||||||
|
}
|
||||||
|
|
||||||
static bool GetEnvBool(const char* name, bool defval) {
|
static bool GetEnvBool(const char* name, bool defval) {
|
||||||
const char* env = std::getenv(name);
|
const char* env = std::getenv(name);
|
||||||
if (env == nullptr) {
|
if (env == nullptr) {
|
|
@ -30,22 +30,18 @@ include_directories(BEFORE
|
||||||
${PROJECT_SOURCE_DIR}/python
|
${PROJECT_SOURCE_DIR}/python
|
||||||
)
|
)
|
||||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/lib)
|
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib)
|
||||||
add_link_options(-Wl,-rpath,$ORIGIN/ltc_backend/lib)
|
add_link_options(-Wl,-rpath,$ORIGIN/lib)
|
||||||
|
|
||||||
file(GLOB LTC_BACKEND_CSRC CONFIGURE_DEPENDS
|
set(REFERENCE_LAZY_BACKEND_CSRC
|
||||||
"ltc_backend/csrc/*.h"
|
backend_impl.cpp
|
||||||
"ltc_backend/csrc/*.cc"
|
reference_lazy_backend_pybind.cpp
|
||||||
"ltc_backend/csrc/*.cpp"
|
|
||||||
"ltc_backend/csrc/*/*.h"
|
|
||||||
"ltc_backend/csrc/*/*.cc"
|
|
||||||
"ltc_backend/csrc/*/*.cpp"
|
|
||||||
)
|
)
|
||||||
add_library(example_mlir_ltc_backend SHARED ${LTC_BACKEND_CSRC})
|
add_library(reference_lazy_backend SHARED ${REFERENCE_LAZY_BACKEND_CSRC})
|
||||||
add_dependencies(example_mlir_ltc_backend
|
add_dependencies(reference_lazy_backend
|
||||||
torch_mlir_ltc_backend
|
torch_mlir_ltc_backend
|
||||||
)
|
)
|
||||||
target_link_libraries(example_mlir_ltc_backend
|
target_link_libraries(reference_lazy_backend
|
||||||
${TORCH_LIBRARIES}
|
${TORCH_LIBRARIES}
|
||||||
${Python3_LIBRARIES}
|
${Python3_LIBRARIES}
|
||||||
torch_python
|
torch_python
|
||||||
|
@ -53,9 +49,9 @@ target_link_libraries(example_mlir_ltc_backend
|
||||||
)
|
)
|
||||||
|
|
||||||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
|
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
|
||||||
set_target_properties(example_mlir_ltc_backend PROPERTIES
|
set_target_properties(reference_lazy_backend PROPERTIES
|
||||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/"
|
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/reference_lazy_backend"
|
||||||
OUTPUT_NAME _EXAMPLE_MLIR_BACKEND
|
OUTPUT_NAME _REFERENCE_LAZY_BACKEND
|
||||||
PREFIX "${PYTHON_MODULE_PREFIX}"
|
PREFIX "${PYTHON_MODULE_PREFIX}"
|
||||||
SUFFIX "${PYTHON_MODULE_EXTENSION}"
|
SUFFIX "${PYTHON_MODULE_EXTENSION}"
|
||||||
CXX_VISIBILITY_PRESET "hidden"
|
CXX_VISIBILITY_PRESET "hidden"
|
|
@ -15,8 +15,8 @@
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
|
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
|
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||||
#include <torch_mlir/csrc/utils/debug.h>
|
#include <torch_mlir/csrc/base_lazy_backend/utils/debug.h>
|
||||||
#include <torch_mlir/csrc/utils/exception.h>
|
#include <torch_mlir/csrc/base_lazy_backend/utils/exception.h>
|
||||||
|
|
||||||
#include "backend_impl.h"
|
#include "backend_impl.h"
|
||||||
|
|
||||||
|
@ -25,8 +25,8 @@ using namespace torch::lazy;
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
struct ExampleMlirBackendDeviceType : public BackendDeviceType {
|
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
|
||||||
ExampleMlirBackendDeviceType(std::string device_type)
|
ReferenceLazyBackendDeviceType(std::string device_type)
|
||||||
: device_type_(device_type) {}
|
: device_type_(device_type) {}
|
||||||
|
|
||||||
std::string toString() const override { return device_type_; }
|
std::string toString() const override { return device_type_; }
|
||||||
|
@ -34,9 +34,9 @@ struct ExampleMlirBackendDeviceType : public BackendDeviceType {
|
||||||
std::string device_type_;
|
std::string device_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl {
|
class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl {
|
||||||
public:
|
public:
|
||||||
ExampleMlirBackendImpl() : default_device_type_("Magic") {}
|
ReferenceLazyBackendImpl() : default_device_type_("Magic") {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configuration
|
* Configuration
|
||||||
|
@ -48,9 +48,9 @@ public:
|
||||||
/**
|
/**
|
||||||
* Lowering, Compilation, Execution
|
* Lowering, Compilation, Execution
|
||||||
* */
|
* */
|
||||||
std::vector<std::string>
|
std::vector<std::string> GetCompilationDevices(
|
||||||
GetCompilationDevices(const std::string &device,
|
const std::string& device,
|
||||||
c10::ArrayRef<std::string> devices) const override {
|
c10::ArrayRef<std::string> devices) const override {
|
||||||
return std::vector<std::string>(devices.begin(), devices.end());
|
return std::vector<std::string>(devices.begin(), devices.end());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ public:
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
|
||||||
// Vendor backend specific lowering can be exec here before returning.
|
// Vendor backend specific lowering can be exec here before returning.
|
||||||
for (const auto &instance : instances) {
|
for (const auto& instance : instances) {
|
||||||
// Store computation instance for external access after compilation.
|
// Store computation instance for external access after compilation.
|
||||||
GetLatestComputation() = instance;
|
GetLatestComputation() = instance;
|
||||||
}
|
}
|
||||||
|
@ -70,17 +70,18 @@ public:
|
||||||
return instances;
|
return instances;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<BackendDataPtr>
|
std::vector<BackendDataPtr> ExecuteComputation(
|
||||||
ExecuteComputation(torch::lazy::ComputationPtr computation,
|
torch::lazy::ComputationPtr computation,
|
||||||
c10::ArrayRef<BackendDataPtr> arguments,
|
c10::ArrayRef<BackendDataPtr> arguments,
|
||||||
const BackendDevice &device) const override {
|
const BackendDevice& device) const override {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
|
||||||
// `arguments` maps 1:1 with the parameters in the generated MLIR. In this
|
// `arguments` maps 1:1 with the parameters in the generated MLIR. In this
|
||||||
// function, we will generate a list of BackendData that corresponds to the
|
// function, we will generate a list of BackendData that corresponds to the
|
||||||
// return values in the MLIR.
|
// return values in the MLIR.
|
||||||
|
|
||||||
auto mlir_computation = static_cast<TorchMlirComputation *>(computation.get());
|
auto mlir_computation =
|
||||||
|
static_cast<TorchMlirComputation*>(computation.get());
|
||||||
|
|
||||||
// Vendor backend specific execution can be inserted here.
|
// Vendor backend specific execution can be inserted here.
|
||||||
//
|
//
|
||||||
|
@ -91,7 +92,7 @@ public:
|
||||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
|
||||||
torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), "");
|
torch::jit::GraphExecutor graph_executor(mlir_computation->graph(), "");
|
||||||
std::vector<torch::jit::IValue> stack;
|
std::vector<torch::jit::IValue> stack;
|
||||||
for (const auto &argument : arguments) {
|
for (const auto& argument : arguments) {
|
||||||
const auto mlir_data =
|
const auto mlir_data =
|
||||||
std::static_pointer_cast<TorchMlirBackendData>(argument);
|
std::static_pointer_cast<TorchMlirBackendData>(argument);
|
||||||
if (mlir_data->mlir_info()->scalar.has_value()) {
|
if (mlir_data->mlir_info()->scalar.has_value()) {
|
||||||
|
@ -128,7 +129,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetDefaultDeviceType(std::string device_type) {
|
void SetDefaultDeviceType(std::string device_type) {
|
||||||
default_device_type_ = ExampleMlirBackendDeviceType(device_type);
|
default_device_type_ = ReferenceLazyBackendDeviceType(device_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -146,22 +147,22 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ExampleMlirBackendDeviceType default_device_type_;
|
ReferenceLazyBackendDeviceType default_device_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
BackendImplInterface *GetExampleMlirBackendImpl() {
|
BackendImplInterface* GetReferenceLazyBackendImpl() {
|
||||||
static ExampleMlirBackendImpl *example_mlir_backend_impl =
|
static ReferenceLazyBackendImpl* reference_lazy_backend_impl =
|
||||||
new ExampleMlirBackendImpl();
|
new ReferenceLazyBackendImpl();
|
||||||
return example_mlir_backend_impl;
|
return reference_lazy_backend_impl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitExampleMlirBackend() {
|
void InitReferenceLazyBackend() {
|
||||||
at::RegisterTorchMlirLazyNativeFunctions();
|
at::RegisterTorchMlirLazyNativeFunctions();
|
||||||
static std::unique_ptr<BackendRegistrar> g_registrar;
|
static std::unique_ptr<BackendRegistrar> g_registrar;
|
||||||
g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl()));
|
g_registrar.reset(new BackendRegistrar(GetReferenceLazyBackendImpl()));
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputationPtr &GetLatestComputation() {
|
ComputationPtr& GetLatestComputation() {
|
||||||
// Store the computation from the most recent compile.
|
// Store the computation from the most recent compile.
|
||||||
static ComputationPtr computation;
|
static ComputationPtr computation;
|
||||||
return computation;
|
return computation;
|
|
@ -19,11 +19,11 @@ TORCH_API void RegisterTorchMlirLazyNativeFunctions();
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl();
|
torch::lazy::BackendImplInterface* GetReferenceLazyBackendImpl();
|
||||||
|
|
||||||
void InitExampleMlirBackend();
|
void InitReferenceLazyBackend();
|
||||||
|
|
||||||
ComputationPtr &GetLatestComputation();
|
ComputationPtr& GetLatestComputation();
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
|
@ -1,4 +1,4 @@
|
||||||
//===- example_mlir_backend_pybind.cpp ------------------------------------===//
|
//===- reference_lazy_backend_pybind.cpp ----------------------------------===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -11,13 +11,13 @@
|
||||||
#include "torch/csrc/lazy/backend/backend_interface.h"
|
#include "torch/csrc/lazy/backend/backend_interface.h"
|
||||||
|
|
||||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||||
|
#include <torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h>
|
||||||
|
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "backend/backend_impl.h"
|
#include "backend_impl.h"
|
||||||
#include "utils/sys_utils.h"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
@ -27,20 +27,20 @@ bool verbose = sys_util::GetEnv("VERBOSE", false);
|
||||||
struct NoGilSection {
|
struct NoGilSection {
|
||||||
NoGilSection() : state(PyEval_SaveThread()) {}
|
NoGilSection() : state(PyEval_SaveThread()) {}
|
||||||
~NoGilSection() { PyEval_RestoreThread(state); }
|
~NoGilSection() { PyEval_RestoreThread(state); }
|
||||||
PyThreadState *state = nullptr;
|
PyThreadState* state = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Install the plugin
|
* @brief Install the plugin
|
||||||
*/
|
*/
|
||||||
void Initialize() {
|
void Initialize() {
|
||||||
// Initialize the Example MLIR LTC Backend
|
// Initialize the Reference Lazy Backend
|
||||||
torch::lazy::InitExampleMlirBackend();
|
torch::lazy::InitReferenceLazyBackend();
|
||||||
|
|
||||||
// sanity check
|
// sanity check
|
||||||
const torch::lazy::BackendImplInterface *mlir_backend =
|
const torch::lazy::BackendImplInterface* mlir_backend =
|
||||||
torch::lazy::GetExampleMlirBackendImpl();
|
torch::lazy::GetReferenceLazyBackendImpl();
|
||||||
const torch::lazy::BackendImplInterface *lazy_backend =
|
const torch::lazy::BackendImplInterface* lazy_backend =
|
||||||
torch::lazy::getBackend();
|
torch::lazy::getBackend();
|
||||||
if (lazy_backend != mlir_backend) {
|
if (lazy_backend != mlir_backend) {
|
||||||
std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl;
|
std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl;
|
||||||
|
@ -62,14 +62,14 @@ void Shutdown() {
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) {
|
PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) {
|
||||||
py::class_<torch::lazy::TorchMlirComputation>(m, "TorchMlirComputation")
|
py::class_<torch::lazy::TorchMlirComputation>(m, "TorchMlirComputation")
|
||||||
.def("to_string", &torch::lazy::TorchMlirComputation::to_string)
|
.def("to_string", &torch::lazy::TorchMlirComputation::to_string)
|
||||||
.def("debug_string", &torch::lazy::TorchMlirComputation::debug_string);
|
.def("debug_string", &torch::lazy::TorchMlirComputation::debug_string);
|
||||||
|
|
||||||
m.doc() = ("pybind11 for example MLIR LTC backend.");
|
m.doc() = ("pybind11 for the Reference Lazy backend.");
|
||||||
m.def("get_latest_computation", []() {
|
m.def("get_latest_computation", []() {
|
||||||
auto computation = static_cast<torch::lazy::TorchMlirComputation *>(
|
auto computation = static_cast<torch::lazy::TorchMlirComputation*>(
|
||||||
torch::lazy::GetLatestComputation().get());
|
torch::lazy::GetLatestComputation().get());
|
||||||
return py::cast(computation);
|
return py::cast(computation);
|
||||||
});
|
});
|
|
@ -3,7 +3,7 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
import torch_mlir.reference_lazy_backend._REFERENCE_LAZY_BACKEND as lazy_backend
|
||||||
import torch
|
import torch
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ class LazyTensorCoreTestConfig(TestConfig):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
ltc_backend._initialize()
|
lazy_backend._initialize()
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||||
return program.to('lazy')
|
return program.to('lazy')
|
||||||
|
|
Loading…
Reference in New Issue