Torch-MLIR LTC Backend Lowering Codegen (#621)

* Codegen and build LTC lowering

* Add LazyShapeInference header
pull/1125/head
Jae Hoon (Antonio) Kim 2022-02-25 19:50:09 -05:00 committed by Henry Tu
parent 2f22e2ef40
commit 58338f79a1
9 changed files with 78 additions and 384 deletions

13
.gitignore vendored
View File

@ -22,3 +22,16 @@ __pycache__
# Bazel
bazel-*
# Autogenerated files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/backend/LazyLazyIr.h
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/backend/LazyShapeInference.cpp
/python/torch_mlir/csrc/backend/RegisterLazy.cpp
# Libraries
*.so
*.a

View File

@ -20,9 +20,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(torch_mlir_ltc_backend SHARED
backend/backend_impl.cc
backend/mlir_lowering_context.cc
backend/mlir_node.cc
backend/aten_eager_fallback.cpp
backend/aten_ltc_mlir_type.cpp
backend/backend_impl.cpp
backend/LazyNativeFunctions.cpp
backend/LazyShapeInference.cpp
backend/mlir_lowering_context.cpp
backend/mlir_node.cpp
backend/RegisterLazy.cpp
)
target_link_libraries(torch_mlir_ltc_backend
@ -32,12 +37,13 @@ target_link_libraries(torch_mlir_ltc_backend
torch_python
)
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wpedantic")
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
set_target_properties(torch_mlir_ltc_backend PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/"
OUTPUT_NAME _MLIR_LTC
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
CXX_VISIBILITY_PRESET "hidden"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wpedantic"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
)

View File

@ -1,159 +0,0 @@
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>
#include "backend_impl.h"
#include "mlir_lowering_context.h"
#include "../utils/exception.h"
namespace torch {
namespace lazy {
struct MlirBackendData::Info : public BackendData::Info {
at::Tensor tensor;
c10::optional<at::Scalar> scalar;
Info() {}
Info(const Info& other) :
tensor{other.tensor}, scalar{other.scalar} {}
Info(const at::Tensor& tensor) : tensor{tensor} {}
Info(const at::Scalar& scalar) : scalar{scalar} {}
};
MlirBackendData::MlirBackendData(BackendDevice device, Shape shape) :
BackendData(device, shape) {
auto info = std::make_shared<MlirBackendData::Info>();
SetInfo(info);
}
MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device) :
BackendData(device, torch::lazy::Shape(scalar.type(), {})) {
auto info = std::make_shared<MlirBackendData::Info>(scalar);
SetInfo(info);
}
MlirBackendData::MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape) :
BackendData(device, shape) {
auto info = std::make_shared<MlirBackendData::Info>(tensor);
SetInfo(info);
}
BackendData::Handle MlirBackendData::GetHandle() { return reinterpret_cast<int64_t>(this); }
void MlirBackendData::Assign(const BackendData& data) {
MlirBackendData::Info* info =
dynamic_cast<MlirBackendData::Info*>(data.info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info."
);
auto new_info = std::make_shared<MlirBackendData::Info>(*info);
SetInfo(new_info);
}
bool MlirBackendData::HasValue() const {
return bool(info());
}
/**
* Initialization/Teardown
* */
void MlirBackendImpl::PrepareToExit() const {}
/**
* Data Transfer
* */
BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor(
const at::Tensor& tensor,
const Shape& shape,
const BackendDevice& device
) const {
return std::make_shared<MlirBackendData>(tensor, device, shape);
}
BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar(
const at::Scalar& scalar,
const torch::lazy::BackendDevice& device
) const {
return std::make_shared<MlirBackendData>(scalar, device);
}
BackendDataPtr MlirBackendImpl::CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape
) const {
return std::make_shared<MlirBackendData>(device, shape);
}
at::Tensor MlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type
) const {
MlirBackendData::Info* info =
dynamic_cast<MlirBackendData::Info*>(data->info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info."
);
return info->tensor;
}
/**
* Lowering, Compilation, Execution
* */
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
const std::string& name,
BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
Util::EmissionMap emit_status
) const {
return std::make_unique<MlirLoweringContext>(
name,
std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)
);
}
std::unique_ptr<LoweringContext> MlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device
) const {
return std::make_unique<MlirLoweringContext>(
name, std::forward<BackendDevice>(device)
);
}
/**
* Device Configuration
* */
// Set or get the default device type.
// For backends used with virtual c10:: Devices, this configures what real
// device type the backend should use, and matters if the backend supports
// more than one type of real device.
// Specify which aten device should be used for eager fallback
// may change depending on current 'Default' DeviceType
at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const {
return at::DeviceType::CPU;
}
// Query all available backend devices
std::vector<BackendDevice> MlirBackendImpl::GetBackendDevices() const {
return {
GetBackendDevice(c10::Device(c10::kCPU, 0)),
GetBackendDevice(c10::Device(c10::kLazy, 0))
};
}
// Map a particular c10:: device to a concrete backend device
// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are
// virtual devices, meaning they may map to a gpu, tpu, etc. behind the
// scenes. In the future, non-virtual c10:: devices may also use lazy tensors
// through a mode, in which case these APIs should still work, but should be
// identity mappings.
BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const {
return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index());
}
} // lazy
} // torch

