mirror of https://github.com/llvm/torch-mlir
Add initial LTC backend (#610)
* Add initial LTC backend skeleton * Disable CI build and move TorchMLIRPyTorch.cmakepull/1125/head
parent
8b5631d4c5
commit
2f22e2ef40
|
@ -5,6 +5,8 @@ on:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
|
@ -56,6 +56,12 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
|
||||||
LLVMSupport
|
LLVMSupport
|
||||||
)
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Lazy Tensor Core
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
add_subdirectory(torch_mlir/csrc)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Optionally handle JIT IR importer.
|
# Optionally handle JIT IR importer.
|
||||||
################################################################################
|
################################################################################
|
||||||
|
@ -147,5 +153,8 @@ endif()
|
||||||
# TODO: Add after macOS builds are fixed
|
# TODO: Add after macOS builds are fixed
|
||||||
#add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example)
|
#add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example)
|
||||||
|
|
||||||
|
# Add Torch-MLIR LTC backend as dependency
|
||||||
|
add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend)
|
||||||
|
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
# Setup PyTorch/LTC
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
|
||||||
|
include(TorchMLIRPyTorch)
|
||||||
|
TorchMLIRProbeForPyTorchInstall()
|
||||||
|
find_package(Torch 1.11 REQUIRED)
|
||||||
|
|
||||||
|
TorchMLIRConfigurePyTorch()
|
||||||
|
|
||||||
|
include_directories(BEFORE
|
||||||
|
${TORCH_INCLUDE_DIRS}
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}
|
||||||
|
${Python3_INCLUDE_DIRS}
|
||||||
|
)
|
||||||
|
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
|
|
||||||
|
|
||||||
|
add_library(torch_mlir_ltc_backend SHARED
|
||||||
|
backend/backend_impl.cc
|
||||||
|
backend/mlir_lowering_context.cc
|
||||||
|
backend/mlir_node.cc
|
||||||
|
)
|
||||||
|
|
||||||
|
target_link_libraries(torch_mlir_ltc_backend
|
||||||
|
TorchMLIRAggregateCAPI
|
||||||
|
${TORCH_LIBRARIES}
|
||||||
|
${Python3_LIBRARIES}
|
||||||
|
torch_python
|
||||||
|
)
|
||||||
|
|
||||||
|
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wpedantic")
|
||||||
|
set_target_properties(torch_mlir_ltc_backend PROPERTIES
|
||||||
|
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/"
|
||||||
|
OUTPUT_NAME _MLIR_LTC
|
||||||
|
PREFIX "${PYTHON_MODULE_PREFIX}"
|
||||||
|
SUFFIX "${PYTHON_MODULE_EXTENSION}"
|
||||||
|
CXX_VISIBILITY_PRESET "hidden"
|
||||||
|
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wpedantic"
|
||||||
|
)
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Torch-MLIR Lazy Tensor Core Backend
|
||||||
|
|
||||||
|
Contained within this directory are the components that implements the
|
||||||
|
Torch-MLIR LTC backend.
|
||||||
|
|
||||||
|
The components are subclasses of the backend API interface classes found under
|
||||||
|
[torch/csrc/lazy/backend](https://github.com/pytorch/pytorch/tree/master/torch/csrc/lazy/backend).
|
||||||
|
|
||||||
|
Importantly, the subclasses are still abstract classes. Pure virtual methods
|
||||||
|
such as `Compile` were purposefully not overriden as Torch-MLIR does not know
|
||||||
|
how to compile the model for the target hardware.
|
||||||
|
|
||||||
|
The intent is that vendor hardware specific plugins will subclass the Torch-MLIR
|
||||||
|
backend classes and override the remaining pure virtual functions to complete
|
||||||
|
the backend.
|
||||||
|
|
||||||
|
The Torch-MLIR LTC backend's job is to perform the lowering from ATen to MLIR. A
|
||||||
|
hardware vendor's backend job is to take care of the actual compile and
|
||||||
|
execution of the lowered MLIR.
|
|
@ -0,0 +1,159 @@
|
||||||
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
|
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||||
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
|
#include "backend_impl.h"
|
||||||
|
#include "mlir_lowering_context.h"
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
struct MlirBackendData::Info : public BackendData::Info {
|
||||||
|
at::Tensor tensor;
|
||||||
|
c10::optional<at::Scalar> scalar;
|
||||||
|
|
||||||
|
Info() {}
|
||||||
|
Info(const Info& other) :
|
||||||
|
tensor{other.tensor}, scalar{other.scalar} {}
|
||||||
|
Info(const at::Tensor& tensor) : tensor{tensor} {}
|
||||||
|
Info(const at::Scalar& scalar) : scalar{scalar} {}
|
||||||
|
};
|
||||||
|
|
||||||
|
MlirBackendData::MlirBackendData(BackendDevice device, Shape shape) :
|
||||||
|
BackendData(device, shape) {
|
||||||
|
auto info = std::make_shared<MlirBackendData::Info>();
|
||||||
|
SetInfo(info);
|
||||||
|
}
|
||||||
|
MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device) :
|
||||||
|
BackendData(device, torch::lazy::Shape(scalar.type(), {})) {
|
||||||
|
auto info = std::make_shared<MlirBackendData::Info>(scalar);
|
||||||
|
SetInfo(info);
|
||||||
|
}
|
||||||
|
MlirBackendData::MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape) :
|
||||||
|
BackendData(device, shape) {
|
||||||
|
auto info = std::make_shared<MlirBackendData::Info>(tensor);
|
||||||
|
SetInfo(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
BackendData::Handle MlirBackendData::GetHandle() { return reinterpret_cast<int64_t>(this); }
|
||||||
|
|
||||||
|
void MlirBackendData::Assign(const BackendData& data) {
|
||||||
|
MlirBackendData::Info* info =
|
||||||
|
dynamic_cast<MlirBackendData::Info*>(data.info());
|
||||||
|
TORCH_CHECK(
|
||||||
|
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info."
|
||||||
|
);
|
||||||
|
auto new_info = std::make_shared<MlirBackendData::Info>(*info);
|
||||||
|
SetInfo(new_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MlirBackendData::HasValue() const {
|
||||||
|
return bool(info());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialization/Teardown
|
||||||
|
* */
|
||||||
|
void MlirBackendImpl::PrepareToExit() const {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data Transfer
|
||||||
|
* */
|
||||||
|
|
||||||
|
BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor(
|
||||||
|
const at::Tensor& tensor,
|
||||||
|
const Shape& shape,
|
||||||
|
const BackendDevice& device
|
||||||
|
) const {
|
||||||
|
return std::make_shared<MlirBackendData>(tensor, device, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar(
|
||||||
|
const at::Scalar& scalar,
|
||||||
|
const torch::lazy::BackendDevice& device
|
||||||
|
) const {
|
||||||
|
return std::make_shared<MlirBackendData>(scalar, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
BackendDataPtr MlirBackendImpl::CreateDataPlaceholder(
|
||||||
|
const BackendDevice& device, const Shape& shape
|
||||||
|
) const {
|
||||||
|
return std::make_shared<MlirBackendData>(device, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
|
||||||
|
const BackendDataPtr data,
|
||||||
|
c10::optional<at::ScalarType> logical_scalar_type
|
||||||
|
) const {
|
||||||
|
MlirBackendData::Info* info =
|
||||||
|
dynamic_cast<MlirBackendData::Info*>(data->info());
|
||||||
|
TORCH_CHECK(
|
||||||
|
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info."
|
||||||
|
);
|
||||||
|
return info->tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lowering, Compilation, Execution
|
||||||
|
* */
|
||||||
|
|
||||||
|
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
|
||||||
|
const std::string& name,
|
||||||
|
BackendDevice device,
|
||||||
|
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||||
|
Util::EmissionMap emit_status
|
||||||
|
) const {
|
||||||
|
return std::make_unique<MlirLoweringContext>(
|
||||||
|
name,
|
||||||
|
std::forward<BackendDevice>(device),
|
||||||
|
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
||||||
|
std::forward<Util::EmissionMap>(emit_status)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
|
||||||
|
const std::string& name, BackendDevice device
|
||||||
|
) const {
|
||||||
|
return std::make_unique<MlirLoweringContext>(
|
||||||
|
name, std::forward<BackendDevice>(device)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Device Configuration
|
||||||
|
* */
|
||||||
|
|
||||||
|
// Set or get the default device type.
|
||||||
|
// For backends used with virtual c10:: Devices, this configures what real
|
||||||
|
// device type the backend should use, and matters if the backend supports
|
||||||
|
// more than one type of real device.
|
||||||
|
|
||||||
|
// Specify which aten device should be used for eager fallback
|
||||||
|
// may change depending on current 'Default' DeviceType
|
||||||
|
at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const {
|
||||||
|
return at::DeviceType::CPU;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Query all available backend devices
|
||||||
|
std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
|
||||||
|
return {
|
||||||
|
GetBackendDevice(c10::Device(c10::kCPU, 0)),
|
||||||
|
GetBackendDevice(c10::Device(c10::kLazy, 0))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map a particular c10:: device to a concrete backend device
|
||||||
|
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
|
||||||
|
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
|
||||||
|
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
|
||||||
|
// through a mode, in which case these APIs should still work, but should be
|
||||||
|
// identity mappings.
|
||||||
|
BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const {
|
||||||
|
return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // lazy
|
||||||
|
} // torch
|
|
@ -0,0 +1,136 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
|
#include <torch/csrc/lazy/backend/backend_device.h>
|
||||||
|
#include <torch/csrc/lazy/backend/backend_interface.h>
|
||||||
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
class MlirBackendData : public torch::lazy::BackendData {
|
||||||
|
public:
|
||||||
|
struct Info;
|
||||||
|
|
||||||
|
MlirBackendData(torch::lazy::BackendDevice device, torch::lazy::Shape shape);
|
||||||
|
MlirBackendData(const at::Scalar& scalar, torch::lazy::BackendDevice device);
|
||||||
|
MlirBackendData(const at::Tensor& tensor, torch::lazy::BackendDevice device, torch::lazy::Shape shape);
|
||||||
|
|
||||||
|
virtual torch::lazy::BackendData::Handle GetHandle() override;
|
||||||
|
|
||||||
|
virtual void Assign(const torch::lazy::BackendData& data) override;
|
||||||
|
|
||||||
|
virtual bool HasValue() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MlirBackendImpl : public torch::lazy::BackendImplInterface {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* Initialization/Teardown
|
||||||
|
* */
|
||||||
|
virtual void PrepareToExit() const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration
|
||||||
|
* */
|
||||||
|
// virtual void SetRngSeed(size_t seed) const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data Transfer
|
||||||
|
* */
|
||||||
|
|
||||||
|
virtual torch::lazy::BackendDataPtr MakeComputationDataFromTensor(
|
||||||
|
const at::Tensor& tensor,
|
||||||
|
const torch::lazy::Shape& shape,
|
||||||
|
const torch::lazy::BackendDevice& device
|
||||||
|
) const override;
|
||||||
|
|
||||||
|
virtual torch::lazy::BackendDataPtr MakeComputationDataFromScalar(
|
||||||
|
const at::Scalar& scalar,
|
||||||
|
const torch::lazy::BackendDevice& device
|
||||||
|
) const override;
|
||||||
|
|
||||||
|
virtual torch::lazy::BackendDataPtr CreateDataPlaceholder(
|
||||||
|
const torch::lazy::BackendDevice& device, const torch::lazy::Shape& shape
|
||||||
|
) const override;
|
||||||
|
|
||||||
|
virtual at::Tensor MakeTensorFromComputationData(
|
||||||
|
const torch::lazy::BackendDataPtr data,
|
||||||
|
c10::optional<at::ScalarType> logical_scalar_type
|
||||||
|
) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lowering, Compilation, Execution
|
||||||
|
* */
|
||||||
|
|
||||||
|
virtual std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
|
||||||
|
const std::string& name,
|
||||||
|
torch::lazy::BackendDevice device,
|
||||||
|
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||||
|
torch::lazy::Util::EmissionMap emit_status
|
||||||
|
) const override;
|
||||||
|
|
||||||
|
virtual std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
|
||||||
|
const std::string& name, torch::lazy::BackendDevice device
|
||||||
|
) const override;
|
||||||
|
|
||||||
|
// TODO(whc) need to keep this?
|
||||||
|
// virtual std::vector<std::string> GetCompilationDevices(
|
||||||
|
// const std::string& device, c10::ArrayRef<std::string> devices
|
||||||
|
// ) const = 0;
|
||||||
|
|
||||||
|
// virtual std::vector<torch::lazy::ComputationPtr> Compile(
|
||||||
|
// std::vector<torch::lazy::ComputationPtr> instances
|
||||||
|
// ) const = 0;
|
||||||
|
|
||||||
|
// virtual std::vector<torch::lazy::BackendDataPtr> ExecuteComputation(
|
||||||
|
// torch::lazy::Computation& computation,
|
||||||
|
// c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
|
||||||
|
// const torch::lazy::BackendDevice& device
|
||||||
|
// ) const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Device Configuration
|
||||||
|
* */
|
||||||
|
|
||||||
|
// Set or get the default device type.
|
||||||
|
// For backends used with virtual c10:: Devices, this configures what real
|
||||||
|
// device type the backend should use, and matters if the backend supports
|
||||||
|
// more than one type of real device.
|
||||||
|
|
||||||
|
// virtual std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const = 0;
|
||||||
|
// virtual void SetDefaultDeviceType(std::string device_type) = 0;
|
||||||
|
|
||||||
|
// Specify which aten device should be used for eager fallback
|
||||||
|
// may change depending on current 'Default' DeviceType
|
||||||
|
virtual at::DeviceType EagerFallbackDeviceType() const override;
|
||||||
|
|
||||||
|
|
||||||
|
// Query all available backend devices
|
||||||
|
virtual std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override;
|
||||||
|
|
||||||
|
// Map a particular c10:: device to a concrete backend device
|
||||||
|
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
|
||||||
|
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
|
||||||
|
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
|
||||||
|
// through a mode, in which case these APIs should still work, but should be
|
||||||
|
// identity mappings.
|
||||||
|
virtual torch::lazy::BackendDevice GetBackendDevice(c10::Device device) const override;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Debug/Metrics
|
||||||
|
* */
|
||||||
|
|
||||||
|
// virtual std::map<std::string, Metric> GetMetrics() const = 0;
|
||||||
|
|
||||||
|
// virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;
|
||||||
|
|
||||||
|
// virtual std::string GetComputationBackendText(
|
||||||
|
// const torch::lazy::ComputationPtr computation
|
||||||
|
// ) const = 0;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
} // lazy
|
||||||
|
} // torch
|
|
@ -0,0 +1,93 @@
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlir_lowering_context.h"
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
MlirLoweringContext::MlirLoweringContext(
|
||||||
|
const std::string& name, BackendDevice device
|
||||||
|
) : LoweringContext(name, std::forward<BackendDevice>(device)) {}
|
||||||
|
|
||||||
|
MlirLoweringContext::MlirLoweringContext(
|
||||||
|
const std::string& name,
|
||||||
|
BackendDevice device,
|
||||||
|
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||||
|
Util::EmissionMap emit_status
|
||||||
|
) : LoweringContext(
|
||||||
|
name,
|
||||||
|
std::forward<BackendDevice>(device),
|
||||||
|
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
||||||
|
std::forward<Util::EmissionMap>(emit_status)
|
||||||
|
) {}
|
||||||
|
|
||||||
|
int MlirComputation::parameters_size() const {
|
||||||
|
UNIMPLEMENTED_ERROR("MlirComputation::parameters_size");
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<torch::lazy::Shape>& MlirComputation::parameter_shapes() const {
|
||||||
|
UNIMPLEMENTED_ERROR("MlirComputation::parameter_shapes");
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::string>& MlirComputation::parameter_names() const {
|
||||||
|
UNIMPLEMENTED_ERROR("MlirComputation::parameter_names");
|
||||||
|
}
|
||||||
|
|
||||||
|
const torch::lazy::Shape& MlirComputation::result_shape() const {
|
||||||
|
UNIMPLEMENTED_ERROR("MlirComputation::result_shape");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Get the shape of the result tuple component, given by index.
|
||||||
|
torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const {
|
||||||
|
UNIMPLEMENTED_ERROR("MlirLoweringContext::GetResultShape( " << index << " )");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds the given output as a component of the result tuple and returns its
|
||||||
|
// assigned position within the tuple.
|
||||||
|
size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) {
|
||||||
|
const torch::lazy::Node* node;
|
||||||
|
auto it = emitted_outputs_.find(output);
|
||||||
|
if (it == emitted_outputs_.end()) {
|
||||||
|
node = output.node;
|
||||||
|
|
||||||
|
auto post_order = Util::ComputePostOrder(node, &emit_status_);
|
||||||
|
for (auto po_node : post_order) {
|
||||||
|
// TODO: uncomment after lowering is implemented
|
||||||
|
// bool ok = lowering_->Lower(node);
|
||||||
|
// TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
|
||||||
|
}
|
||||||
|
emitted_outputs_[output] = node;
|
||||||
|
} else {
|
||||||
|
node = it->second;
|
||||||
|
}
|
||||||
|
result_tuple_.emplace_back(node);
|
||||||
|
return result_tuple_.size() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associates the given output with the input parameter of the given index and
|
||||||
|
// shape. Only used for the operator-by-operator execution, mostly for
|
||||||
|
// debugging purposes.
|
||||||
|
void MlirLoweringContext::AddParameter(
|
||||||
|
const torch::lazy::Output& output,
|
||||||
|
size_t index,
|
||||||
|
const torch::lazy::Shape& shape,
|
||||||
|
const std::string& name
|
||||||
|
) {
|
||||||
|
UNIMPLEMENTED_ERROR("MlirLoweringContext::AddParameter");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the computation capturing all the operations created with the
|
||||||
|
// embedded builder (returned by the builder() API).
|
||||||
|
ComputationPtr MlirLoweringContext::Build() {
|
||||||
|
for (const torch::lazy::Node* output : result_tuple_) {
|
||||||
|
|
||||||
|
}
|
||||||
|
return std::make_shared<MlirComputation>();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,55 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
class MlirComputation : public torch::lazy::Computation {
|
||||||
|
public:
|
||||||
|
int parameters_size() const override;
|
||||||
|
|
||||||
|
virtual const std::vector<torch::lazy::Shape>& parameter_shapes() const override;
|
||||||
|
|
||||||
|
virtual const std::vector<std::string>& parameter_names() const override;
|
||||||
|
|
||||||
|
virtual const torch::lazy::Shape& result_shape() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MlirLoweringContext : public torch::lazy::LoweringContext {
|
||||||
|
public:
|
||||||
|
|
||||||
|
MlirLoweringContext(const std::string& name, torch::lazy::BackendDevice device);
|
||||||
|
MlirLoweringContext(const std::string& name,
|
||||||
|
torch::lazy::BackendDevice device,
|
||||||
|
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||||
|
torch::lazy::Util::EmissionMap emit_status);
|
||||||
|
|
||||||
|
// Get the shape of the result tuple component, given by index.
|
||||||
|
virtual torch::lazy::Shape GetResultShape(size_t index) const override;
|
||||||
|
|
||||||
|
// Adds the given output as a component of the result tuple and returns its
|
||||||
|
// assigned position within the tuple.
|
||||||
|
virtual size_t AddResult(const torch::lazy::Output& output) override;
|
||||||
|
|
||||||
|
// Associates the given output with the input parameter of the given index and
|
||||||
|
// shape. Only used for the operator-by-operator execution, mostly for
|
||||||
|
// debugging purposes.
|
||||||
|
virtual void AddParameter(const torch::lazy::Output& output,
|
||||||
|
size_t index,
|
||||||
|
const torch::lazy::Shape& shape,
|
||||||
|
const std::string& name) override;
|
||||||
|
|
||||||
|
// Build the computation capturing all the operations created with the
|
||||||
|
// embedded builder (returned by the builder() API).
|
||||||
|
virtual torch::lazy::ComputationPtr Build() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<const torch::lazy::Node*> result_tuple_;
|
||||||
|
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,124 @@
|
||||||
|
#include <torch/csrc/lazy/core/cache.h>
|
||||||
|
|
||||||
|
#include "mlir_node.h"
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
hash_t OperandHashes(const OpList& operands, const hash_t& seed, const bool bakeInSizes) {
|
||||||
|
hash_t hash = seed;
|
||||||
|
for (auto& operand : operands) {
|
||||||
|
if (!operand) {
|
||||||
|
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto operand_hash = bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes();
|
||||||
|
hash = HashCombine(hash, operand_hash);
|
||||||
|
}
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
hash_t GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) {
|
||||||
|
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
|
||||||
|
return HashCombine(h, hash_seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
|
MlirNode::MlirNode(
|
||||||
|
OpKind op, OpList operands, std::vector<Shape>&& shapes,
|
||||||
|
size_t num_outputs, hash_t hash_seed
|
||||||
|
) : Node(
|
||||||
|
op, num_outputs,
|
||||||
|
/* node_hash */ HashCombine(op.hash(), hash_seed),
|
||||||
|
/* dag_hash */
|
||||||
|
[&](bool bakeInSizes) -> hash_t {
|
||||||
|
return OperandHashes(operands, HashCombine(op.hash(), hash_seed), bakeInSizes);
|
||||||
|
}
|
||||||
|
),
|
||||||
|
shapes_(std::move(shapes)) {
|
||||||
|
|
||||||
|
for (auto& operand : operands) {
|
||||||
|
// Ideally, optional operands should be filtered by the leaf node classes,
|
||||||
|
// but it's just much easier to do it here.
|
||||||
|
if (!operand) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
AddOperand(operand.node, operand.index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirNode::MlirNode(
|
||||||
|
OpKind op, OpList operands,
|
||||||
|
const std::function<Shape()>& shape_fn,
|
||||||
|
size_t num_outputs, hash_t hash_seed
|
||||||
|
) : MlirNode(
|
||||||
|
op, operands, std::vector<Shape>{}, num_outputs, hash_seed
|
||||||
|
) {
|
||||||
|
shapes_.push_back(GetOpShape(shape_fn));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirNode::MlirNode(
|
||||||
|
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed
|
||||||
|
) : MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
|
||||||
|
|
||||||
|
void MlirNode::SetShapeDeferred(
|
||||||
|
const std::function<Shape()>& shape_fn
|
||||||
|
) {
|
||||||
|
shapes_.push_back(GetOpShape(shape_fn));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
|
||||||
|
: Node(
|
||||||
|
op, num_outputs,
|
||||||
|
[&](bool bakeInSizes) -> hash_t {
|
||||||
|
return GetOpHash(op, shape, hash_seed, bakeInSizes);
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
shapes_.push_back(std::move(shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
using ShapeCache = Cache<hash_t, Shape, HashReducer>;
|
||||||
|
|
||||||
|
constexpr const int torch_lazy_shape_cache_size = 4096;
|
||||||
|
|
||||||
|
ShapeCache* GetShapeCache() {
|
||||||
|
static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape MlirNode::GetOpShape(const std::function<Shape()>& shape_fn) const {
|
||||||
|
ShapeCache* shape_cache = GetShapeCache();
|
||||||
|
auto shape = shape_cache->Get(hash());
|
||||||
|
if (shape == nullptr) {
|
||||||
|
shape = shape_cache->Add(
|
||||||
|
hash(), std::make_shared<Shape>(shape_fn())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return *shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const std::vector<Output>& MlirNode::operands() const {
|
||||||
|
return operands_as_outputs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Output& MlirNode::operand(size_t i) const {
|
||||||
|
return operands_as_outputs_.at(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MlirNode::AddOperand(NodePtr node, size_t index) {
|
||||||
|
CHECK_LT(index, node->num_outputs());
|
||||||
|
operands_.push_back(std::move(node));
|
||||||
|
operands_as_outputs_.emplace_back(operands_.back().get(), index);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,71 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/core/interned_strings.h>
|
||||||
|
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||||
|
#include <torch/csrc/lazy/core/shape.h>
|
||||||
|
#include <torch/csrc/lazy/core/ir.h>
|
||||||
|
|
||||||
|
#include "mlir_lowering_context.h"
|
||||||
|
#include "../utils/exception.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
typedef std::vector<NodePtr> MlirOpVector;
|
||||||
|
typedef NodePtr MlirFunction;
|
||||||
|
|
||||||
|
|
||||||
|
class MlirNode : public torch::lazy::Node {
|
||||||
|
|
||||||
|
public:
|
||||||
|
MlirNode(
|
||||||
|
OpKind op, OpList operands, std::vector<Shape>&& shapes,
|
||||||
|
size_t num_outputs = 1, hash_t hash_seed = kHashSeed
|
||||||
|
);
|
||||||
|
|
||||||
|
// Same as the constructor above, but the shape is generated by a function,
|
||||||
|
// only if needed (shape cache miss).
|
||||||
|
MlirNode(
|
||||||
|
OpKind op, OpList operands,
|
||||||
|
const std::function<Shape()>& shape_fn,
|
||||||
|
size_t num_outputs = 1, hash_t hash_seed = kHashSeed
|
||||||
|
);
|
||||||
|
|
||||||
|
// The shape is set later.
|
||||||
|
MlirNode(
|
||||||
|
OpKind op, OpList operands, size_t num_outputs = 1,
|
||||||
|
hash_t hash_seed = kHashSeed
|
||||||
|
);
|
||||||
|
|
||||||
|
void SetShapeDeferred(const std::function<Shape()>& shape_fn);
|
||||||
|
|
||||||
|
// Contructor used to create leaf nodes.
|
||||||
|
MlirNode(
|
||||||
|
OpKind op, Shape shape, size_t num_outputs = 1, hash_t hash_seed = kHashSeed
|
||||||
|
);
|
||||||
|
|
||||||
|
Shape GetOpShape(const std::function<Shape()>& shape_fn) const;
|
||||||
|
|
||||||
|
const std::vector<Output>& operands() const override;
|
||||||
|
|
||||||
|
const Output& operand(size_t i) const override;
|
||||||
|
|
||||||
|
virtual MlirOpVector Lower(
|
||||||
|
MlirFunction function,
|
||||||
|
MlirLoweringContext* loctx
|
||||||
|
) const = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Adds node's index output number as operand.
|
||||||
|
void AddOperand(NodePtr node, size_t index = 0);
|
||||||
|
|
||||||
|
std::vector<Shape> shapes_;
|
||||||
|
// A node holds a real reference to its operands.
|
||||||
|
std::vector<NodePtr> operands_;
|
||||||
|
// Outputs do not hold references on the nodes, and neither do the uses, since
|
||||||
|
// otherwise we get into circular reference counting.
|
||||||
|
std::vector<Output> operands_as_outputs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,20 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#define UNIMPLEMENTED_ERROR(msg) \
|
||||||
|
{ \
|
||||||
|
std::ostringstream err; \
|
||||||
|
err << "Unimplemented Error: " << msg; \
|
||||||
|
throw std::runtime_error(err.str()); \
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define UNSUPPORTED_ERROR(msg) \
|
||||||
|
{ \
|
||||||
|
std::ostringstream err; \
|
||||||
|
err << "Unsupported Error: " << msg; \
|
||||||
|
throw std::runtime_error(err.str()); \
|
||||||
|
}
|
|
@ -2,7 +2,7 @@
|
||||||
# Setup PyTorch
|
# Setup PyTorch
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
|
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
|
||||||
include(TorchMLIRPyTorch)
|
include(TorchMLIRPyTorch)
|
||||||
|
|
||||||
TorchMLIRProbeForPyTorchInstall()
|
TorchMLIRProbeForPyTorchInstall()
|
||||||
|
|
Loading…
Reference in New Issue