mirror of https://github.com/llvm/torch-mlir
Added JIT to MLIR lowering (#724)
* Added JIT to MLIR lowering Lowering to JIT is performed in a way similar to how it's done in the TS LTC backend. After a jit::Graph is constructed, it gets converted to a jit::Function, which is fed into the existing utility to generate an MlirModule in torch-mlir. * Renamed `csrc/backend` to `csrc/base_lazy_backend`pull/1125/head
parent
65cf1465ef
commit
3e9b1cbd36
|
@ -23,15 +23,15 @@ __pycache__
|
|||
# Bazel
|
||||
bazel-*
|
||||
|
||||
# Autogenerated files
|
||||
/generated_native_functions.yaml
|
||||
/generated_backend.hash
|
||||
/python/torch_mlir/csrc/backend/LazyIr.h
|
||||
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
|
||||
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
|
||||
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp
|
||||
/python/torch_mlir/csrc/backend/RegisterLazy.cpp
|
||||
|
||||
# Libraries
|
||||
*.so
|
||||
*.a
|
||||
|
||||
# Autogenerated files
|
||||
/generated_native_functions.yaml
|
||||
/generated_backend.hash
|
||||
/python/torch_mlir/csrc/base_lazy_backend/LazyIr.h
|
||||
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.cpp
|
||||
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h
|
||||
/python/torch_mlir/csrc/base_lazy_backend/GenLazyShapeInference.cpp
|
||||
/python/torch_mlir/csrc/base_lazy_backend/RegisterLazy.cpp
|
||||
|
|
|
@ -300,7 +300,7 @@ def main(args):
|
|||
"generated_native_functions.yaml"
|
||||
)
|
||||
backend_path = TORCH_MLIR_DIR.joinpath(
|
||||
"python", "torch_mlir", "csrc", "backend"
|
||||
"python", "torch_mlir", "csrc", "base_lazy_backend"
|
||||
)
|
||||
assert backend_path.is_dir()
|
||||
|
||||
|
|
|
@ -20,18 +20,23 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
|||
|
||||
|
||||
add_library(torch_mlir_ltc_backend SHARED
|
||||
backend/backend_impl.cpp
|
||||
backend/LazyNativeFunctions.cpp
|
||||
backend/LazyShapeInference.cpp
|
||||
backend/GenLazyShapeInference.cpp
|
||||
backend/mlir_lowering_context.cpp
|
||||
backend/mlir_native_functions.cpp
|
||||
backend/mlir_node.cpp
|
||||
backend/RegisterLazy.cpp
|
||||
base_lazy_backend/backend_impl.cpp
|
||||
base_lazy_backend/LazyNativeFunctions.cpp
|
||||
base_lazy_backend/LazyShapeInference.cpp
|
||||
base_lazy_backend/GenLazyShapeInference.cpp
|
||||
base_lazy_backend/mlir_lowering_context.cpp
|
||||
base_lazy_backend/mlir_native_functions.cpp
|
||||
base_lazy_backend/mlir_node.cpp
|
||||
base_lazy_backend/mlir_node_lowering.cpp
|
||||
base_lazy_backend/RegisterLazy.cpp
|
||||
)
|
||||
|
||||
add_dependencies(torch_mlir_ltc_backend
|
||||
TorchMLIRJITIRImporter
|
||||
)
|
||||
target_link_libraries(torch_mlir_ltc_backend
|
||||
TorchMLIRAggregateCAPI
|
||||
TorchMLIRJITIRImporter
|
||||
${TORCH_LIBRARIES}
|
||||
${Python3_LIBRARIES}
|
||||
torch_python
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
//===- mlir_lowering_context.cpp ------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "../utils/debug.h"
|
||||
#include "../utils/exception.h"
|
||||
#include "mlir_lowering_context.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||
const std::string& name, BackendDevice device)
|
||||
: LoweringContext(name, std::forward<BackendDevice>(device)) {}
|
||||
|
||||
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||
const std::string& name, BackendDevice device,
|
||||
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
|
||||
: LoweringContext(
|
||||
name, std::forward<BackendDevice>(device),
|
||||
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
||||
std::forward<Util::EmissionMap>(emit_status)) {}
|
||||
|
||||
int TorchMlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); }
|
||||
|
||||
const std::vector<torch::lazy::Shape>&
|
||||
TorchMlirComputation::parameter_shapes() const {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
std::string TorchMlirComputation::to_string() const {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
// Get the shape of the result tuple component, given by index.
|
||||
torch::lazy::Shape TorchMlirLoweringContext::GetResultShape(size_t index) const {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
// Adds the given output as a component of the result tuple and returns its
|
||||
// assigned position within the tuple.
|
||||
size_t TorchMlirLoweringContext::AddResult(const torch::lazy::Output& output) {
|
||||
PRINT_FUNCTION();
|
||||
const torch::lazy::Node* node;
|
||||
auto it = emitted_outputs_.find(output);
|
||||
if (it == emitted_outputs_.end()) {
|
||||
node = output.node;
|
||||
|
||||
auto post_order = Util::ComputePostOrder(node, &emit_status_);
|
||||
for (auto po_node : post_order) {
|
||||
// TODO: uncomment after lowering is implemented
|
||||
// bool ok = lowering_->Lower(node);
|
||||
// TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
|
||||
}
|
||||
emitted_outputs_[output] = node;
|
||||
} else {
|
||||
node = it->second;
|
||||
}
|
||||
result_tuple_.emplace_back(node);
|
||||
return result_tuple_.size() - 1;
|
||||
}
|
||||
|
||||
// Associates the given output with the input parameter of the given index and
|
||||
// shape. Only used for the operator-by-operator execution, mostly for
|
||||
// debugging purposes.
|
||||
void TorchMlirLoweringContext::AddParameter(
|
||||
const torch::lazy::Output& output, size_t index,
|
||||
const torch::lazy::Shape& shape, const std::string& name) {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
// Build the computation capturing all the operations created with the
|
||||
// embedded builder (returned by the builder() API).
|
||||
ComputationPtr TorchMlirLoweringContext::Build() {
|
||||
PRINT_FUNCTION()
|
||||
for (const torch::lazy::Node* output : result_tuple_) {
|
||||
}
|
||||
return std::make_shared<TorchMlirComputation>();
|
||||
}
|
||||
|
||||
// Retrieves the lowered operation for an output. If the requested output is
|
||||
// not available yet, the graph behind the output's Node is lowered, and the
|
||||
// corresponding MLIR operation returned.
|
||||
torch::jit::Value* GetOutputOp(const Output& output) {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -1,72 +0,0 @@
|
|||
//===- mlir_lowering_context.h --------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/ts_lowering_context.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
|
||||
public:
|
||||
int parameters_size() const override;
|
||||
|
||||
virtual const std::vector<torch::lazy::Shape>&
|
||||
parameter_shapes() const override;
|
||||
|
||||
virtual const std::vector<std::string>& parameter_names() const override;
|
||||
|
||||
virtual const torch::lazy::Shape& result_shape() const override;
|
||||
};
|
||||
|
||||
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
|
||||
public:
|
||||
TorchMlirLoweringContext(
|
||||
const std::string& name, torch::lazy::BackendDevice device);
|
||||
TorchMlirLoweringContext(
|
||||
const std::string& name, torch::lazy::BackendDevice device,
|
||||
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||
torch::lazy::Util::EmissionMap emit_status);
|
||||
|
||||
// Get the shape of the result tuple component, given by index.
|
||||
virtual torch::lazy::Shape GetResultShape(size_t index) const override;
|
||||
|
||||
// Adds the given output as a component of the result tuple and returns its
|
||||
// assigned position within the tuple.
|
||||
virtual size_t AddResult(const torch::lazy::Output& output) override;
|
||||
|
||||
// Associates the given output with the input parameter of the given index and
|
||||
// shape. Only used for the operator-by-operator execution, mostly for
|
||||
// debugging purposes.
|
||||
virtual void AddParameter(
|
||||
const torch::lazy::Output& output, size_t index,
|
||||
const torch::lazy::Shape& shape, const std::string& name) override;
|
||||
|
||||
// Build the computation capturing all the operations created with the
|
||||
// embedded builder (returned by the builder() API).
|
||||
virtual torch::lazy::ComputationPtr Build() override;
|
||||
|
||||
// Retrieves the lowered operation for an output. If the requested output is
|
||||
// not available yet, the graph behind the output's Node is lowered, and the
|
||||
// corresponding MLIR operation returned.
|
||||
torch::jit::Value* GetOutputOp(const Output& output);
|
||||
|
||||
private:
|
||||
std::vector<const torch::lazy::Node*> result_tuple_;
|
||||
torch::lazy::OutputMap<const torch::lazy::Node*> emitted_outputs_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -95,7 +95,8 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
|||
TorchMlirBackendData::Info* info =
|
||||
dynamic_cast<TorchMlirBackendData::Info*>(data->info());
|
||||
TORCH_CHECK(
|
||||
info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
||||
info,
|
||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
||||
return info->tensor;
|
||||
}
|
||||
|
|
@ -41,7 +41,8 @@ public:
|
|||
|
||||
TorchMlirBackendData(BackendDevice device, Shape shape);
|
||||
TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device);
|
||||
TorchMlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
|
||||
TorchMlirBackendData(
|
||||
const at::Tensor& tensor, BackendDevice device, Shape shape);
|
||||
|
||||
virtual BackendData::Handle GetHandle() override;
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
//===- mlir_lowering_context.cpp ------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||
|
||||
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
|
||||
#include "../utils/debug.h"
|
||||
#include "../utils/exception.h"
|
||||
#include "backend_impl.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir_lowering_context.h"
|
||||
#include "mlir_node.h"
|
||||
#include "torch-mlir-c/Registration.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TorchMlir Computation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TorchMlirComputation::TorchMlirComputation(
|
||||
MlirOperation func_op, MlirContext mlir_context,
|
||||
const std::shared_ptr<torch::jit::Graph>& graph)
|
||||
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
|
||||
graph_(graph), num_results_(graph_->outputs().size()) {
|
||||
|
||||
// TODO(henrytu): Save parameter shape information.
|
||||
|
||||
for (torch::jit::Value* input : graph_->inputs()) {
|
||||
parameter_names_.push_back(input->debugName());
|
||||
}
|
||||
}
|
||||
|
||||
int TorchMlirComputation::parameters_size() const {
|
||||
return parameter_names_.size();
|
||||
}
|
||||
|
||||
const std::vector<torch::lazy::Shape>&
|
||||
TorchMlirComputation::parameter_shapes() const {
|
||||
throw std::runtime_error(
|
||||
"todo(whc) implement ts computation shapes or change interface");
|
||||
return parameter_shapes_;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
|
||||
return parameter_names_;
|
||||
}
|
||||
|
||||
const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
|
||||
throw std::runtime_error(
|
||||
"todo(whc) implement ts computation shapes or change interface");
|
||||
return result_shape_;
|
||||
}
|
||||
|
||||
unsigned TorchMlirComputation::num_results() const { return num_results_; }
|
||||
|
||||
MlirOperation TorchMlirComputation::func_op() const { return func_op_; }
|
||||
|
||||
std::string TorchMlirComputation::to_string() const {
|
||||
// Since we use the C-MLIR API, we need to use a callback to print.
|
||||
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
|
||||
// user_data is a void ptr to some data structure of our choice -- in this
|
||||
// case, the string stream where we'll be accumulating the strings.
|
||||
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
|
||||
*ss_ptr << std::string(part.data, part.length);
|
||||
};
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "JIT Graph: \n"
|
||||
<< graph_->toString() << "\n\n"
|
||||
<< "MLIR: \n";
|
||||
mlirOperationPrint(func_op_, print_callback, &ss);
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TorchMlir Lowering Context
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||
const std::string& name, BackendDevice device)
|
||||
: LoweringContext(name, std::forward<BackendDevice>(device)),
|
||||
graph_(std::make_shared<torch::jit::Graph>()),
|
||||
mlir_context_(mlirContextCreate()) {
|
||||
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
|
||||
RegisterMlirDialects();
|
||||
}
|
||||
|
||||
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||
const std::string& name, BackendDevice device,
|
||||
c10::ArrayRef<torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
|
||||
: LoweringContext(
|
||||
name, std::forward<BackendDevice>(device),
|
||||
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
||||
std::forward<Util::EmissionMap>(emit_status)),
|
||||
graph_(std::make_shared<torch::jit::Graph>()),
|
||||
mlir_context_(mlirContextCreate()) {
|
||||
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
|
||||
for (auto node : post_order) {
|
||||
bool ok = lowering_->Lower(node);
|
||||
CHECK(ok) << "Failed to lower: " << *node;
|
||||
}
|
||||
|
||||
RegisterMlirDialects();
|
||||
}
|
||||
|
||||
// Get the shape of the result tuple component, given by index.
|
||||
torch::lazy::Shape
|
||||
TorchMlirLoweringContext::GetResultShape(size_t index) const {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
size_t TorchMlirLoweringContext::AddResult(const Output& output) {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
return AddResult(GetOutputOp(output));
|
||||
}
|
||||
|
||||
// Associates the given output with the input parameter of the given index and
|
||||
// shape. Only used for the operator-by-operator execution, mostly for
|
||||
// debugging purposes.
|
||||
void TorchMlirLoweringContext::AddParameter(
|
||||
const torch::lazy::Output& output, size_t index,
|
||||
const torch::lazy::Shape& shape, const std::string& name) {
|
||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||
}
|
||||
|
||||
// Build the computation capturing all the operations created with the
|
||||
// embedded builder (returned by the builder() API).
|
||||
ComputationPtr TorchMlirLoweringContext::Build() {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
for (torch::jit::Value* output : root_tuple_) {
|
||||
graph_->block()->registerOutput(output);
|
||||
}
|
||||
|
||||
// Create jit::Function from jit::Graph.
|
||||
c10::QualifiedName name("graph");
|
||||
auto cu = std::make_shared<torch::jit::CompilationUnit>();
|
||||
// IMPORTANT: We pass in a COPY of the graph into create_function, since it
|
||||
// may get mutated in the process.
|
||||
auto jit_fn = cu->create_function(std::move(name), std::move(graph_->copy()));
|
||||
|
||||
// Generate MLIR.
|
||||
MlirOperation func_op =
|
||||
torch_mlir::importJitFunctionAsFuncOp(mlir_context_, jit_fn);
|
||||
|
||||
// TODO(henrytu): Inject tensor shapes into func_op
|
||||
return std::make_shared<TorchMlirComputation>(func_op, mlir_context_, graph_);
|
||||
}
|
||||
|
||||
torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
auto it = emitted_outputs_.find(output);
|
||||
if (it == emitted_outputs_.end()) {
|
||||
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
|
||||
for (auto node : post_order) {
|
||||
bool ok = lowering_->Lower(node);
|
||||
TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
|
||||
}
|
||||
// At this point the output better be present, otherwise there is an issue
|
||||
// with the lowering code.
|
||||
it = emitted_outputs_.find(output);
|
||||
TORCH_CHECK(
|
||||
it != emitted_outputs_.end(),
|
||||
"No MLIR operation emitted for output: ", output.ToString());
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void TorchMlirLoweringContext::AssignOutputOp(
|
||||
const Output& output, torch::jit::Value* op) {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
auto torch_mlir_node =
|
||||
NodeCast<TorchMlirNode>(output.node, output.node->op());
|
||||
if (!torch_mlir_node->getPythonStacktrace().empty()) {
|
||||
op->node()->s_(
|
||||
c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace());
|
||||
}
|
||||
emitted_outputs_[output] = std::move(op);
|
||||
}
|
||||
|
||||
torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
if (!dynamic_cast<TorchMlirBackendData*>(data.get())) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Expected TorchMlirBackendData. Got some other BackendData type");
|
||||
}
|
||||
const auto mlir_data = std::static_pointer_cast<TorchMlirBackendData>(data);
|
||||
|
||||
BackendData::Handle handle = mlir_data->GetHandle();
|
||||
auto it = parameters_map_.find(handle);
|
||||
|
||||
if (it == parameters_map_.end()) {
|
||||
torch::jit::Value* param =
|
||||
graph_->addInput(c10::str("p", parameters_.size()));
|
||||
|
||||
auto info = mlir_data->mlir_info();
|
||||
if (info->scalar.has_value()) {
|
||||
auto& scalar = info->scalar.value();
|
||||
if (scalar.isFloatingPoint()) {
|
||||
param->setType(c10::FloatType::get());
|
||||
} else if (scalar.isIntegral(true)) {
|
||||
param->setType(c10::IntType::get());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Unhandled scalar type: ", c10::toString(scalar.type()));
|
||||
}
|
||||
}
|
||||
|
||||
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
|
||||
.first;
|
||||
parameters_.push_back(mlir_data);
|
||||
}
|
||||
|
||||
parameter_sequence_.push_back(it->second.index);
|
||||
return it->second.param;
|
||||
}
|
||||
|
||||
std::shared_ptr<torch::jit::Graph> TorchMlirLoweringContext::graph() const {
|
||||
return graph_;
|
||||
}
|
||||
|
||||
size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) {
|
||||
PRINT_FUNCTION();
|
||||
root_tuple_.push_back(std::move(op));
|
||||
return root_tuple_.size() - 1;
|
||||
}
|
||||
|
||||
void TorchMlirLoweringContext::RegisterMlirDialects() {
|
||||
// https://reviews.llvm.org/D88162
|
||||
mlirRegisterAllDialects(mlir_context_);
|
||||
torchMlirRegisterAllDialects(mlir_context_);
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,136 @@
|
|||
//===- mlir_lowering_context.h --------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/api/include/torch/jit.h>
|
||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir_node_lowering.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API TorchMlirNodeLoweringInterface {
|
||||
/**
|
||||
* This interface is only needed for legacy ops, and can be removed once all
|
||||
* ops implement LtcMlirNode->lower().
|
||||
* */
|
||||
public:
|
||||
TorchMlirNodeLoweringInterface() = default;
|
||||
|
||||
virtual ~TorchMlirNodeLoweringInterface() = default;
|
||||
|
||||
virtual bool Lower(const Node* node) = 0;
|
||||
|
||||
static std::unique_ptr<TorchMlirNodeLoweringInterface>
|
||||
Create(LoweringContext* loctx);
|
||||
};
|
||||
|
||||
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
|
||||
public:
|
||||
TorchMlirComputation(
|
||||
MlirOperation func_op, MlirContext mlir_context,
|
||||
const std::shared_ptr<torch::jit::Graph>& graph);
|
||||
|
||||
int parameters_size() const override;
|
||||
|
||||
const std::vector<torch::lazy::Shape>& parameter_shapes() const override;
|
||||
|
||||
const std::vector<std::string>& parameter_names() const override;
|
||||
|
||||
const torch::lazy::Shape& result_shape() const override;
|
||||
|
||||
unsigned num_results() const;
|
||||
|
||||
MlirOperation func_op() const;
|
||||
|
||||
std::string to_string() const;
|
||||
|
||||
private:
|
||||
std::vector<std::string> parameter_names_;
|
||||
std::vector<Shape> parameter_shapes_;
|
||||
Shape result_shape_;
|
||||
|
||||
MlirOperation func_op_;
|
||||
MlirContext mlir_context_;
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
unsigned num_results_;
|
||||
};
|
||||
|
||||
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
|
||||
public:
|
||||
TorchMlirLoweringContext(
|
||||
const std::string& name, torch::lazy::BackendDevice device);
|
||||
TorchMlirLoweringContext(
|
||||
const std::string& name, torch::lazy::BackendDevice device,
|
||||
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||
torch::lazy::Util::EmissionMap emit_status);
|
||||
|
||||
// Get the shape of the result tuple component, given by index.
|
||||
torch::lazy::Shape GetResultShape(size_t index) const override;
|
||||
|
||||
// Adds the given output as a component of the result tuple and returns its
|
||||
// assigned position within the tuple.
|
||||
size_t AddResult(const torch::lazy::Output& output) override;
|
||||
|
||||
// Associates the given output with the input parameter of the given index and
|
||||
// shape. Only used for the operator-by-operator execution, mostly for
|
||||
// debugging purposes.
|
||||
void AddParameter(
|
||||
const torch::lazy::Output& output, size_t index,
|
||||
const torch::lazy::Shape& shape, const std::string& name) override;
|
||||
|
||||
// Build the computation capturing all the operations created with the
|
||||
// embedded builder (returned by the builder() API).
|
||||
torch::lazy::ComputationPtr Build() override;
|
||||
|
||||
// Retrieves the lowered operation for an output. If the requested output is
|
||||
// not available yet, the graph behind the output's Node is lowered, and the
|
||||
// corresponding TS operation returned.
|
||||
torch::jit::Value* GetOutputOp(const Output& output);
|
||||
|
||||
// Assigns the given TS operation to the specified output. As outputs are
|
||||
// lowered in a post-order fashion, later nodes should always find their
|
||||
// operands among the emitted outputs.
|
||||
void AssignOutputOp(const Output& output, torch::jit::Value* op);
|
||||
|
||||
// If a parameter associated with data has already been declared, it will be
|
||||
// returned. Otherwise a new one will be created, associated with the tensor
|
||||
// held in data.
|
||||
torch::jit::Value* GetParameter(BackendDataPtr data);
|
||||
|
||||
std::shared_ptr<torch::jit::Graph> graph() const;
|
||||
|
||||
private:
|
||||
struct Parameter {
|
||||
torch::jit::Value* param;
|
||||
size_t index = 0;
|
||||
};
|
||||
|
||||
size_t AddResult(torch::jit::Value* op);
|
||||
|
||||
void RegisterMlirDialects();
|
||||
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
MlirContext mlir_context_;
|
||||
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
|
||||
std::vector<torch::jit::Value*> root_tuple_;
|
||||
OutputMap<torch::jit::Value*> emitted_outputs_;
|
||||
std::unique_ptr<TorchMlirNodeLoweringInterface> lowering_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -18,8 +18,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
TorchMlirOpVector
|
||||
TorchMlirNode::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
TorchMlirOpVector TorchMlirNode::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
return {};
|
||||
}
|
||||
|
|
@ -25,9 +25,6 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
typedef std::vector<torch::jit::Value*> TorchMlirOpVector;
|
||||
typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
|
||||
|
||||
class TORCH_API TorchMlirNode : public torch::lazy::Node {
|
||||
public:
|
||||
using torch::lazy::Node::Node;
|
|
@ -0,0 +1,452 @@
|
|||
//===- mlir_node_lowering.cpp ---------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir_node_lowering.h"
|
||||
#include "mlir_lowering_context.h"
|
||||
#include "mlir_node.h"
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/csrc/jit/frontend/sugared_value.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/lazy/backend/backend_interface.h>
|
||||
#include <torch/csrc/lazy/core/helpers.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||
#include <torch/csrc/lazy/core/permutation_util.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/internal_ops/cast.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/device_data.h>
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/core/ops/batch_norm_ops.h>
|
||||
#include <torch/csrc/lazy/core/ops/expand.h>
|
||||
#include <torch/csrc/lazy/core/ops/scalar.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/as_strided.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/as_strided_view_update.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/narrow.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/narrow_view_update.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/permute.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/select.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/select_view_update.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/squeeze.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/unsqueeze.h>
|
||||
#include <torch/csrc/lazy/core/view_ops/view.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
|
||||
public:
|
||||
TorchMlirNodeLowering(
|
||||
const std::string& name, torch::lazy::TorchMlirLoweringContext* loctx)
|
||||
: loctx_(loctx), function_(
|
||||
loctx ? std::make_shared<torch::jit::GraphFunction>(
|
||||
name, loctx->graph(), nullptr)
|
||||
: nullptr) {}
|
||||
|
||||
torch::lazy::TorchMlirLoweringContext* loctx() { return loctx_; }
|
||||
|
||||
bool Lower(const torch::lazy::Node* node) override {
|
||||
if (auto* torch_mlir_node =
|
||||
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
|
||||
// First, we call the node lowering function, which exists for newly
|
||||
// codegenned or refactored nodes
|
||||
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, loctx());
|
||||
if (ops.empty()) {
|
||||
// Then fall back to legacy lowering code, which should be gradually
|
||||
// removed
|
||||
ops = LowerNonCodegenOps(node);
|
||||
}
|
||||
if (ops.empty()) {
|
||||
return false;
|
||||
}
|
||||
CHECK_EQ(node->num_outputs(), ops.size());
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
TorchMlirOpVector ops = LowerNonCodegenOps(node);
|
||||
if (!ops.empty()) {
|
||||
CHECK_EQ(node->num_outputs(), ops.size());
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
|
||||
}
|
||||
|
||||
// TODO(whc) this is for legacy/non-codegen Ops, and after moving most ops
|
||||
// to codegen we should delete this and put all the lowering logic into Node
|
||||
// classes
|
||||
TorchMlirOpVector LowerNonCodegenOps(const torch::lazy::Node* node) {
|
||||
|
||||
if (node->op().op == at::aten::as_strided) {
|
||||
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(
|
||||
node, torch::lazy::OpKind(at::aten::as_strided)));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_as_strided_view_update) {
|
||||
return LowerAsStridedViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(
|
||||
node, *torch::lazy::ltc_as_strided_view_update));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_cast) {
|
||||
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(
|
||||
node, *torch::lazy::ltc_cast));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_select_view_update) {
|
||||
return LowerSelectViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(
|
||||
node, *torch::lazy::ltc_select_view_update));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_narrow_view_update) {
|
||||
return LowerNarrowViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(
|
||||
node, *torch::lazy::ltc_narrow_view_update));
|
||||
}
|
||||
if (node->op().op == at::prim::Constant) {
|
||||
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(
|
||||
node, torch::lazy::OpKind(at::prim::Constant)));
|
||||
}
|
||||
if (node->op().op == at::aten::bernoulli) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
if (node->op().op == at::aten::native_batch_norm) {
|
||||
return LowerBatchNorm(
|
||||
torch::lazy::NodeCast<torch::lazy::NativeBatchNormForward>(
|
||||
node, torch::lazy::OpKind(at::aten::native_batch_norm)));
|
||||
}
|
||||
if (node->op().op == at::aten::native_batch_norm_backward) {
|
||||
return LowerBatchNormBackward(
|
||||
torch::lazy::NodeCast<torch::lazy::NativeBatchNormBackward>(
|
||||
node, torch::lazy::OpKind(at::aten::native_batch_norm_backward)));
|
||||
}
|
||||
if (node->op().op == at::aten::expand) {
|
||||
return LowerExpand(torch::lazy::NodeCast<torch::lazy::Expand>(
|
||||
node, torch::lazy::OpKind(at::aten::expand)));
|
||||
}
|
||||
if (node->op().op == at::aten::narrow) {
|
||||
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(
|
||||
node, torch::lazy::OpKind(at::aten::narrow)));
|
||||
}
|
||||
if (node->op().op == at::aten::permute) {
|
||||
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(
|
||||
node, torch::lazy::OpKind(at::aten::permute)));
|
||||
}
|
||||
if (node->op().op == at::aten::select) {
|
||||
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(
|
||||
node, torch::lazy::OpKind(at::aten::select)));
|
||||
}
|
||||
if (node->op().op == at::aten::squeeze) {
|
||||
return LowerSqueeze(torch::lazy::NodeCast<torch::lazy::Squeeze>(
|
||||
node, torch::lazy::OpKind(at::aten::squeeze)));
|
||||
}
|
||||
if (node->op().op == at::aten::unsqueeze) {
|
||||
return LowerUnsqueeze(torch::lazy::NodeCast<torch::lazy::Unsqueeze>(
|
||||
node, torch::lazy::OpKind(at::aten::unsqueeze)));
|
||||
}
|
||||
if (node->op().op == at::aten::view) {
|
||||
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(
|
||||
node, torch::lazy::OpKind(at::aten::view)));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_device_data) {
|
||||
const torch::lazy::DeviceData* device_data_node =
|
||||
torch::lazy::NodeCast<torch::lazy::DeviceData>(
|
||||
node, *torch::lazy::ltc_device_data);
|
||||
auto infoptr = device_data_node->data()->info();
|
||||
auto deviceDataInfoPtr =
|
||||
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
||||
if (GRAPH_DUMP_ENABLED) {
|
||||
LOG(ERROR) << "Lowering device data node, tensor id "
|
||||
<< deviceDataInfoPtr->tensor_id << std::endl;
|
||||
}
|
||||
return {loctx()->GetParameter(device_data_node->data())};
|
||||
}
|
||||
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (const torch::lazy::Output& output : node->operands()) {
|
||||
arguments.emplace_back(loctx()->GetOutputOp(output));
|
||||
}
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
const torch::lazy::Node* node,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(
|
||||
function_, node->op().op, arguments, kwarguments);
|
||||
}
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
c10::Symbol sym, const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(function_, sym, arguments, kwarguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->size());
|
||||
arguments.emplace_back(node->stride());
|
||||
arguments.emplace_back(node->storage_offset());
|
||||
TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
return {GenerateClone(as_strided_out.front())};
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerAsStridedViewUpdate(const torch::lazy::AsStridedViewUpdate* node) {
|
||||
torch::jit::Value* destination =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
const torch::lazy::Output& input_op = node->operand(1);
|
||||
const torch::lazy::Shape& input_shape = input_op.shape();
|
||||
const auto input_dimensions = input_shape.sizes();
|
||||
std::vector<torch::jit::NamedValue> dest_arguments;
|
||||
dest_arguments.emplace_back(destination);
|
||||
dest_arguments.emplace_back(
|
||||
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
|
||||
dest_arguments.emplace_back(node->stride());
|
||||
dest_arguments.emplace_back(node->storage_offset());
|
||||
TorchMlirOpVector as_strided_out =
|
||||
LowerBuiltin(at::aten::as_strided, dest_arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
torch::jit::Value* as_strided = as_strided_out.front();
|
||||
GenerateCopy(as_strided, loctx()->GetOutputOp(input_op));
|
||||
return {destination};
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerBatchNorm(const torch::lazy::NativeBatchNormForward* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
|
||||
}
|
||||
arguments.emplace_back(node->training());
|
||||
arguments.emplace_back(node->momentum());
|
||||
arguments.emplace_back(node->eps());
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerBatchNormBackward(const torch::lazy::NativeBatchNormBackward* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
|
||||
}
|
||||
const auto& operands = node->operands();
|
||||
c10::optional<at::Tensor> null_arg;
|
||||
if (operands.size() == 5) {
|
||||
arguments.emplace_back(null_arg);
|
||||
arguments.emplace_back(null_arg);
|
||||
}
|
||||
for (size_t i = 3; i < operands.size(); ++i) {
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(i)));
|
||||
}
|
||||
arguments.emplace_back(node->training());
|
||||
arguments.emplace_back(node->eps());
|
||||
arguments.emplace_back(node->output_mask());
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->dtype());
|
||||
return LowerBuiltin(at::aten::to, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->size());
|
||||
auto expand_out = LowerBuiltin(node, arguments);
|
||||
if (node->is_scalar_expand()) {
|
||||
// The aten::expand operations sets all strides to 0 when the original
|
||||
// of rank 0. This leads to false positives when checking for internal
|
||||
// memory overlap, because at::has_internal_overlap returns
|
||||
// MemOverlap::YES when a stride is set to 0.
|
||||
CHECK_EQ(expand_out.size(), 1);
|
||||
return {GenerateClone(expand_out.front())};
|
||||
}
|
||||
return expand_out;
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) {
|
||||
const torch::lazy::Output& input = node->operand(0);
|
||||
torch::jit::Value* base = loctx()->GetOutputOp(input);
|
||||
const auto& base_indices = node->base_indices();
|
||||
const auto& sizes = node->sizes();
|
||||
const torch::lazy::Shape& input_shape = input.shape();
|
||||
CHECK_EQ(sizes.size(), base_indices.size());
|
||||
CHECK_EQ(input_shape.dim(), base_indices.size());
|
||||
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
||||
int64_t start = base_indices[dim];
|
||||
base = GenerateSlice(
|
||||
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
||||
/*end=*/start + sizes[dim], /*step=*/1);
|
||||
}
|
||||
return {base};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->dims());
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) {
|
||||
const at::Scalar& value = node->value();
|
||||
const torch::lazy::Shape& shape = node->shape();
|
||||
auto options =
|
||||
at::TensorOptions()
|
||||
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
|
||||
.dtype(shape.scalar_type());
|
||||
return {
|
||||
loctx()->graph()->insertConstant(at::scalar_tensor(value, options))};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) {
|
||||
int64_t step = torch::lazy::Select::GetStride(
|
||||
node->start(), node->end(), node->stride());
|
||||
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
|
||||
return {GenerateSlice(
|
||||
/*base=*/base, /*dim=*/node->dim(),
|
||||
/*start=*/node->start(), /*end=*/node->end(),
|
||||
/*step=*/step)};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
if (node->dim() != -1) {
|
||||
arguments.push_back(node->dim());
|
||||
}
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
int64_t step = torch::lazy::Select::GetStride(
|
||||
node->start(), node->end(), node->stride());
|
||||
torch::jit::Value* selected = GenerateSlice(
|
||||
/*base=*/dest, /*dim=*/node->dim(), /*start=*/node->start(),
|
||||
/*end=*/node->end(), /*step=*/step);
|
||||
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
|
||||
return {dest};
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
const auto& base_indices = node->base_indices();
|
||||
const torch::lazy::Output& source_argument = node->operand(1);
|
||||
const torch::lazy::Shape& source_shape = source_argument.shape();
|
||||
CHECK_EQ(source_shape.dim(), base_indices.size());
|
||||
torch::jit::Value* base = dest;
|
||||
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
||||
int64_t start = base_indices[dim];
|
||||
base = GenerateSlice(
|
||||
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
||||
/*end=*/start + source_shape.size(dim),
|
||||
/*step=*/1);
|
||||
}
|
||||
GenerateCopy(base, loctx()->GetOutputOp(source_argument));
|
||||
return {dest};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->dim());
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerView(const torch::lazy::View* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->output_size());
|
||||
return LowerBuiltin(at::aten::reshape, arguments);
|
||||
}
|
||||
|
||||
torch::jit::Value* GenerateClone(torch::jit::Value* val) {
|
||||
std::vector<torch::jit::NamedValue> clone_arguments;
|
||||
clone_arguments.emplace_back(val);
|
||||
TorchMlirOpVector cloned = LowerBuiltin(at::aten::clone, clone_arguments);
|
||||
CHECK_EQ(cloned.size(), 1);
|
||||
return cloned.front();
|
||||
}
|
||||
|
||||
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(destination);
|
||||
arguments.emplace_back(source);
|
||||
LowerBuiltin(at::aten::copy_, arguments);
|
||||
}
|
||||
|
||||
torch::jit::Value* GenerateSlice(
|
||||
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
|
||||
int64_t step) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(base);
|
||||
arguments.emplace_back(dim);
|
||||
arguments.emplace_back(start);
|
||||
arguments.emplace_back(end);
|
||||
arguments.emplace_back(step);
|
||||
TorchMlirOpVector selected = LowerBuiltin(at::aten::slice, arguments);
|
||||
CHECK_EQ(selected.size(), 1);
|
||||
return selected.front();
|
||||
}
|
||||
torch::lazy::TorchMlirLoweringContext* loctx_;
|
||||
std::shared_ptr<torch::jit::GraphFunction> function_;
|
||||
};
|
||||
|
||||
std::unique_ptr<TorchMlirNodeLoweringInterface>
|
||||
TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) {
|
||||
return std::make_unique<TorchMlirNodeLowering>(
|
||||
"TorchMlirNodeLowering",
|
||||
static_cast<torch::lazy::TorchMlirLoweringContext*>(loctx));
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
||||
auto builtin =
|
||||
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
|
||||
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
|
||||
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
|
||||
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
|
||||
CHECK(sv);
|
||||
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
|
||||
const auto tuple_call_result = sv->asTuple({}, *function);
|
||||
TorchMlirOpVector tuple_result;
|
||||
for (const auto& tuple_component : tuple_call_result) {
|
||||
auto tuple_component_sv =
|
||||
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
|
||||
tuple_result.push_back(tuple_component_sv->getValue());
|
||||
}
|
||||
return tuple_result;
|
||||
}
|
||||
return {sv->getValue()};
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,30 @@
|
|||
//===- mlir_node_lowering.h -----------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file is adapted from pytorch/pytorch
|
||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/api/include/torch/jit.h>
|
||||
#include <torch/csrc/lazy/backend/lowering_context.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
typedef std::vector<torch::jit::Value*> TorchMlirOpVector;
|
||||
typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
|
||||
|
||||
TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||
TorchMlirFunction function, c10::Symbol sym,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {});
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -21,3 +21,7 @@ static const bool verbose_print_function =
|
|||
std::cout << __PRETTY_FUNCTION__ << " (" << __FILE__ << ":" << __LINE__ \
|
||||
<< ")" << std::endl; \
|
||||
}
|
||||
|
||||
#define PRINT_DEBUG(msg) \
|
||||
std::cout << msg << " (" << __FILE__ << ":" << __LINE__ << ")" \
|
||||
<< std::endl;
|
|
@ -10,7 +10,7 @@ include_directories(BEFORE
|
|||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
add_library(TorchMLIRJITIRImporter MODULE
|
||||
add_library(TorchMLIRJITIRImporter SHARED
|
||||
class_annotator.cpp
|
||||
class_annotator_pybind.cpp
|
||||
get_registered_ops.cpp
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace torch_mlir {
|
|||
/// will be attached as an argument attribute to the func op's argument. If a
|
||||
/// null MlirAttribute is returned, no attribute will be attached to that
|
||||
/// argument.
|
||||
MlirOperation importJitFunctionAsFuncOp(
|
||||
TORCH_API MlirOperation importJitFunctionAsFuncOp(
|
||||
MlirContext context, torch::jit::Function *function,
|
||||
std::function<MlirAttribute(int)> getArgAttribute =
|
||||
[](int) -> MlirAttribute { return {nullptr}; });
|
||||
|
|
Loading…
Reference in New Issue