View File

@ -1,3 +1,18 @@
//===- backend_impl.h -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// The Torch-MLIR backend class API that handles lowering LTC ATen ops to MLIR
// using the Torch-MLIR ATen dialect
//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.h
//===----------------------------------------------------------------------===//
#pragma once
#include <torch/csrc/lazy/backend/backend_data.h>

View File

@ -1,93 +0,0 @@
#include <iostream>
#include "mlir_lowering_context.h"
#include "../utils/exception.h"
namespace torch {
namespace lazy {
MlirLoweringContext::MlirLoweringContext(
const std::string& name, BackendDevice device
) : LoweringContext(name, std::forward<BackendDevice>(device)) {}
MlirLoweringContext::MlirLoweringContext(
const std::string& name,
BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
Util::EmissionMap emit_status
) : LoweringContext(
name,
std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)
) {}
int MlirComputation::parameters_size() const {
UNIMPLEMENTED_ERROR("MlirComputation::parameters_size");
}
const std::vector<torch::lazy::Shape>& MlirComputation::parameter_shapes() const {
UNIMPLEMENTED_ERROR("MlirComputation::parameter_shapes");
}
const std::vector<std::string>& MlirComputation::parameter_names() const {
UNIMPLEMENTED_ERROR("MlirComputation::parameter_names");
}
const torch::lazy::Shape& MlirComputation::result_shape() const {
UNIMPLEMENTED_ERROR("MlirComputation::result_shape");
}
// Get the shape of the result tuple component, given by index.
torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const {
UNIMPLEMENTED_ERROR("MlirLoweringContext::GetResultShape( " << index << " )");
}
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) {
const torch::lazy::Node* node;
auto it = emitted_outputs_.find(output);
if (it == emitted_outputs_.end()) {
node = output.node;
auto post_order = Util::ComputePostOrder(node, &emit_status_);
for (auto po_node : post_order) {
// TODO: uncomment after lowering is implemented
// bool ok = lowering_->Lower(node);
// TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
}
emitted_outputs_[output] = node;
} else {
node = it->second;
}
result_tuple_.emplace_back(node);
return result_tuple_.size() - 1;
}
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
void MlirLoweringContext::AddParameter(
const torch::lazy::Output& output,
size_t index,
const torch::lazy::Shape& shape,
const std::string& name
) {
UNIMPLEMENTED_ERROR("MlirLoweringContext::AddParameter");
}
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
ComputationPtr MlirLoweringContext::Build() {
for (const torch::lazy::Node* output : result_tuple_) {
}
return std::make_shared<MlirComputation>();
}
} // namespace lazy
} // namespace torch

