mirror of https://github.com/llvm/torch-mlir
Add example Torch MLIR LTC Backend (#725)
parent
3e9b1cbd36
commit
a605fe279c
|
@ -192,3 +192,4 @@ else()
|
|||
endif()
|
||||
|
||||
add_subdirectory(test)
|
||||
add_subdirectory(examples)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(ltc_backend)
|
|
@ -0,0 +1,63 @@
|
|||
###########################################################################
|
||||
# Setup PyTorch
|
||||
###########################################################################
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
|
||||
include(TorchMLIRPyTorch)
|
||||
TorchMLIRProbeForPyTorchInstall()
|
||||
find_package(Torch 1.11 REQUIRED)
|
||||
|
||||
TorchMLIRConfigurePyTorch()
|
||||
|
||||
###########################################################################
|
||||
# Setup Python development
|
||||
###########################################################################
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/external/llvm-project/mlir/cmake/modules")
|
||||
include(MLIRDetectPythonEnv)
|
||||
mlir_configure_python_dev_packages()
|
||||
|
||||
###########################################################################
|
||||
# Library definition
|
||||
###########################################################################
|
||||
|
||||
include_directories(BEFORE
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
${PYTHON_H_DIR}
|
||||
${PROJECT_SOURCE_DIR}/python
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/lib)
|
||||
add_link_options(-Wl,-rpath,$ORIGIN/ltc_backend/lib)
|
||||
|
||||
file(GLOB LTC_BACKEND_CSRC CONFIGURE_DEPENDS
|
||||
"ltc_backend/csrc/*.h"
|
||||
"ltc_backend/csrc/*.cc"
|
||||
"ltc_backend/csrc/*.cpp"
|
||||
"ltc_backend/csrc/*/*.h"
|
||||
"ltc_backend/csrc/*/*.cc"
|
||||
"ltc_backend/csrc/*/*.cpp"
|
||||
)
|
||||
add_library(example_mlir_ltc_backend SHARED ${LTC_BACKEND_CSRC})
|
||||
add_dependencies(example_mlir_ltc_backend
|
||||
torch_mlir_ltc_backend
|
||||
)
|
||||
target_link_libraries(example_mlir_ltc_backend
|
||||
${TORCH_LIBRARIES}
|
||||
${Python3_LIBRARIES}
|
||||
torch_python
|
||||
torch_mlir_ltc_backend
|
||||
)
|
||||
|
||||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
|
||||
set_target_properties(example_mlir_ltc_backend PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/"
|
||||
OUTPUT_NAME _EXAMPLE_MLIR_BACKEND
|
||||
PREFIX "${PYTHON_MODULE_PREFIX}"
|
||||
SUFFIX "${PYTHON_MODULE_EXTENSION}"
|
||||
CXX_VISIBILITY_PRESET "hidden"
|
||||
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
|
||||
)
|
|
@ -0,0 +1,140 @@
|
|||
//===- backend_impl.cpp ---------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||
#include <torch/csrc/lazy/core/shape.h>
|
||||
|
||||
#include <torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <torch_mlir/csrc/utils/debug.h>
|
||||
#include <torch_mlir/csrc/utils/exception.h>
|
||||
|
||||
#include "backend_impl.h"
|
||||
|
||||
using namespace torch::lazy;
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
struct ExampleMlirBackendDeviceType : public BackendDeviceType {
|
||||
ExampleMlirBackendDeviceType(std::string device_type)
|
||||
: device_type_(device_type) {}
|
||||
|
||||
std::string toString() const override { return device_type_; }
|
||||
|
||||
std::string device_type_;
|
||||
};
|
||||
|
||||
class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl {
|
||||
public:
|
||||
ExampleMlirBackendImpl() : default_device_type_("Magic") {}
|
||||
|
||||
/**
|
||||
* Configuration
|
||||
* */
|
||||
void SetRngSeed(size_t seed) const override {
|
||||
std::cout << "RNG Seed Set to: " << seed << std::endl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Lowering, Compilation, Execution
|
||||
* */
|
||||
std::vector<std::string>
|
||||
GetCompilationDevices(const std::string &device,
|
||||
c10::ArrayRef<std::string> devices) const override {
|
||||
return std::vector<std::string>(devices.begin(), devices.end());
|
||||
};
|
||||
|
||||
std::vector<ComputationPtr>
|
||||
Compile(std::vector<ComputationPtr> instances) const override {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
// Vendor backend specific lowering can be exec here before returning.
|
||||
for (const auto &instance : instances) {
|
||||
std::cout << "Instance received at Compile: \n"
|
||||
<< GetComputationBackendText(instance) << std::endl;
|
||||
}
|
||||
|
||||
return instances;
|
||||
}
|
||||
|
||||
std::vector<BackendDataPtr>
|
||||
ExecuteComputation(Computation &computation,
|
||||
c10::ArrayRef<BackendDataPtr> arguments,
|
||||
const BackendDevice &device) const override {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
// `arguments` maps 1:1 with the parameters in the generated MLIR. In this
|
||||
// function, we will generate a list of BackendData that corresponds to the
|
||||
// return values in the MLIR.
|
||||
std::vector<torch::lazy::BackendDataPtr> results;
|
||||
|
||||
// "Borrow" some tensor data from arguments to reuse in return. This ensures
|
||||
// that the tensor device is correctly configured.
|
||||
TORCH_CHECK(arguments.size() > 0,
|
||||
"Need at least one argument for example execution.");
|
||||
const TorchMlirBackendData *torch_mlir_data =
|
||||
dynamic_cast<const TorchMlirBackendData *>(arguments[0].get());
|
||||
TORCH_CHECK(torch_mlir_data,
|
||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
||||
|
||||
// For this demo we aren't performing a legitimate execution, so we generate
|
||||
// some dummy data to return based on the expected number of return values.
|
||||
auto mlir_computation = static_cast<TorchMlirComputation *>(&computation);
|
||||
for (unsigned i = 0; i < mlir_computation->num_results(); i++) {
|
||||
results.push_back(std::make_shared<TorchMlirBackendData>(
|
||||
torch_mlir_data->mlir_info()->tensor, device,
|
||||
torch_mlir_data->shape()));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Device Configuration
|
||||
* */
|
||||
std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const {
|
||||
return std::make_shared<BackendDeviceType>(default_device_type_);
|
||||
}
|
||||
|
||||
void SetDefaultDeviceType(std::string device_type) {
|
||||
default_device_type_ = ExampleMlirBackendDeviceType(device_type);
|
||||
}
|
||||
|
||||
/**
|
||||
* Debug/Metrics
|
||||
* */
|
||||
std::string
|
||||
GetComputationBackendText(const ComputationPtr computation) const override {
|
||||
auto mlir_computation =
|
||||
static_cast<TorchMlirComputation *>(computation.get());
|
||||
return mlir_computation->to_string();
|
||||
}
|
||||
|
||||
private:
|
||||
ExampleMlirBackendDeviceType default_device_type_;
|
||||
};
|
||||
|
||||
BackendImplInterface *GetExampleMlirBackendImpl() {
|
||||
static ExampleMlirBackendImpl *example_mlir_backend_impl =
|
||||
new ExampleMlirBackendImpl();
|
||||
return example_mlir_backend_impl;
|
||||
}
|
||||
|
||||
void InitExampleMlirBackend() {
|
||||
at::RegisterTorchMlirLazyNativeFunctions();
|
||||
static std::unique_ptr<BackendRegistrar> g_registrar;
|
||||
g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl()));
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,27 @@
|
|||
//===- backend_impl.h -----------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/lazy/backend/backend_interface.h>
|
||||
|
||||
namespace at {
|
||||
// This function is defined in the codegenerated RegisterLazy.cpp file.
|
||||
TORCH_API void RegisterTorchMlirLazyNativeFunctions();
|
||||
} // namespace at
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl();
|
||||
|
||||
void InitExampleMlirBackend();
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,73 @@
|
|||
//===- example_mlir_backend_pybind.cpp ------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch/csrc/jit/python/pybind.h"
|
||||
#include "torch/csrc/lazy/backend/backend_interface.h"
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "backend/backend_impl.h"
|
||||
#include "utils/sys_utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
bool verbose = sys_util::GetEnv("VERBOSE", false);
|
||||
|
||||
struct NoGilSection {
|
||||
NoGilSection() : state(PyEval_SaveThread()) {}
|
||||
~NoGilSection() { PyEval_RestoreThread(state); }
|
||||
PyThreadState *state = nullptr;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Install the plugin
|
||||
*/
|
||||
void Initialize() {
|
||||
// Initialize the Example MLIR LTC Backend
|
||||
torch::lazy::InitExampleMlirBackend();
|
||||
|
||||
// sanity check
|
||||
const torch::lazy::BackendImplInterface *mlir_backend =
|
||||
torch::lazy::GetExampleMlirBackendImpl();
|
||||
const torch::lazy::BackendImplInterface *lazy_backend =
|
||||
torch::lazy::getBackend();
|
||||
if (lazy_backend != mlir_backend) {
|
||||
std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl;
|
||||
throw std::runtime_error("Failed to initialize MLIR Lazy Backend");
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Uninstall the plugin
|
||||
*/
|
||||
void Shutdown() {
|
||||
if (verbose) {
|
||||
std::cout << "MLIR LTC PyTorch Plugin Shut down." << std::endl;
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) {
|
||||
m.doc() = ("pybind11 for example MLIR LTC backend.");
|
||||
m.def("_initialize", []() {
|
||||
NoGilSection gil;
|
||||
Initialize();
|
||||
});
|
||||
m.def("_shutdown", []() {
|
||||
NoGilSection gil;
|
||||
Shutdown();
|
||||
});
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
//===- sys_utils.h --------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
namespace sys_util {
|
||||
|
||||
template <typename T>
|
||||
T GetEnv(const std::string &name, const T &default_value = T(0)) {
|
||||
const char *env = std::getenv(name.c_str());
|
||||
if (!env) {
|
||||
return default_value;
|
||||
}
|
||||
return T(std::atoi(env));
|
||||
}
|
||||
|
||||
} // namespace sys_util
|
|
@ -0,0 +1,86 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
"""
|
||||
Example use of the example Torch MLIR LTC backend.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def main(device):
|
||||
import torch
|
||||
|
||||
if device in ("TS", "MLIR_EXAMPLE"):
|
||||
import torch._lazy
|
||||
|
||||
if device == "TS":
|
||||
import torch._lazy.ts_backend
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
elif device == "MLIR_EXAMPLE":
|
||||
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
||||
|
||||
ltc_backend._initialize()
|
||||
|
||||
device = "lazy"
|
||||
print("Initialized backend")
|
||||
else:
|
||||
device = device.lower()
|
||||
|
||||
inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device)
|
||||
assert inputs.device.type == device
|
||||
|
||||
targets = torch.tensor([3], dtype=torch.int64, device=device)
|
||||
assert targets.device.type == device
|
||||
|
||||
print("Initialized data")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
model = Model().to(device)
|
||||
model.train()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
|
||||
print("Initialized model")
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if device == "lazy":
|
||||
print("Calling Mark Step")
|
||||
torch._lazy.mark_step()
|
||||
|
||||
print()
|
||||
print(loss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--device",
|
||||
type=str.upper,
|
||||
choices=["CPU", "TS", "MLIR_EXAMPLE"],
|
||||
default="MLIR_EXAMPLE",
|
||||
help="The device type",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.device)
|
|
@ -46,11 +46,18 @@ BackendData::Handle TorchMlirBackendData::GetHandle() {
|
|||
}
|
||||
|
||||
void TorchMlirBackendData::Assign(const BackendData& data) {
|
||||
const TorchMlirBackendData* torch_mlir_data =
|
||||
dynamic_cast<const TorchMlirBackendData*>(&data);
|
||||
TORCH_CHECK(
|
||||
torch_mlir_data,
|
||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
||||
|
||||
TorchMlirBackendData::Info* info =
|
||||
dynamic_cast<TorchMlirBackendData::Info*>(data.info());
|
||||
dynamic_cast<TorchMlirBackendData::Info*>(torch_mlir_data->mlir_info());
|
||||
TORCH_CHECK(
|
||||
info,
|
||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
||||
|
||||
info_ = std::make_unique<TorchMlirBackendData::Info>(*info);
|
||||
}
|
||||
|
||||
|
@ -92,11 +99,19 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
|||
const BackendDataPtr data,
|
||||
c10::optional<at::ScalarType> logical_scalar_type) const {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
TorchMlirBackendData* torch_mlir_data =
|
||||
dynamic_cast<TorchMlirBackendData*>(data.get());
|
||||
TORCH_CHECK(
|
||||
torch_mlir_data,
|
||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
||||
|
||||
TorchMlirBackendData::Info* info =
|
||||
dynamic_cast<TorchMlirBackendData::Info*>(data->info());
|
||||
dynamic_cast<TorchMlirBackendData::Info*>(torch_mlir_data->mlir_info());
|
||||
TORCH_CHECK(
|
||||
info,
|
||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
||||
|
||||
return info->tensor;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue