Added JIT to MLIR lowering (#724)

* Added JIT to MLIR lowering

Lowering to JIT is performed in a way similar to how it's done in the TS LTC backend. After a jit::Graph is constructed, it gets converted to a jit::Function, which is fed into the existing utility to generate an MlirModule in torch-mlir.

* Renamed `csrc/backend` to `csrc/base_lazy_backend`
pull/1125/head
Henry Tu 2022-04-14 12:40:10 -04:00 committed by Henry Tu
parent 65cf1465ef
commit 3e9b1cbd36
19 changed files with 908 additions and 206 deletions

18
.gitignore vendored
View File

@ -23,15 +23,15 @@ __pycache__
# Bazel
bazel-*
# Autogenerated files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/backend/LazyIr.h
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/backend/RegisterLazy.cpp
# Libraries
*.so
*.a
# Autogenerated files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/base_lazy_backend/LazyIr.h
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/base_lazy_backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/base_lazy_backend/RegisterLazy.cpp

View File

@ -300,7 +300,7 @@ def main(args):
"generated_native_functions.yaml"
)
backend_path = TORCH_MLIR_DIR.joinpath(
"python", "torch_mlir", "csrc", "backend"
"python", "torch_mlir", "csrc", "base_lazy_backend"
)
assert backend_path.is_dir()

View File

@ -20,18 +20,23 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(torch_mlir_ltc_backend SHARED
backend/backend_impl.cpp
backend/LazyNativeFunctions.cpp
backend/LazyShapeInference.cpp
backend/GenLazyShapeInference.cpp
backend/mlir_lowering_context.cpp
backend/mlir_native_functions.cpp
backend/mlir_node.cpp
backend/RegisterLazy.cpp
base_lazy_backend/backend_impl.cpp
base_lazy_backend/LazyNativeFunctions.cpp
base_lazy_backend/LazyShapeInference.cpp
base_lazy_backend/GenLazyShapeInference.cpp
base_lazy_backend/mlir_lowering_context.cpp
base_lazy_backend/mlir_native_functions.cpp
base_lazy_backend/mlir_node.cpp
base_lazy_backend/mlir_node_lowering.cpp
base_lazy_backend/RegisterLazy.cpp
)
add_dependencies(torch_mlir_ltc_backend
TorchMLIRJITIRImporter
)
target_link_libraries(torch_mlir_ltc_backend
TorchMLIRAggregateCAPI
TorchMLIRJITIRImporter
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
torch_python

View File

@ -1,107 +0,0 @@
//===- mlir_lowering_context.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
//===----------------------------------------------------------------------===//
#include <iostream>
#include "../utils/debug.h"
#include "../utils/exception.h"
#include "mlir_lowering_context.h"
namespace torch {
namespace lazy {
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device)
: LoweringContext(name, std::forward<BackendDevice>(device)) {}
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
: LoweringContext(
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)) {}
int TorchMlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); }
const std::vector<torch::lazy::Shape>&
TorchMlirComputation::parameter_shapes() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
std::string TorchMlirComputation::to_string() const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Get the shape of the result tuple component, given by index.
torch::lazy::Shape TorchMlirLoweringContext::GetResultShape(size_t index) const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
size_t TorchMlirLoweringContext::AddResult(const torch::lazy::Output& output) {
PRINT_FUNCTION();
const torch::lazy::Node* node;
auto it = emitted_outputs_.find(output);
if (it == emitted_outputs_.end()) {
node = output.node;
auto post_order = Util::ComputePostOrder(node, &emit_status_);
for (auto po_node : post_order) {
// TODO: uncomment after lowering is implemented
// bool ok = lowering_->Lower(node);
// TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
}
emitted_outputs_[output] = node;
} else {
node = it->second;
}
result_tuple_.emplace_back(node);
return result_tuple_.size() - 1;
}
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
void TorchMlirLoweringContext::AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
ComputationPtr TorchMlirLoweringContext::Build() {
PRINT_FUNCTION()
for (const torch::lazy::Node* output : result_tuple_) {
}
return std::make_shared<TorchMlirComputation>();
}
// Retrieves the lowered operation for an output. If the requested output is
// not available yet, the graph behind the output's Node is lowered, and the
// corresponding MLIR operation returned.
torch::jit::Value* GetOutputOp(const Output& output) {
UNIMPLEMENTED_FUNCTION_ERROR();
}
} // namespace lazy
} // namespace torch