View File

@ -1,9 +1,23 @@
//===- mlir_lowering_context.h --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_lowering_context.h
//===----------------------------------------------------------------------===//
#pragma once
#include <vector>
#include <torch/csrc/lazy/backend/lowering_context.h>
namespace torch {
namespace lazy {

View File

@ -1,124 +0,0 @@
#include <torch/csrc/lazy/core/cache.h>
#include "mlir_node.h"
#include "../utils/exception.h"
namespace torch {
namespace lazy {
namespace {
hash_t OperandHashes(const OpList& operands, const hash_t& seed, const bool bakeInSizes) {
hash_t hash = seed;
for (auto& operand : operands) {
if (!operand) {
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
continue;
}
auto operand_hash = bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes();
hash = HashCombine(hash, operand_hash);
}
return hash;
}
hash_t GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) {
hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes));
return HashCombine(h, hash_seed);
}
} // namespace
MlirNode::MlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes,
size_t num_outputs, hash_t hash_seed
) : Node(
op, num_outputs,
/* node_hash */ HashCombine(op.hash(), hash_seed),
/* dag_hash */
[&](bool bakeInSizes) -> hash_t {
return OperandHashes(operands, HashCombine(op.hash(), hash_seed), bakeInSizes);
}
),
shapes_(std::move(shapes)) {
for (auto& operand : operands) {
// Ideally, optional operands should be filtered by the leaf node classes,
// but it's just much easier to do it here.
if (!operand) {
continue;
}
AddOperand(operand.node, operand.index);
}
}
MlirNode::MlirNode(
OpKind op, OpList operands,
const std::function<Shape()>& shape_fn,
size_t num_outputs, hash_t hash_seed
) : MlirNode(
op, operands, std::vector<Shape>{}, num_outputs, hash_seed
) {
shapes_.push_back(GetOpShape(shape_fn));
}
MlirNode::MlirNode(
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed
) : MlirNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
void MlirNode::SetShapeDeferred(
const std::function<Shape()>& shape_fn
) {
shapes_.push_back(GetOpShape(shape_fn));
}
MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
: Node(
op, num_outputs,
[&](bool bakeInSizes) -> hash_t {
return GetOpHash(op, shape, hash_seed, bakeInSizes);
}
) {
shapes_.push_back(std::move(shape));
}
using ShapeCache = Cache<hash_t, Shape, HashReducer>;
constexpr const int torch_lazy_shape_cache_size = 4096;
ShapeCache* GetShapeCache() {
static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size);
return cache;
}
Shape MlirNode::GetOpShape(const std::function<Shape()>& shape_fn) const {
ShapeCache* shape_cache = GetShapeCache();
auto shape = shape_cache->Get(hash());
if (shape == nullptr) {
shape = shape_cache->Add(
hash(), std::make_shared<Shape>(shape_fn())
);
}
return *shape;
}
const std::vector<Output>& MlirNode::operands() const {
return operands_as_outputs_;
}
const Output& MlirNode::operand(size_t i) const {
return operands_as_outputs_.at(i);
}
void MlirNode::AddOperand(NodePtr node, size_t index) {
CHECK_LT(index, node->num_outputs());
operands_.push_back(std::move(node));
operands_as_outputs_.emplace_back(operands_.back().get(), index);
}
} // namespace lazy
} // namespace torch

View File

@ -1,3 +1,15 @@
//===- mlir_node.h --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.h
//===----------------------------------------------------------------------===//
#pragma once
#include <ATen/core/interned_strings.h>
@ -5,6 +17,7 @@
#include <torch/csrc/lazy/core/shape.h>
#include <torch/csrc/lazy/core/ir.h>
#include "aten_eager_fallback.h"
#include "mlir_lowering_context.h"
#include "../utils/exception.h"

View File

@ -1,3 +1,12 @@
//===- exception.h --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <exception>