mirror of https://github.com/llvm/torch-mlir
Propagate device data names (#1157)
* Propagate device data names * Address PR comment * Add example usage * Add test for device data names * Make TorchMlirComputation fields protected * Add lazy backend device data name unit tests * Disable lazy backend tests if LTC is disabled * Add commentspull/1233/head
parent
84d345c650
commit
0af55781ae
|
@ -52,6 +52,12 @@ endif()
|
|||
# TODO: Reenable LTC once OOT build is successful (https://github.com/llvm/torch-mlir/issues/1154)
|
||||
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
||||
|
||||
if(TORCH_MLIR_ENABLE_LTC)
|
||||
set(ENV{TORCH_MLIR_ENABLE_LTC} 1)
|
||||
else()
|
||||
set(ENV{TORCH_MLIR_ENABLE_LTC} 0)
|
||||
endif()
|
||||
|
||||
torch_mlir_add_llvm_external_project(
|
||||
torch-mlir-dialects
|
||||
TORCH_MLIR_DIALECTS
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
|
||||
import torch
|
||||
import torch._lazy
|
||||
|
||||
import torch_mlir.reference_lazy_backend._REFERENCE_LAZY_BACKEND as lazy_backend
|
||||
|
||||
from run_test import run_test
|
||||
|
||||
lazy_backend._initialize()
|
||||
|
||||
device = "lazy"
|
||||
|
||||
|
||||
# CHECK: 0 input tensors found
|
||||
# -----
|
||||
# CHECK: PASS - test_no_device_data_name
|
||||
@run_test
|
||||
def test_no_device_data_name():
|
||||
x = torch.tensor(1).to(device)
|
||||
y = torch.tensor(2).to(device)
|
||||
z = x + y
|
||||
torch._lazy.mark_step()
|
||||
|
||||
|
||||
# CHECK: Input tensor: input_x
|
||||
# CHECK: 1 input tensors found
|
||||
# -----
|
||||
# CHECK: PASS - test_device_data_name
|
||||
@run_test
|
||||
def test_device_data_name():
|
||||
x = torch.tensor(1).to(device)
|
||||
y = torch.tensor(2).to(device)
|
||||
|
||||
lazy_backend.set_parameter_name(x, "input_x")
|
||||
|
||||
z = x + y
|
||||
torch._lazy.mark_step()
|
|
@ -0,0 +1,23 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: true
|
||||
|
||||
|
||||
def run_test(*args, XPASS=False, XFAIL=False):
|
||||
def _run_test(test):
|
||||
test_name = test.__name__
|
||||
try:
|
||||
test()
|
||||
print(("X" if XPASS else "") + f"PASS - {test_name}")
|
||||
except Exception as e:
|
||||
print(("X" if XFAIL else "") + f"FAIL - {test_name}")
|
||||
print("Errors: ", e)
|
||||
print(flush=True)
|
||||
|
||||
if len(args):
|
||||
_run_test(args[0])
|
||||
else:
|
||||
return _run_test
|
|
@ -51,6 +51,9 @@ llvm_config.use_default_substitutions()
|
|||
# directories.
|
||||
config.excludes = ['lit.cfg.py', 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
|
||||
|
||||
if not bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0))):
|
||||
config.excludes.append("lazy_backend")
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
|
|
|
@ -113,3 +113,25 @@ add_custom_command(
|
|||
COMMAND cp
|
||||
${PROJECT_SOURCE_DIR}/python/torch_mlir/csrc/base_lazy_backend/generated/*.h
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/generated/)
|
||||
|
||||
add_custom_command(
|
||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||
COMMAND mkdir -p
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/)
|
||||
|
||||
add_custom_command(
|
||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||
COMMAND cp
|
||||
${PROJECT_SOURCE_DIR}/python/torch_mlir/csrc/base_lazy_backend/ops/*.h
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/ops/)
|
||||
|
||||
add_custom_command(
|
||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||
COMMAND mkdir -p
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/)
|
||||
|
||||
add_custom_command(
|
||||
TARGET torch_mlir_ltc_backend POST_BUILD
|
||||
COMMAND cp
|
||||
${PROJECT_SOURCE_DIR}/python/torch_mlir/csrc/base_lazy_backend/utils/*.h
|
||||
${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/base_lazy_backend/utils/)
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||
#include <torch/csrc/lazy/backend/backend_interface.h>
|
||||
|
@ -29,13 +31,20 @@ public:
|
|||
at::Tensor tensor;
|
||||
c10::optional<at::Scalar> scalar;
|
||||
bool requires_grad;
|
||||
std::string name;
|
||||
|
||||
Info() {}
|
||||
Info(const Info& other)
|
||||
: tensor{other.tensor}, scalar{other.scalar},
|
||||
requires_grad{other.requires_grad} {}
|
||||
requires_grad{other.requires_grad}, name{other.name} {}
|
||||
Info(const at::Tensor& tensor)
|
||||
: tensor{tensor}, requires_grad{tensor.requires_grad()} {}
|
||||
: tensor{tensor}, requires_grad{tensor.requires_grad()} {
|
||||
static int num_tensors = 0;
|
||||
std::ostringstream oss;
|
||||
oss << "tensor" << num_tensors;
|
||||
this->name = oss.str();
|
||||
++num_tensors;
|
||||
}
|
||||
Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {}
|
||||
};
|
||||
|
||||
|
|
|
@ -135,11 +135,11 @@ public:
|
|||
|
||||
MlirOperation func_op() const;
|
||||
|
||||
const std::string debug_string() const;
|
||||
virtual const std::string debug_string() const;
|
||||
|
||||
const std::string to_string() const override;
|
||||
virtual const std::string to_string() const override;
|
||||
|
||||
private:
|
||||
protected:
|
||||
std::vector<std::string> parameter_names_;
|
||||
std::vector<Shape> parameter_shapes_;
|
||||
Shape result_shape_;
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||
|
||||
#include "device_data.h"
|
||||
#include "../backend_impl.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
@ -13,11 +14,37 @@ DeviceData::DeviceData(std::shared_ptr<BackendData> data)
|
|||
data->shape(),
|
||||
/*num_outputs=*/1,
|
||||
/*hash_seed=*/static_cast<uint32_t>(101)),
|
||||
data_(std::move(data)) {}
|
||||
data_(std::move(data)) {
|
||||
propagate_name();
|
||||
}
|
||||
|
||||
void DeviceData::propagate_name() {
|
||||
if (data_ && name_ != "") {
|
||||
// Add device data name to backend data
|
||||
TorchMlirBackendData* mlir_data = dynamic_cast<TorchMlirBackendData*>(data_.get());
|
||||
TORCH_CHECK(mlir_data);
|
||||
TorchMlirBackendData::Info* info = mlir_data->mlir_info();
|
||||
TORCH_CHECK(info);
|
||||
info->name = name_;
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceData::SetData(std::shared_ptr<BackendData> data) {
|
||||
data_ = data;
|
||||
propagate_name();
|
||||
}
|
||||
|
||||
void DeviceData::SetName(const std::string& name) {
|
||||
name_ = name;
|
||||
propagate_name();
|
||||
}
|
||||
|
||||
std::string DeviceData::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << TorchMlirNode::ToString() << ", device=" << data_->device();
|
||||
if (name_ != "") {
|
||||
ss << ", name=" << name_;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -27,13 +27,9 @@ class TORCH_API DeviceData : public TorchMlirNode {
|
|||
|
||||
std::string ToString() const override;
|
||||
|
||||
const std::shared_ptr<BackendData>& data() const {
|
||||
return data_;
|
||||
}
|
||||
const std::shared_ptr<BackendData>& data() const { return data_; }
|
||||
|
||||
void SetData(std::shared_ptr<BackendData> data) {
|
||||
data_ = data;
|
||||
}
|
||||
void SetData(std::shared_ptr<BackendData> data);
|
||||
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override;
|
||||
|
||||
|
@ -43,8 +39,14 @@ class TORCH_API DeviceData : public TorchMlirNode {
|
|||
// instead of calling the constructor directly.
|
||||
static NodePtr Create(std::shared_ptr<BackendData> data);
|
||||
|
||||
const std::string& GetName() const { return name_; }
|
||||
void SetName(const std::string& name);
|
||||
|
||||
private:
|
||||
void propagate_name();
|
||||
|
||||
std::shared_ptr<BackendData> data_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::ostream& string_join(std::ostream& out, const std::vector<T>& v, const std::string& delimiter) {
|
||||
size_t i = 0;
|
||||
for (const T& e : v) {
|
||||
if ((i++) > 0) { out << delimiter; }
|
||||
out << e;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string string_join(const std::vector<T>& v, const std::string& delimiter) {
|
||||
std::ostringstream joined;
|
||||
string_join(joined, v, delimiter);
|
||||
return joined.str();
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Returns true if str starts with prefix
|
||||
*/
|
||||
inline bool startswith(const std::string& str, const std::string& prefix) {
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
#pragma once
|
||||
|
||||
#include "torch/csrc/lazy/backend/backend_device.h"
|
||||
#include "torch/csrc/lazy/core/tensor.h"
|
||||
|
||||
#include "../ops/device_data.h"
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
inline torch::lazy::DeviceData* device_data_cast(
|
||||
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt
|
||||
) {
|
||||
if (!device) {
|
||||
device = torch::lazy::GetBackendDevice(tensor);
|
||||
}
|
||||
TORCH_CHECK(device);
|
||||
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device);
|
||||
if (lazy_tensor) {
|
||||
torch::lazy::Value param_value = lazy_tensor->GetIrValue();
|
||||
if (param_value && param_value->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
||||
return dynamic_cast<torch::lazy::DeviceData*>(param_value.node.get());
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -18,6 +18,7 @@
|
|||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/debug.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/exception.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
|
||||
|
||||
#include "backend_impl.h"
|
||||
|
||||
|
@ -88,6 +89,8 @@ public:
|
|||
auto mlir_computation =
|
||||
static_cast<TorchMlirComputation*>(computation.get());
|
||||
|
||||
int num_inputs = 0;
|
||||
|
||||
// Vendor backend specific execution can be inserted here.
|
||||
//
|
||||
// We don't have a way to execute a computation based on the generated MLIR,
|
||||
|
@ -106,7 +109,17 @@ public:
|
|||
at::Tensor tensor = mlir_data->mlir_info()->tensor;
|
||||
stack.emplace_back(tensor);
|
||||
}
|
||||
|
||||
// count number of inputs
|
||||
auto name = mlir_data->mlir_info()->name;
|
||||
if (startswith(name, "input_")) {
|
||||
// Printing tensor name for testing purposes
|
||||
std::cout << "Input tensor: " << name << std::endl;
|
||||
++num_inputs;
|
||||
}
|
||||
}
|
||||
// Printing number of input tensors for testing purposes
|
||||
std::cout << num_inputs << " input tensors found" << std::endl;
|
||||
graph_executor.run(stack);
|
||||
std::vector<torch::lazy::BackendDataPtr> results;
|
||||
for (torch::jit::IValue component : stack) {
|
||||
|
|
|
@ -11,7 +11,9 @@
|
|||
#include "torch/csrc/lazy/backend/backend_interface.h"
|
||||
|
||||
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/string_utils.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h>
|
||||
#include <torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h>
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
|
@ -73,6 +75,15 @@ PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) {
|
|||
torch::lazy::GetLatestComputation().get());
|
||||
return py::cast(computation);
|
||||
});
|
||||
m.def("set_parameter_name",
|
||||
[](const at::Tensor& tensor, const std::string& name) -> bool {
|
||||
torch::lazy::DeviceData* ir_node = torch::lazy::device_data_cast(tensor);
|
||||
if (ir_node) {
|
||||
ir_node->SetName(name);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
m.def("_initialize", []() {
|
||||
NoGilSection gil;
|
||||
Initialize();
|
||||
|
|
10
setup.py
10
setup.py
|
@ -20,6 +20,10 @@
|
|||
# prevent this script from attempting to build the directory, and will simply
|
||||
# use the (presumed already built) directory as-is.
|
||||
#
|
||||
# By default the lazy tensor backend is disabled and not built to avoid conflicts
|
||||
# with the out-of-tree build. To enable it, set the TORCH_MLIR_ENABLE_LTC
|
||||
# environment variable to 1.
|
||||
#
|
||||
# The package version can be set with the TORCH_MLIR_PYTHON_PACKAGE_VERSION
|
||||
# environment variable. For example, this can be "20220330.357" for a snapshot
|
||||
# release on 2022-03-30 with build number 357.
|
||||
|
@ -64,7 +68,6 @@ class CMakeBuild(build_py):
|
|||
python_package_dir = os.path.join(cmake_build_dir,
|
||||
"tools", "torch-mlir", "python_packages",
|
||||
"torch_mlir")
|
||||
|
||||
if not os.getenv("TORCH_MLIR_CMAKE_BUILD_DIR_ALREADY_BUILT"):
|
||||
src_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
llvm_dir = os.path.join(
|
||||
|
@ -83,6 +86,11 @@ class CMakeBuild(build_py):
|
|||
f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
|
||||
f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden",
|
||||
]
|
||||
# TODO: Enable LTC by default once JIT importer linkage issue is fixed (https://github.com/llvm/torch-mlir/issues/1154)
|
||||
enable_ltc = bool(int(os.environ.get("TORCH_MLIR_ENABLE_LTC", 0)))
|
||||
if not enable_ltc:
|
||||
cmake_args.append("-DTORCH_MLIR_ENABLE_LTC=OFF")
|
||||
|
||||
os.makedirs(cmake_build_dir, exist_ok=True)
|
||||
cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt")
|
||||
if os.path.exists(cmake_cache_file):
|
||||
|
|
Loading…
Reference in New Issue