View File

@ -1,72 +0,0 @@
//===- mlir_lowering_context.h --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/ts_lowering_context.h
//===----------------------------------------------------------------------===//
#pragma once
#include <vector>
#include <torch/csrc/lazy/backend/lowering_context.h>
namespace torch {
namespace lazy {
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
public:
int parameters_size() const override;
virtual const std::vector<torch::lazy::Shape>&
parameter_shapes() const override;
virtual const std::vector<std::string>& parameter_names() const override;
virtual const torch::lazy::Shape& result_shape() const override;
};
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
public:
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device);
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);
// Get the shape of the result tuple component, given by index.
virtual torch::lazy::Shape GetResultShape(size_t index) const override;
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
virtual size_t AddResult(const torch::lazy::Output& output) override;
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
virtual void AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) override;
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
virtual torch::lazy::ComputationPtr Build() override;
// Retrieves the lowered operation for an output. If the requested output is
// not available yet, the graph behind the output's Node is lowered, and the
// corresponding MLIR operation returned.
torch::jit::Value* GetOutputOp(const Output& output);
private:
std::vector<const torch::lazy::Node*> result_tuple_;
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
};
} // namespace lazy
} // namespace torch

View File

@ -95,7 +95,8 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(data->info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
return info->tensor;
}

View File

@ -41,7 +41,8 @@ public:
TorchMlirBackendData(BackendDevice device, Shape shape);
TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device);
TorchMlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape);
virtual BackendData::Handle GetHandle() override;

View File

