diff --git a/.gitignore b/.gitignore index dc506413e..5adbf95d6 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,16 @@ __pycache__ # Bazel bazel-* + +# Autogenerated files +/generated_native_functions.yaml +/generated_backend.hash +/python/torch_mlir/csrc/backend/LazyLazyIr.h +/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp +/python/torch_mlir/csrc/backend/LazyNativeFunctions.h +/python/torch_mlir/csrc/backend/LazyShapeInference.cpp +/python/torch_mlir/csrc/backend/RegisterLazy.cpp + +# Libraries +*.so +*.a diff --git a/python/torch_mlir/csrc/CMakeLists.txt b/python/torch_mlir/csrc/CMakeLists.txt index 05f34040c..a7ca1dba3 100644 --- a/python/torch_mlir/csrc/CMakeLists.txt +++ b/python/torch_mlir/csrc/CMakeLists.txt @@ -20,9 +20,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib") add_library(torch_mlir_ltc_backend SHARED - backend/backend_impl.cc - backend/mlir_lowering_context.cc - backend/mlir_node.cc + backend/aten_eager_fallback.cpp + backend/aten_ltc_mlir_type.cpp + backend/backend_impl.cpp + backend/LazyNativeFunctions.cpp + backend/LazyShapeInference.cpp + backend/mlir_lowering_context.cpp + backend/mlir_node.cpp + backend/RegisterLazy.cpp ) target_link_libraries(torch_mlir_ltc_backend @@ -32,12 +37,13 @@ target_link_libraries(torch_mlir_ltc_backend torch_python ) -message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wpedantic") +message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") set_target_properties(torch_mlir_ltc_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/" OUTPUT_NAME _MLIR_LTC PREFIX "${PYTHON_MODULE_PREFIX}" SUFFIX "${PYTHON_MODULE_EXTENSION}" CXX_VISIBILITY_PRESET "hidden" - COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wpedantic" + COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic" ) + diff --git a/python/torch_mlir/csrc/backend/backend_impl.cc b/python/torch_mlir/csrc/backend/backend_impl.cc deleted file mode 100644 index 96ebbcda9..000000000 --- a/python/torch_mlir/csrc/backend/backend_impl.cc +++ /dev/null @@ -1,159 +0,0 @@ -#include -#include -#include -#include - -#include "backend_impl.h" -#include "mlir_lowering_context.h" -#include "../utils/exception.h" - -namespace torch { -namespace lazy { - -struct MlirBackendData::Info : public BackendData::Info { - at::Tensor tensor; - c10::optional scalar; - - Info() {} - Info(const Info& other) : - tensor{other.tensor}, scalar{other.scalar} {} - Info(const at::Tensor& tensor) : tensor{tensor} {} - Info(const at::Scalar& scalar) : scalar{scalar} {} -}; - -MlirBackendData::MlirBackendData(BackendDevice device, Shape shape) : - BackendData(device, shape) { - auto info = std::make_shared(); - SetInfo(info); -} -MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device) : - BackendData(device, torch::lazy::Shape(scalar.type(), {})) { - auto info = std::make_shared(scalar); - SetInfo(info); -} -MlirBackendData::MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape) : - BackendData(device, shape) { - auto info = std::make_shared(tensor); - SetInfo(info); -} - -BackendData::Handle MlirBackendData::GetHandle() { return reinterpret_cast(this); } - -void MlirBackendData::Assign(const BackendData& data) { - MlirBackendData::Info* info = - dynamic_cast(data.info()); - TORCH_CHECK( - info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info." - ); - auto new_info = std::make_shared(*info); - SetInfo(new_info); -} - -bool MlirBackendData::HasValue() const { - return bool(info()); -} - -/** - * Initialization/Teardown - * */ -void MlirBackendImpl::PrepareToExit() const {} - -/** - * Data Transfer - * */ - -BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor( - const at::Tensor& tensor, - const Shape& shape, - const BackendDevice& device -) const { - return std::make_shared(tensor, device, shape); -} - -BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar( - const at::Scalar& scalar, - const torch::lazy::BackendDevice& device -) const { - return std::make_shared(scalar, device); -} - -BackendDataPtr MlirBackendImpl::CreateDataPlaceholder( - const BackendDevice& device, const Shape& shape -) const { - return std::make_shared(device, shape); -} - -at::Tensor MlirBackendImpl::MakeTensorFromComputationData( - const BackendDataPtr data, - c10::optional logical_scalar_type -) const { - MlirBackendData::Info* info = - dynamic_cast(data->info()); - TORCH_CHECK( - info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info." - ); - return info->tensor; -} - -/** - * Lowering, Compilation, Execution - * */ - -std::unique_ptr MlirBackendImpl::CreateLoweringContext( - const std::string& name, - BackendDevice device, - c10::ArrayRef post_order, - Util::EmissionMap emit_status -) const { - return std::make_unique( - name, - std::forward(device), - std::forward>(post_order), - std::forward(emit_status) - ); -} - -std::unique_ptr MlirBackendImpl::CreateLoweringContext( - const std::string& name, BackendDevice device -) const { - return std::make_unique( - name, std::forward(device) - ); -} - -/** - * Device Configuration - * */ - -// Set or get the default device type. -// For backends used with virtual c10:: Devices, this configures what real -// device type the backend should use, and matters if the backend supports -// more than one type of real device. - -// Specify which aten device should be used for eager fallback -// may change depending on current 'Default' DeviceType -at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const { - return at::DeviceType::CPU; -} - - -// Query all available backend devices -std::vector MlirBackendImpl::GetBackendDevices() const { - return { - GetBackendDevice(c10::Device(c10::kCPU, 0)), - GetBackendDevice(c10::Device(c10::kLazy, 0)) - }; -} - -// Map a particular c10:: device to a concrete backend device -// Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are -// virtual devices, meaning they may map to a gpu, tpu, etc. behind the -// scenes. In the future, non-virtual c10:: devices may also use lazy tensors -// through a mode, in which case these APIs should still work, but should be -// identity mappings. -BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const { - return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index()); -} - -} // lazy -} // torch diff --git a/python/torch_mlir/csrc/backend/backend_impl.h b/python/torch_mlir/csrc/backend/backend_impl.h index 52055f460..d016c3670 100644 --- a/python/torch_mlir/csrc/backend/backend_impl.h +++ b/python/torch_mlir/csrc/backend/backend_impl.h @@ -1,3 +1,18 @@ +//===- backend_impl.h -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// The Torch-MLIR backend class API that handles lowering LTC ATen ops to MLIR +// using the Torch-MLIR ATen dialect +// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.h +//===----------------------------------------------------------------------===// + #pragma once #include diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.cc b/python/torch_mlir/csrc/backend/mlir_lowering_context.cc deleted file mode 100644 index e122843de..000000000 --- a/python/torch_mlir/csrc/backend/mlir_lowering_context.cc +++ /dev/null @@ -1,93 +0,0 @@ -#include - -#include "mlir_lowering_context.h" -#include "../utils/exception.h" - - -namespace torch { -namespace lazy { - -MlirLoweringContext::MlirLoweringContext( - const std::string& name, BackendDevice device -) : LoweringContext(name, std::forward(device)) {} - -MlirLoweringContext::MlirLoweringContext( - const std::string& name, - BackendDevice device, - c10::ArrayRef post_order, - Util::EmissionMap emit_status -) : LoweringContext( - name, - std::forward(device), - std::forward>(post_order), - std::forward(emit_status) -) {} - -int MlirComputation::parameters_size() const { - UNIMPLEMENTED_ERROR("MlirComputation::parameters_size"); -} - -const std::vector& MlirComputation::parameter_shapes() const { - UNIMPLEMENTED_ERROR("MlirComputation::parameter_shapes"); -} - -const std::vector& MlirComputation::parameter_names() const { - UNIMPLEMENTED_ERROR("MlirComputation::parameter_names"); -} - -const torch::lazy::Shape& MlirComputation::result_shape() const { - UNIMPLEMENTED_ERROR("MlirComputation::result_shape"); -} - - -// Get the shape of the result tuple component, given by index. -torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const { - UNIMPLEMENTED_ERROR("MlirLoweringContext::GetResultShape( " << index << " )"); -} - -// Adds the given output as a component of the result tuple and returns its -// assigned position within the tuple. -size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) { - const torch::lazy::Node* node; - auto it = emitted_outputs_.find(output); - if (it == emitted_outputs_.end()) { - node = output.node; - - auto post_order = Util::ComputePostOrder(node, &emit_status_); - for (auto po_node : post_order) { - // TODO: uncomment after lowering is implemented - // bool ok = lowering_->Lower(node); - // TORCH_CHECK(ok, "Failed to lower: ", node->ToString()); - } - emitted_outputs_[output] = node; - } else { - node = it->second; - } - result_tuple_.emplace_back(node); - return result_tuple_.size() - 1; -} - -// Associates the given output with the input parameter of the given index and -// shape. Only used for the operator-by-operator execution, mostly for -// debugging purposes. -void MlirLoweringContext::AddParameter( - const torch::lazy::Output& output, - size_t index, - const torch::lazy::Shape& shape, - const std::string& name -) { - UNIMPLEMENTED_ERROR("MlirLoweringContext::AddParameter"); -} - -// Build the computation capturing all the operations created with the -// embedded builder (returned by the builder() API). -ComputationPtr MlirLoweringContext::Build() { - for (const torch::lazy::Node* output : result_tuple_) { - - } - return std::make_shared(); -} - - -} // namespace lazy -} // namespace torch diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.h b/python/torch_mlir/csrc/backend/mlir_lowering_context.h index 0fac168ec..6ba3034a3 100644 --- a/python/torch_mlir/csrc/backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/backend/mlir_lowering_context.h @@ -1,9 +1,23 @@ +//===- mlir_lowering_context.h --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_lowering_context.h +//===----------------------------------------------------------------------===// + + #pragma once #include #include + namespace torch { namespace lazy { @@ -38,9 +52,9 @@ class MlirLoweringContext : public torch::lazy::LoweringContext { // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. virtual void AddParameter(const torch::lazy::Output& output, - size_t index, - const torch::lazy::Shape& shape, - const std::string& name) override; + size_t index, + const torch::lazy::Shape& shape, + const std::string& name) override; // Build the computation capturing all the operations created with the // embedded builder (returned by the builder() API). diff --git a/python/torch_mlir/csrc/backend/mlir_node.cc b/python/torch_mlir/csrc/backend/mlir_node.cc deleted file mode 100644 index 0db1b5898..000000000 --- a/python/torch_mlir/csrc/backend/mlir_node.cc +++ /dev/null @@ -1,124 +0,0 @@ -#include - -#include "mlir_node.h" -#include "../utils/exception.h" - - -namespace torch { -namespace lazy { - -namespace { - -hash_t OperandHashes(const OpList& operands, const hash_t& seed, const bool bakeInSizes) { - hash_t hash = seed; - for (auto& operand : operands) { - if (!operand) { - hash = HashCombine(hash, static_cast(kNullOpt)); - continue; - } - auto operand_hash = bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes(); - hash = HashCombine(hash, operand_hash); - } - return hash; -} - -hash_t GetOpHash(OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) { - hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes)); - return HashCombine(h, hash_seed); -} - -} // namespace - - -MlirNode::MlirNode( - OpKind op, OpList operands, std::vector&& shapes, - size_t num_outputs, hash_t hash_seed -) : Node( - op, num_outputs, - /* node_hash */ HashCombine(op.hash(), hash_seed), - /* dag_hash */ - [&](bool bakeInSizes) -> hash_t { - return OperandHashes(operands, HashCombine(op.hash(), hash_seed), bakeInSizes); - } - ), - shapes_(std::move(shapes)) { - - for (auto& operand : operands) { - // Ideally, optional operands should be filtered by the leaf node classes, - // but it's just much easier to do it here. - if (!operand) { - continue; - } - - AddOperand(operand.node, operand.index); - } -} - -MlirNode::MlirNode( - OpKind op, OpList operands, - const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed -) : MlirNode( - op, operands, std::vector{}, num_outputs, hash_seed - ) { - shapes_.push_back(GetOpShape(shape_fn)); -} - -MlirNode::MlirNode( - OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed -) : MlirNode(op, operands, std::vector{}, num_outputs, hash_seed) {} - -void MlirNode::SetShapeDeferred( - const std::function& shape_fn -) { - shapes_.push_back(GetOpShape(shape_fn)); -} - -MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) - : Node( - op, num_outputs, - [&](bool bakeInSizes) -> hash_t { - return GetOpHash(op, shape, hash_seed, bakeInSizes); - } - ) { - shapes_.push_back(std::move(shape)); -} - - -using ShapeCache = Cache; - -constexpr const int torch_lazy_shape_cache_size = 4096; - -ShapeCache* GetShapeCache() { - static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size); - return cache; -} - -Shape MlirNode::GetOpShape(const std::function& shape_fn) const { - ShapeCache* shape_cache = GetShapeCache(); - auto shape = shape_cache->Get(hash()); - if (shape == nullptr) { - shape = shape_cache->Add( - hash(), std::make_shared(shape_fn()) - ); - } - return *shape; -} - - -const std::vector& MlirNode::operands() const { - return operands_as_outputs_; -} - -const Output& MlirNode::operand(size_t i) const { - return operands_as_outputs_.at(i); -} - -void MlirNode::AddOperand(NodePtr node, size_t index) { - CHECK_LT(index, node->num_outputs()); - operands_.push_back(std::move(node)); - operands_as_outputs_.emplace_back(operands_.back().get(), index); -} - -} // namespace lazy -} // namespace torch diff --git a/python/torch_mlir/csrc/backend/mlir_node.h b/python/torch_mlir/csrc/backend/mlir_node.h index 48b70fe28..91be16ef8 100644 --- a/python/torch_mlir/csrc/backend/mlir_node.h +++ b/python/torch_mlir/csrc/backend/mlir_node.h @@ -1,3 +1,15 @@ +//===- mlir_node.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.h +//===----------------------------------------------------------------------===// + #pragma once #include @@ -5,6 +17,7 @@ #include #include +#include "aten_eager_fallback.h" #include "mlir_lowering_context.h" #include "../utils/exception.h" diff --git a/python/torch_mlir/csrc/utils/exception.h b/python/torch_mlir/csrc/utils/exception.h index a9dafdcfd..5f5d790b0 100644 --- a/python/torch_mlir/csrc/utils/exception.h +++ b/python/torch_mlir/csrc/utils/exception.h @@ -1,3 +1,12 @@ +//===- exception.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + #pragma once #include