Add example Torch MLIR LTC Backend (#725)

pull/1125/head
Henry Tu 2022-04-14 12:53:00 -04:00 committed by Henry Tu
parent 3e9b1cbd36
commit a605fe279c
10 changed files with 434 additions and 2 deletions

View File

@ -192,3 +192,4 @@ else()
endif()
add_subdirectory(test)
add_subdirectory(examples)

View File

@ -0,0 +1 @@
add_subdirectory(ltc_backend)

View File

@ -0,0 +1,63 @@
###########################################################################
# Setup PyTorch
###########################################################################
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
include(TorchMLIRPyTorch)
TorchMLIRProbeForPyTorchInstall()
find_package(Torch 1.11 REQUIRED)
TorchMLIRConfigurePyTorch()
###########################################################################
# Setup Python development
###########################################################################
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/external/llvm-project/mlir/cmake/modules")
include(MLIRDetectPythonEnv)
mlir_configure_python_dev_packages()
###########################################################################
# Library definition
###########################################################################
include_directories(BEFORE
${TORCH_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
${Python3_INCLUDE_DIRS}
${PYTHON_H_DIR}
${PROJECT_SOURCE_DIR}/python
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/lib)
add_link_options(-Wl,-rpath,$ORIGIN/ltc_backend/lib)
file(GLOB LTC_BACKEND_CSRC CONFIGURE_DEPENDS
"ltc_backend/csrc/*.h"
"ltc_backend/csrc/*.cc"
"ltc_backend/csrc/*.cpp"
"ltc_backend/csrc/*/*.h"
"ltc_backend/csrc/*/*.cc"
"ltc_backend/csrc/*/*.cpp"
)
add_library(example_mlir_ltc_backend SHARED ${LTC_BACKEND_CSRC})
add_dependencies(example_mlir_ltc_backend
torch_mlir_ltc_backend
)
target_link_libraries(example_mlir_ltc_backend
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
torch_python
torch_mlir_ltc_backend
)
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
set_target_properties(example_mlir_ltc_backend PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/"
OUTPUT_NAME _EXAMPLE_MLIR_BACKEND
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
CXX_VISIBILITY_PRESET "hidden"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
)

View File

@ -0,0 +1,140 @@
//===- backend_impl.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/utils/debug.h>
#include <torch_mlir/csrc/utils/exception.h>
#include "backend_impl.h"
using namespace torch::lazy;
namespace torch {
namespace lazy {
struct ExampleMlirBackendDeviceType : public BackendDeviceType {
ExampleMlirBackendDeviceType(std::string device_type)
: device_type_(device_type) {}
std::string toString() const override { return device_type_; }
std::string device_type_;
};
class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl {
public:
ExampleMlirBackendImpl() : default_device_type_("Magic") {}
/**
* Configuration
* */
void SetRngSeed(size_t seed) const override {
std::cout << "RNG Seed Set to: " << seed << std::endl;
}
/**
* Lowering, Compilation, Execution
* */
std::vector<std::string>
GetCompilationDevices(const std::string &device,
c10::ArrayRef<std::string> devices) const override {
return std::vector<std::string>(devices.begin(), devices.end());
};
std::vector<ComputationPtr>
Compile(std::vector<ComputationPtr> instances) const override {
PRINT_FUNCTION();
// Vendor backend specific lowering can be exec here before returning.
for (const auto &instance : instances) {
std::cout << "Instance received at Compile: \n"
<< GetComputationBackendText(instance) << std::endl;
}
return instances;
}
std::vector<BackendDataPtr>
ExecuteComputation(Computation &computation,
c10::ArrayRef<BackendDataPtr> arguments,
const BackendDevice &device) const override {
PRINT_FUNCTION();
// `arguments` maps 1:1 with the parameters in the generated MLIR. In this
// function, we will generate a list of BackendData that corresponds to the
// return values in the MLIR.
std::vector<torch::lazy::BackendDataPtr> results;
// "Borrow" some tensor data from arguments to reuse in return. This ensures
// that the tensor device is correctly configured.
TORCH_CHECK(arguments.size() > 0,
"Need at least one argument for example execution.");
const TorchMlirBackendData *torch_mlir_data =
dynamic_cast<const TorchMlirBackendData *>(arguments[0].get());
TORCH_CHECK(torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
// For this demo we aren't performing a legitimate execution, so we generate
// some dummy data to return based on the expected number of return values.
auto mlir_computation = static_cast<TorchMlirComputation *>(&computation);
for (unsigned i = 0; i < mlir_computation->num_results(); i++) {
results.push_back(std::make_shared<TorchMlirBackendData>(
torch_mlir_data->mlir_info()->tensor, device,
torch_mlir_data->shape()));
}
return results;
}
/**
* Device Configuration
* */
std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const {
return std::make_shared<BackendDeviceType>(default_device_type_);
}
void SetDefaultDeviceType(std::string device_type) {
default_device_type_ = ExampleMlirBackendDeviceType(device_type);
}
/**
* Debug/Metrics
* */
std::string
GetComputationBackendText(const ComputationPtr computation) const override {
auto mlir_computation =
static_cast<TorchMlirComputation *>(computation.get());
return mlir_computation->to_string();
}
private:
ExampleMlirBackendDeviceType default_device_type_;
};
BackendImplInterface *GetExampleMlirBackendImpl() {
static ExampleMlirBackendImpl *example_mlir_backend_impl =
new ExampleMlirBackendImpl();
return example_mlir_backend_impl;
}
void InitExampleMlirBackend() {
at::RegisterTorchMlirLazyNativeFunctions();
static std::unique_ptr<BackendRegistrar> g_registrar;
g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl()));
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,27 @@
//===- backend_impl.h -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <torch/csrc/lazy/backend/backend_interface.h>
namespace at {
// This function is defined in the codegenerated RegisterLazy.cpp file.
TORCH_API void RegisterTorchMlirLazyNativeFunctions();
} // namespace at
namespace torch {
namespace lazy {
torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl();
void InitExampleMlirBackend();
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,73 @@
//===- example_mlir_backend_pybind.cpp ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch/csrc/jit/python/pybind.h"
#include "torch/csrc/lazy/backend/backend_interface.h"
#include <exception>
#include <iostream>
#include <string>
#include "backend/backend_impl.h"
#include "utils/sys_utils.h"
namespace py = pybind11;
namespace {
bool verbose = sys_util::GetEnv("VERBOSE", false);
struct NoGilSection {
NoGilSection() : state(PyEval_SaveThread()) {}
~NoGilSection() { PyEval_RestoreThread(state); }
PyThreadState *state = nullptr;
};
/**
* @brief Install the plugin
*/
void Initialize() {
// Initialize the Example MLIR LTC Backend
torch::lazy::InitExampleMlirBackend();
// sanity check
const torch::lazy::BackendImplInterface *mlir_backend =
torch::lazy::GetExampleMlirBackendImpl();
const torch::lazy::BackendImplInterface *lazy_backend =
torch::lazy::getBackend();
if (lazy_backend != mlir_backend) {
std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl;
throw std::runtime_error("Failed to initialize MLIR Lazy Backend");
}
if (verbose) {
std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl;
}
}
/**
* @brief Uninstall the plugin
*/
void Shutdown() {
if (verbose) {
std::cout << "MLIR LTC PyTorch Plugin Shut down." << std::endl;
}
}
} // anonymous namespace
PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) {
m.doc() = ("pybind11 for example MLIR LTC backend.");
m.def("_initialize", []() {
NoGilSection gil;
Initialize();
});
m.def("_shutdown", []() {
NoGilSection gil;
Shutdown();
});
}

View File

@ -0,0 +1,26 @@
//===- sys_utils.h --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <cstdlib>
#include <string>
namespace sys_util {
template <typename T>
T GetEnv(const std::string &name, const T &default_value = T(0)) {
const char *env = std::getenv(name.c_str());
if (!env) {
return default_value;
}
return T(std::atoi(env));
}
} // namespace sys_util

View File

@ -0,0 +1,86 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Example use of the example Torch MLIR LTC backend.
"""
import argparse
import torch.nn.functional as F
def main(device):
import torch
if device in ("TS", "MLIR_EXAMPLE"):
import torch._lazy
if device == "TS":
import torch._lazy.ts_backend
torch._lazy.ts_backend.init()
elif device == "MLIR_EXAMPLE":
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
ltc_backend._initialize()
device = "lazy"
print("Initialized backend")
else:
device = device.lower()
inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device)
assert inputs.device.type == device
targets = torch.tensor([3], dtype=torch.int64, device=device)
assert targets.device.type == device
print("Initialized data")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(5, 5)
def forward(self, x):
out = self.fc1(x)
out = F.relu(out)
return out
model = Model().to(device)
model.train()
assert all(p.device.type == device for p in model.parameters())
print("Initialized model")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if device == "lazy":
print("Calling Mark Step")
torch._lazy.mark_step()
print()
print(loss)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--device",
type=str.upper,
choices=["CPU", "TS", "MLIR_EXAMPLE"],
default="MLIR_EXAMPLE",
help="The device type",
)
args = parser.parse_args()
main(args.device)

View File

@ -46,11 +46,18 @@ BackendData::Handle TorchMlirBackendData::GetHandle() {
}
void TorchMlirBackendData::Assign(const BackendData& data) {
const TorchMlirBackendData* torch_mlir_data =
dynamic_cast<const TorchMlirBackendData*>(&data);
TORCH_CHECK(
torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(data.info());
dynamic_cast<TorchMlirBackendData::Info*>(torch_mlir_data->mlir_info());
TORCH_CHECK(
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
info_ = std::make_unique<TorchMlirBackendData::Info>(*info);
}
@ -92,11 +99,19 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const {
PRINT_FUNCTION();
TorchMlirBackendData* torch_mlir_data =
dynamic_cast<TorchMlirBackendData*>(data.get());
TORCH_CHECK(
torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(data->info());
dynamic_cast<TorchMlirBackendData::Info*>(torch_mlir_data->mlir_info());
TORCH_CHECK(
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
return info->tensor;
}