@ -0,0 +1,255 @@
//===- mlir_lowering_context.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
//===----------------------------------------------------------------------===//
#include <iostream>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
#include "../utils/debug.h"
#include "../utils/exception.h"
#include "backend_impl.h"
#include "mlir-c/Registration.h"
#include "mlir_lowering_context.h"
#include "mlir_node.h"
#include "torch-mlir-c/Registration.h"
namespace torch {
namespace lazy {
///////////////////////////////////////////////////////////////////////////////
// TorchMlir Computation
///////////////////////////////////////////////////////////////////////////////
TorchMlirComputation::TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), num_results_(graph_->outputs().size()) {
// TODO(henrytu): Save parameter shape information.
for (torch::jit::Value* input : graph_->inputs()) {
parameter_names_.push_back(input->debugName());
}
}
int TorchMlirComputation::parameters_size() const {
return parameter_names_.size();
}
const std::vector<torch::lazy::Shape>&
TorchMlirComputation::parameter_shapes() const {
throw std::runtime_error(
"todo(whc) implement ts computation shapes or change interface");
return parameter_shapes_;
}
const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
return parameter_names_;
}
const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
throw std::runtime_error(
"todo(whc) implement ts computation shapes or change interface");
return result_shape_;
}
unsigned TorchMlirComputation::num_results() const { return num_results_; }
MlirOperation TorchMlirComputation::func_op() const { return func_op_; }
std::string TorchMlirComputation::to_string() const {
// Since we use the C-MLIR API, we need to use a callback to print.
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
// user_data is a void ptr to some data structure of our choice -- in this
// case, the string stream where we'll be accumulating the strings.
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
*ss_ptr << std::string(part.data, part.length);
};
std::stringstream ss;
ss << "JIT Graph: \n"
<< graph_->toString() << "\n\n"
<< "MLIR: \n";
mlirOperationPrint(func_op_, print_callback, &ss);
return ss.str();
}
///////////////////////////////////////////////////////////////////////////////
// TorchMlir Lowering Context
///////////////////////////////////////////////////////////////////////////////
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device)
: LoweringContext(name, std::forward<BackendDevice>(device)),
graph_(std::make_shared<torch::jit::Graph>()),
mlir_context_(mlirContextCreate()) {
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
RegisterMlirDialects();
}
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
: LoweringContext(
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)),
graph_(std::make_shared<torch::jit::Graph>()),
mlir_context_(mlirContextCreate()) {
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
CHECK(ok) << "Failed to lower: " << *node;
}
RegisterMlirDialects();
}
// Get the shape of the result tuple component, given by index.
torch::lazy::Shape
TorchMlirLoweringContext::GetResultShape(size_t index) const {
UNIMPLEMENTED_FUNCTION_ERROR();
}
size_t TorchMlirLoweringContext::AddResult(const Output& output) {
PRINT_FUNCTION();
return AddResult(GetOutputOp(output));
}
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
void TorchMlirLoweringContext::AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) {
UNIMPLEMENTED_FUNCTION_ERROR();
}
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
ComputationPtr TorchMlirLoweringContext::Build() {
PRINT_FUNCTION();
for (torch::jit::Value* output : root_tuple_) {
graph_->block()->registerOutput(output);
}
// Create jit::Function from jit::Graph.
c10::QualifiedName name("graph");
auto cu = std::make_shared<torch::jit::CompilationUnit>();
// IMPORTANT: We pass in a COPY of the graph into create_function, since it
// may get mutated in the process.
auto jit_fn = cu->create_function(std::move(name), std::move(graph_->copy()));
// Generate MLIR.
MlirOperation func_op =
torch_mlir::importJitFunctionAsFuncOp(mlir_context_, jit_fn);
// TODO(henrytu): Inject tensor shapes into func_op
return std::make_shared<TorchMlirComputation>(func_op, mlir_context_, graph_);
}
torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
PRINT_FUNCTION();
auto it = emitted_outputs_.find(output);
if (it == emitted_outputs_.end()) {
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
}
// At this point the output better be present, otherwise there is an issue
// with the lowering code.
it = emitted_outputs_.find(output);
TORCH_CHECK(
it != emitted_outputs_.end(),
"No MLIR operation emitted for output: ", output.ToString());
}
return it->second;
}
void TorchMlirLoweringContext::AssignOutputOp(
const Output& output, torch::jit::Value* op) {
PRINT_FUNCTION();
auto torch_mlir_node =
NodeCast<TorchMlirNode>(output.node, output.node->op());
if (!torch_mlir_node->getPythonStacktrace().empty()) {
op->node()->s_(
c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace());
}
emitted_outputs_[output] = std::move(op);
}
torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
PRINT_FUNCTION();
if (!dynamic_cast<TorchMlirBackendData*>(data.get())) {
TORCH_CHECK(
false,
"Expected TorchMlirBackendData. Got some other BackendData type");
}
const auto mlir_data = std::static_pointer_cast<TorchMlirBackendData>(data);
BackendData::Handle handle = mlir_data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
torch::jit::Value* param =
graph_->addInput(c10::str("p", parameters_.size()));
auto info = mlir_data->mlir_info();
if (info->scalar.has_value()) {
auto& scalar = info->scalar.value();
if (scalar.isFloatingPoint()) {
param->setType(c10::FloatType::get());
} else if (scalar.isIntegral(true)) {
param->setType(c10::IntType::get());
} else {
TORCH_CHECK(
false, "Unhandled scalar type: ", c10::toString(scalar.type()));
}
}
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
.first;
parameters_.push_back(mlir_data);
}
parameter_sequence_.push_back(it->second.index);
return it->second.param;
}
std::shared_ptr<torch::jit::Graph> TorchMlirLoweringContext::graph() const {
return graph_;
}
size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) {
PRINT_FUNCTION();
root_tuple_.push_back(std::move(op));
return root_tuple_.size() - 1;
}
void TorchMlirLoweringContext::RegisterMlirDialects() {
// https://reviews.llvm.org/D88162
mlirRegisterAllDialects(mlir_context_);
torchMlirRegisterAllDialects(mlir_context_);
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,136 @@
//===- mlir_lowering_context.h --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.h
//===----------------------------------------------------------------------===//
#pragma once
#include <vector>
#include <torch/csrc/api/include/torch/jit.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include "mlir-c/IR.h"
#include "mlir_node_lowering.h"
namespace torch {
namespace lazy {
class TORCH_API TorchMlirNodeLoweringInterface {
/**
* This interface is only needed for legacy ops, and can be removed once all
* ops implement LtcMlirNode->lower().
* */
public:
TorchMlirNodeLoweringInterface() = default;
virtual ~TorchMlirNodeLoweringInterface() = default;
virtual bool Lower(const Node* node) = 0;
static std::unique_ptr<TorchMlirNodeLoweringInterface>
Create(LoweringContext* loctx);
};
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
public:
TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph);
int parameters_size() const override;
const std::vector<torch::lazy::Shape>& parameter_shapes() const override;
const std::vector<std::string>& parameter_names() const override;
const torch::lazy::Shape& result_shape() const override;
unsigned num_results() const;
MlirOperation func_op() const;
std::string to_string() const;
private:
std::vector<std::string> parameter_names_;
std::vector<Shape> parameter_shapes_;
Shape result_shape_;
MlirOperation func_op_;
MlirContext mlir_context_;
std::shared_ptr<torch::jit::Graph> graph_;
unsigned num_results_;
};
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
public:
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device);
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);
// Get the shape of the result tuple component, given by index.
torch::lazy::Shape GetResultShape(size_t index) const override;
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
size_t AddResult(const torch::lazy::Output& output) override;
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
void AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) override;
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
torch::lazy::ComputationPtr Build() override;
// Retrieves the lowered operation for an output. If the requested output is
// not available yet, the graph behind the output's Node is lowered, and the
// corresponding TS operation returned.
torch::jit::Value* GetOutputOp(const Output& output);
// Assigns the given TS operation to the specified output. As outputs are
// lowered in a post-order fashion, later nodes should always find their
// operands among the emitted outputs.
void AssignOutputOp(const Output& output, torch::jit::Value* op);
// If a parameter associated with data has already been declared, it will be
// returned. Otherwise a new one will be created, associated with the tensor
// held in data.
torch::jit::Value* GetParameter(BackendDataPtr data);
std::shared_ptr<torch::jit::Graph> graph() const;
private:
struct Parameter {
torch::jit::Value* param;
size_t index = 0;
};
size_t AddResult(torch::jit::Value* op);
void RegisterMlirDialects();
std::shared_ptr<torch::jit::Graph> graph_;
MlirContext mlir_context_;
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
std::vector<torch::jit::Value*> root_tuple_;
OutputMap<torch::jit::Value*> emitted_outputs_;
std::unique_ptr<TorchMlirNodeLoweringInterface> lowering_;
};
} // namespace lazy
} // namespace torch

View File

@ -18,8 +18,8 @@
namespace torch {
namespace lazy {
TorchMlirOpVector
TorchMlirNode::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
TorchMlirOpVector TorchMlirNode::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
return {};
}

View File

@ -25,9 +25,6 @@
namespace torch {
namespace lazy {
typedef std::vector<torch::jit::Value*> TorchMlirOpVector;
typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
class TORCH_API TorchMlirNode : public torch::lazy::Node {
public:
using torch::lazy::Node::Node;

View File

@ -0,0 +1,452 @@
//===- mlir_node_lowering.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp
//===----------------------------------------------------------------------===//
#include "mlir_node_lowering.h"
#include "mlir_lowering_context.h"
#include "mlir_node.h"
#include <ATen/Functions.h>
#include <c10/core/ScalarType.h>
#include <torch/csrc/jit/frontend/sugared_value.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/permutation_util.h>
#include <torch/csrc/lazy/core/internal_ops/cast.h>
#include <torch/csrc/lazy/core/internal_ops/device_data.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/ops/batch_norm_ops.h>
#include <torch/csrc/lazy/core/ops/expand.h>
#include <torch/csrc/lazy/core/ops/scalar.h>
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
#include <torch/csrc/lazy/core/view_ops/as_strided_view_update.h>
#include <torch/csrc/lazy/core/view_ops/narrow.h>
#include <torch/csrc/lazy/core/view_ops/narrow_view_update.h>
#include <torch/csrc/lazy/core/view_ops/permute.h>
#include <torch/csrc/lazy/core/view_ops/select.h>
#include <torch/csrc/lazy/core/view_ops/select_view_update.h>
#include <torch/csrc/lazy/core/view_ops/squeeze.h>
#include <torch/csrc/lazy/core/view_ops/unsqueeze.h>
#include <torch/csrc/lazy/core/view_ops/view.h>
namespace torch {
namespace lazy {
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
public:
TorchMlirNodeLowering(
const std::string& name, torch::lazy::TorchMlirLoweringContext* loctx)
: loctx_(loctx), function_(
loctx ? std::make_shared<torch::jit::GraphFunction>(
name, loctx->graph(), nullptr)
: nullptr) {}
torch::lazy::TorchMlirLoweringContext* loctx() { return loctx_; }
bool Lower(const torch::lazy::Node* node) override {
if (auto* torch_mlir_node =
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
// First, we call the node lowering function, which exists for newly
// codegenned or refactored nodes
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, loctx());
if (ops.empty()) {
// Then fall back to legacy lowering code, which should be gradually
// removed
ops = LowerNonCodegenOps(node);
}
if (ops.empty()) {
return false;
}
CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
return true;
} else {
TorchMlirOpVector ops = LowerNonCodegenOps(node);
if (!ops.empty()) {
CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
return true;
}
}
throw std::runtime_error(
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
}
// TODO(whc) this is for legacy/non-codegen Ops, and after moving most ops
// to codegen we should delete this and put all the lowering logic into Node
// classes
TorchMlirOpVector LowerNonCodegenOps(const torch::lazy::Node* node) {
if (node->op().op == at::aten::as_strided) {
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(
node, torch::lazy::OpKind(at::aten::as_strided)));
}
if (node->op() == *torch::lazy::ltc_as_strided_view_update) {
return LowerAsStridedViewUpdate(
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(
node, *torch::lazy::ltc_as_strided_view_update));
}
if (node->op() == *torch::lazy::ltc_cast) {
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(
node, *torch::lazy::ltc_cast));
}
if (node->op() == *torch::lazy::ltc_select_view_update) {
return LowerSelectViewUpdate(
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(
node, *torch::lazy::ltc_select_view_update));
}
if (node->op() == *torch::lazy::ltc_narrow_view_update) {
return LowerNarrowViewUpdate(
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(
node, *torch::lazy::ltc_narrow_view_update));
}
if (node->op().op == at::prim::Constant) {
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(
node, torch::lazy::OpKind(at::prim::Constant)));
}
if (node->op().op == at::aten::bernoulli) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
return LowerBuiltin(node, arguments);
}
if (node->op().op == at::aten::native_batch_norm) {
return LowerBatchNorm(
torch::lazy::NodeCast<torch::lazy::NativeBatchNormForward>(
node, torch::lazy::OpKind(at::aten::native_batch_norm)));
}
if (node->op().op == at::aten::native_batch_norm_backward) {
return LowerBatchNormBackward(
torch::lazy::NodeCast<torch::lazy::NativeBatchNormBackward>(
node, torch::lazy::OpKind(at::aten::native_batch_norm_backward)));
}
if (node->op().op == at::aten::expand) {
return LowerExpand(torch::lazy::NodeCast<torch::lazy::Expand>(
node, torch::lazy::OpKind(at::aten::expand)));
}
if (node->op().op == at::aten::narrow) {
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(
node, torch::lazy::OpKind(at::aten::narrow)));
}
if (node->op().op == at::aten::permute) {
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(
node, torch::lazy::OpKind(at::aten::permute)));
}
if (node->op().op == at::aten::select) {
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(
node, torch::lazy::OpKind(at::aten::select)));
}
if (node->op().op == at::aten::squeeze) {
return LowerSqueeze(torch::lazy::NodeCast<torch::lazy::Squeeze>(
node, torch::lazy::OpKind(at::aten::squeeze)));
}
if (node->op().op == at::aten::unsqueeze) {
return LowerUnsqueeze(torch::lazy::NodeCast<torch::lazy::Unsqueeze>(
node, torch::lazy::OpKind(at::aten::unsqueeze)));
}
if (node->op().op == at::aten::view) {
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(
node, torch::lazy::OpKind(at::aten::view)));
}
if (node->op() == *torch::lazy::ltc_device_data) {
const torch::lazy::DeviceData* device_data_node =
torch::lazy::NodeCast<torch::lazy::DeviceData>(
node, *torch::lazy::ltc_device_data);
auto infoptr = device_data_node->data()->info();
auto deviceDataInfoPtr =
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
if (GRAPH_DUMP_ENABLED) {
LOG(ERROR) << "Lowering device data node, tensor id "
<< deviceDataInfoPtr->tensor_id << std::endl;
}
return {loctx()->GetParameter(device_data_node->data())};
}
std::vector<torch::jit::NamedValue> arguments;
for (const torch::lazy::Output& output : node->operands()) {
arguments.emplace_back(loctx()->GetOutputOp(output));
}
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerBuiltin(
const torch::lazy::Node* node,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(
function_, node->op().op, arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(function_, sym, arguments, kwarguments);
}
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->size());
arguments.emplace_back(node->stride());
arguments.emplace_back(node->storage_offset());
TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments);
CHECK_EQ(as_strided_out.size(), 1);
return {GenerateClone(as_strided_out.front())};
}
TorchMlirOpVector
LowerAsStridedViewUpdate(const torch::lazy::AsStridedViewUpdate* node) {
torch::jit::Value* destination =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
const torch::lazy::Output& input_op = node->operand(1);
const torch::lazy::Shape& input_shape = input_op.shape();
const auto input_dimensions = input_shape.sizes();
std::vector<torch::jit::NamedValue> dest_arguments;
dest_arguments.emplace_back(destination);
dest_arguments.emplace_back(
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
dest_arguments.emplace_back(node->stride());
dest_arguments.emplace_back(node->storage_offset());
TorchMlirOpVector as_strided_out =
LowerBuiltin(at::aten::as_strided, dest_arguments);
CHECK_EQ(as_strided_out.size(), 1);
torch::jit::Value* as_strided = as_strided_out.front();
GenerateCopy(as_strided, loctx()->GetOutputOp(input_op));
return {destination};
}
TorchMlirOpVector
LowerBatchNorm(const torch::lazy::NativeBatchNormForward* node) {
std::vector<torch::jit::NamedValue> arguments;
for (size_t i = 0; i < 5; ++i) {
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
}
arguments.emplace_back(node->training());
arguments.emplace_back(node->momentum());
arguments.emplace_back(node->eps());
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector
LowerBatchNormBackward(const torch::lazy::NativeBatchNormBackward* node) {
std::vector<torch::jit::NamedValue> arguments;
for (size_t i = 0; i < 3; ++i) {
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
}
const auto& operands = node->operands();
c10::optional<at::Tensor> null_arg;
if (operands.size() == 5) {
arguments.emplace_back(null_arg);
arguments.emplace_back(null_arg);
}
for (size_t i = 3; i < operands.size(); ++i) {
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
}
arguments.emplace_back(node->training());
arguments.emplace_back(node->eps());
arguments.emplace_back(node->output_mask());
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->dtype());
return LowerBuiltin(at::aten::to, arguments);
}
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->size());
auto expand_out = LowerBuiltin(node, arguments);
if (node->is_scalar_expand()) {
// The aten::expand operations sets all strides to 0 when the original
// of rank 0. This leads to false positives when checking for internal
// memory overlap, because at::has_internal_overlap returns
// MemOverlap::YES when a stride is set to 0.
CHECK_EQ(expand_out.size(), 1);
return {GenerateClone(expand_out.front())};
}
return expand_out;
}
TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) {
const torch::lazy::Output& input = node->operand(0);
torch::jit::Value* base = loctx()->GetOutputOp(input);
const auto& base_indices = node->base_indices();
const auto& sizes = node->sizes();
const torch::lazy::Shape& input_shape = input.shape();
CHECK_EQ(sizes.size(), base_indices.size());
CHECK_EQ(input_shape.dim(), base_indices.size());
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
int64_t start = base_indices[dim];
base = GenerateSlice(
/*base=*/base, /*dim=*/dim, /*start=*/start,
/*end=*/start + sizes[dim], /*step=*/1);
}
return {base};
}
TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->dims());
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) {
const at::Scalar& value = node->value();
const torch::lazy::Shape& shape = node->shape();
auto options =
at::TensorOptions()
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
.dtype(shape.scalar_type());
return {
loctx()->graph()->insertConstant(at::scalar_tensor(value, options))};
}
TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) {
int64_t step = torch::lazy::Select::GetStride(
node->start(), node->end(), node->stride());
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
return {GenerateSlice(
/*base=*/base, /*dim=*/node->dim(),
/*start=*/node->start(), /*end=*/node->end(),
/*step=*/step)};
}
TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
if (node->dim() != -1) {
arguments.push_back(node->dim());
}
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector
LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
torch::jit::Value* dest =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
int64_t step = torch::lazy::Select::GetStride(
node->start(), node->end(), node->stride());
torch::jit::Value* selected = GenerateSlice(
/*base=*/dest, /*dim=*/node->dim(), /*start=*/node->start(),
/*end=*/node->end(), /*step=*/step);
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
return {dest};
}
TorchMlirOpVector
LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
torch::jit::Value* dest =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
const auto& base_indices = node->base_indices();
const torch::lazy::Output& source_argument = node->operand(1);
const torch::lazy::Shape& source_shape = source_argument.shape();
CHECK_EQ(source_shape.dim(), base_indices.size());
torch::jit::Value* base = dest;
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
int64_t start = base_indices[dim];
base = GenerateSlice(
/*base=*/base, /*dim=*/dim, /*start=*/start,
/*end=*/start + source_shape.size(dim),
/*step=*/1);
}
GenerateCopy(base, loctx()->GetOutputOp(source_argument));
return {dest};
}
TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->dim());
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerView(const torch::lazy::View* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->output_size());
return LowerBuiltin(at::aten::reshape, arguments);
}
torch::jit::Value* GenerateClone(torch::jit::Value* val) {
std::vector<torch::jit::NamedValue> clone_arguments;
clone_arguments.emplace_back(val);
TorchMlirOpVector cloned = LowerBuiltin(at::aten::clone, clone_arguments);
CHECK_EQ(cloned.size(), 1);
return cloned.front();
}
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(destination);
arguments.emplace_back(source);
LowerBuiltin(at::aten::copy_, arguments);
}
torch::jit::Value* GenerateSlice(
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
int64_t step) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(base);
arguments.emplace_back(dim);
arguments.emplace_back(start);
arguments.emplace_back(end);
arguments.emplace_back(step);
TorchMlirOpVector selected = LowerBuiltin(at::aten::slice, arguments);
CHECK_EQ(selected.size(), 1);
return selected.front();
}
torch::lazy::TorchMlirLoweringContext* loctx_;
std::shared_ptr<torch::jit::GraphFunction> function_;
};
std::unique_ptr<TorchMlirNodeLoweringInterface>
TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) {
return std::make_unique<TorchMlirNodeLowering>(
"TorchMlirNodeLowering",
static_cast<torch::lazy::TorchMlirLoweringContext*>(loctx));
}
TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
auto builtin =
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
CHECK(sv);
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
const auto tuple_call_result = sv->asTuple({}, *function);
TorchMlirOpVector tuple_result;
for (const auto& tuple_component : tuple_call_result) {
auto tuple_component_sv =
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
tuple_result.push_back(tuple_component_sv->getValue());
}
return tuple_result;
}
return {sv->getValue()};
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,30 @@
//===- mlir_node_lowering.h -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.h
//===----------------------------------------------------------------------===//
#pragma once
#include <torch/csrc/api/include/torch/jit.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
namespace torch {
namespace lazy {
typedef std::vector<torch::jit::Value*> TorchMlirOpVector;
typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin(
TorchMlirFunction function, c10::Symbol sym,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {});
} // namespace lazy
} // namespace torch

View File

@ -21,3 +21,7 @@ static const bool verbose_print_function =
std::cout << __PRETTY_FUNCTION__ << " (" << __FILE__ << ":" << __LINE__ \
<< ")" << std::endl; \
}
#define PRINT_DEBUG(msg) \
std::cout << msg << " (" << __FILE__ << ":" << __LINE__ << ")" \
<< std::endl;

View File

@ -10,7 +10,7 @@ include_directories(BEFORE
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(TorchMLIRJITIRImporter MODULE
add_library(TorchMLIRJITIRImporter SHARED
class_annotator.cpp
class_annotator_pybind.cpp
get_registered_ops.cpp

View File

@ -38,7 +38,7 @@ namespace torch_mlir {
/// will be attached as an argument attribute to the func op's argument. If a
/// null MlirAttribute is returned, no attribute will be attached to that
/// argument.
MlirOperation importJitFunctionAsFuncOp(
TORCH_API MlirOperation importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function *function,
std::function<MlirAttribute(int)> getArgAttribute =
[](int) -> MlirAttribute { return {nullptr}; });