mirror of https://github.com/llvm/torch-mlir
Refactor Node Lowering (#914)
parent
d9aee0d7a7
commit
8312fa535b
|
@ -36,8 +36,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
const std::string& name, BackendDevice device)
|
const std::string& name, BackendDevice device)
|
||||||
: LoweringContext(name, std::forward<BackendDevice>(device)),
|
: LoweringContext(name, std::forward<BackendDevice>(device)),
|
||||||
graph_(std::make_shared<torch::jit::Graph>()),
|
graph_(std::make_shared<torch::jit::Graph>()),
|
||||||
|
function_(
|
||||||
|
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
|
||||||
mlir_context_(mlirContextCreate()) {
|
mlir_context_(mlirContextCreate()) {
|
||||||
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
|
|
||||||
RegisterMlirDialects();
|
RegisterMlirDialects();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,16 +50,31 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
||||||
std::forward<Util::EmissionMap>(emit_status)),
|
std::forward<Util::EmissionMap>(emit_status)),
|
||||||
graph_(std::make_shared<torch::jit::Graph>()),
|
graph_(std::make_shared<torch::jit::Graph>()),
|
||||||
|
function_(
|
||||||
|
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
|
||||||
mlir_context_(mlirContextCreate()) {
|
mlir_context_(mlirContextCreate()) {
|
||||||
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
|
|
||||||
for (auto node : post_order) {
|
for (auto node : post_order) {
|
||||||
bool ok = lowering_->Lower(node);
|
Lower(node);
|
||||||
CHECK(ok) << "Failed to lower: " << *node;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RegisterMlirDialects();
|
RegisterMlirDialects();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TorchMlirLoweringContext::Lower(const Node* node) {
|
||||||
|
if (auto* torch_mlir_node =
|
||||||
|
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
|
||||||
|
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this);
|
||||||
|
CHECK(!ops.empty()) << "Failed to lower: " << *node;
|
||||||
|
CHECK_EQ(node->num_outputs(), ops.size());
|
||||||
|
for (size_t i = 0; i < ops.size(); ++i) {
|
||||||
|
AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TorchMlirLoweringContext::SetUpAlias(
|
void TorchMlirLoweringContext::SetUpAlias(
|
||||||
const std::vector<int64_t>& output_index, int64_t param_number,
|
const std::vector<int64_t>& output_index, int64_t param_number,
|
||||||
const std::vector<int64_t>& param_index, bool must_alias) {
|
const std::vector<int64_t>& param_index, bool must_alias) {
|
||||||
|
@ -136,8 +152,7 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
|
||||||
if (it == emitted_outputs_.end()) {
|
if (it == emitted_outputs_.end()) {
|
||||||
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
|
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
|
||||||
for (auto node : post_order) {
|
for (auto node : post_order) {
|
||||||
bool ok = lowering_->Lower(node);
|
Lower(node);
|
||||||
TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
|
|
||||||
}
|
}
|
||||||
// At this point the output better be present, otherwise there is an issue
|
// At this point the output better be present, otherwise there is an issue
|
||||||
// with the lowering code.
|
// with the lowering code.
|
||||||
|
|
|
@ -23,22 +23,6 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
class TORCH_API TorchMlirNodeLoweringInterface {
|
|
||||||
/**
|
|
||||||
* This interface is only needed for legacy ops, and can be removed once all
|
|
||||||
* ops implement LtcMlirNode->lower().
|
|
||||||
* */
|
|
||||||
public:
|
|
||||||
TorchMlirNodeLoweringInterface() = default;
|
|
||||||
|
|
||||||
virtual ~TorchMlirNodeLoweringInterface() = default;
|
|
||||||
|
|
||||||
virtual bool Lower(const Node* node) = 0;
|
|
||||||
|
|
||||||
static std::unique_ptr<TorchMlirNodeLoweringInterface>
|
|
||||||
Create(LoweringContext* loctx);
|
|
||||||
};
|
|
||||||
|
|
||||||
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
|
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
|
||||||
public:
|
public:
|
||||||
// Describes an input/output alias as inserted by the SetUpAlias() API.
|
// Describes an input/output alias as inserted by the SetUpAlias() API.
|
||||||
|
@ -61,6 +45,8 @@ public:
|
||||||
c10::ArrayRef<torch::lazy::Node*> post_order,
|
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||||
torch::lazy::Util::EmissionMap emit_status);
|
torch::lazy::Util::EmissionMap emit_status);
|
||||||
|
|
||||||
|
void Lower(const Node* node);
|
||||||
|
|
||||||
// Adds a new input/output alias.
|
// Adds a new input/output alias.
|
||||||
void SetUpAlias(
|
void SetUpAlias(
|
||||||
const std::vector<int64_t>& output_index, int64_t param_number,
|
const std::vector<int64_t>& output_index, int64_t param_number,
|
||||||
|
@ -120,11 +106,11 @@ private:
|
||||||
// Holds the input/output alias information populated by the SetUpAlias() API.
|
// Holds the input/output alias information populated by the SetUpAlias() API.
|
||||||
InputOutputAliases input_output_aliases_;
|
InputOutputAliases input_output_aliases_;
|
||||||
std::shared_ptr<torch::jit::Graph> graph_;
|
std::shared_ptr<torch::jit::Graph> graph_;
|
||||||
|
std::shared_ptr<torch::jit::GraphFunction> function_;
|
||||||
MlirContext mlir_context_;
|
MlirContext mlir_context_;
|
||||||
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
|
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
|
||||||
std::vector<torch::jit::Value*> root_tuple_;
|
std::vector<torch::jit::Value*> root_tuple_;
|
||||||
OutputMap<torch::jit::Value*> emitted_outputs_;
|
OutputMap<torch::jit::Value*> emitted_outputs_;
|
||||||
std::unique_ptr<TorchMlirNodeLoweringInterface> lowering_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
|
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
|
||||||
|
|
|
@ -71,12 +71,6 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }
|
||||||
|
|
||||||
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
||||||
|
|
||||||
TorchMlirOpVector TorchMlirNode::Lower(
|
|
||||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
OpKind TorchMlirTensorList::ClassOpKind() {
|
OpKind TorchMlirTensorList::ClassOpKind() {
|
||||||
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
||||||
// import otherwise
|
// import otherwise
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
|
TorchMlirFunction function, c10::Symbol sym,
|
||||||
const std::vector<c10::TypePtr> tensor_types,
|
const std::vector<c10::TypePtr> tensor_types,
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
const std::vector<torch::jit::NamedValue>& arguments,
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
||||||
|
@ -81,7 +81,7 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
|
TorchMlirFunction function, c10::Symbol sym,
|
||||||
const c10::ArrayRef<Shape> result_shapes,
|
const c10::ArrayRef<Shape> result_shapes,
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
const std::vector<torch::jit::NamedValue>& arguments,
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
||||||
|
@ -101,6 +101,29 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
function, sym, tensor_types, arguments, kwarguments);
|
function, sym, tensor_types, arguments, kwarguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector LowerBuiltin(
|
||||||
|
const torch::lazy::Node* node, TorchMlirFunction function,
|
||||||
|
const std::vector<torch::jit::NamedValue>& arguments,
|
||||||
|
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||||
|
return LowerTorchMlirBuiltin(
|
||||||
|
function, node->op().op, node->shapes(), arguments, kwarguments);
|
||||||
|
}
|
||||||
|
TorchMlirOpVector LowerBuiltin(
|
||||||
|
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
|
||||||
|
TorchMlirFunction function,
|
||||||
|
const std::vector<torch::jit::NamedValue>& arguments,
|
||||||
|
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||||
|
return LowerTorchMlirBuiltin(
|
||||||
|
function, sym, result_shapes, arguments, kwarguments);
|
||||||
|
}
|
||||||
|
TorchMlirOpVector LowerBuiltin(
|
||||||
|
c10::Symbol sym, const std::vector<c10::TypePtr> types,
|
||||||
|
TorchMlirFunction function,
|
||||||
|
const std::vector<torch::jit::NamedValue>& arguments,
|
||||||
|
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||||
|
return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments);
|
||||||
|
}
|
||||||
|
|
||||||
c10::TensorType& cast_tensor_type(c10::TypePtr value_type) {
|
c10::TensorType& cast_tensor_type(c10::TypePtr value_type) {
|
||||||
auto tensor_type = value_type->cast<c10::TensorType>();
|
auto tensor_type = value_type->cast<c10::TensorType>();
|
||||||
TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!");
|
TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!");
|
||||||
|
@ -174,358 +197,284 @@ std::vector<torch::lazy::Shape> compute_shape_slice(
|
||||||
return {Shape(scalar_type.value(), dims)};
|
return {Shape(scalar_type.value(), dims)};
|
||||||
}
|
}
|
||||||
|
|
||||||
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
|
torch::jit::Value*
|
||||||
public:
|
GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
|
||||||
TorchMlirNodeLowering(
|
std::vector<torch::jit::NamedValue> clone_arguments;
|
||||||
const std::string& name, torch::lazy::TorchMlirLoweringContext* loctx)
|
clone_arguments.emplace_back(val);
|
||||||
: loctx_(loctx), function_(
|
|
||||||
loctx ? std::make_shared<torch::jit::GraphFunction>(
|
|
||||||
name, loctx->graph(), nullptr)
|
|
||||||
: nullptr) {}
|
|
||||||
|
|
||||||
torch::lazy::TorchMlirLoweringContext* loctx() { return loctx_; }
|
// Type of cloned value should be identical to the original one.
|
||||||
|
TorchMlirOpVector cloned =
|
||||||
|
LowerBuiltin(at::aten::clone, {val->type()}, function, clone_arguments);
|
||||||
|
CHECK_EQ(cloned.size(), 1);
|
||||||
|
return cloned.front();
|
||||||
|
}
|
||||||
|
|
||||||
bool Lower(const torch::lazy::Node* node) override {
|
|
||||||
if (auto* torch_mlir_node =
|
|
||||||
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
|
|
||||||
// First, we call the node lowering function, which exists for newly
|
|
||||||
// codegenned or refactored nodes
|
|
||||||
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, loctx());
|
|
||||||
if (ops.empty()) {
|
|
||||||
// Then fall back to legacy lowering code, which should be gradually
|
|
||||||
// removed
|
|
||||||
ops = LowerNonCodegenOps(node);
|
|
||||||
}
|
|
||||||
if (ops.empty()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
CHECK_EQ(node->num_outputs(), ops.size());
|
|
||||||
for (size_t i = 0; i < ops.size(); ++i) {
|
|
||||||
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
TorchMlirOpVector ops = LowerNonCodegenOps(node);
|
|
||||||
if (!ops.empty()) {
|
|
||||||
CHECK_EQ(node->num_outputs(), ops.size());
|
|
||||||
for (size_t i = 0; i < ops.size(); ++i) {
|
|
||||||
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(whc) this is for legacy/non-codegen Ops, and after moving most ops
|
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source, TorchMlirFunction function) {
|
||||||
// to codegen we should delete this and put all the lowering logic into Node
|
|
||||||
// classes
|
|
||||||
TorchMlirOpVector LowerNonCodegenOps(const torch::lazy::Node* node) {
|
|
||||||
|
|
||||||
if (node->op().op == at::aten::as_strided) {
|
|
||||||
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(
|
|
||||||
node, torch::lazy::OpKind(at::aten::as_strided)));
|
|
||||||
}
|
|
||||||
if (node->op() == *torch::lazy::ltc_as_strided_view_update) {
|
|
||||||
return LowerAsStridedViewUpdate(
|
|
||||||
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(
|
|
||||||
node, *torch::lazy::ltc_as_strided_view_update));
|
|
||||||
}
|
|
||||||
if (node->op() == *torch::lazy::ltc_cast) {
|
|
||||||
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(
|
|
||||||
node, *torch::lazy::ltc_cast));
|
|
||||||
}
|
|
||||||
if (node->op() == *torch::lazy::ltc_select_view_update) {
|
|
||||||
return LowerSelectViewUpdate(
|
|
||||||
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(
|
|
||||||
node, *torch::lazy::ltc_select_view_update));
|
|
||||||
}
|
|
||||||
if (node->op() == *torch::lazy::ltc_narrow_view_update) {
|
|
||||||
return LowerNarrowViewUpdate(
|
|
||||||
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(
|
|
||||||
node, *torch::lazy::ltc_narrow_view_update));
|
|
||||||
}
|
|
||||||
if (node->op().op == at::prim::Constant) {
|
|
||||||
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(
|
|
||||||
node, torch::lazy::OpKind(at::prim::Constant)));
|
|
||||||
}
|
|
||||||
if (node->op().op == at::aten::bernoulli) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
return LowerBuiltin(node, arguments);
|
|
||||||
}
|
|
||||||
if (node->op().op == at::aten::expand) {
|
|
||||||
return LowerExpand(torch::lazy::NodeCast<torch::lazy::Expand>(
|
|
||||||
node, torch::lazy::OpKind(at::aten::expand)));
|
|
||||||
}
|
|
||||||
if (node->op().op == at::aten::narrow) {
|
|
||||||
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(
|
|
||||||
node, torch::lazy::OpKind(at::aten::narrow)));
|
|
||||||
}
|
|
||||||
if (node->op().op == at::aten::permute) {
|
|
||||||
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(
|
|
||||||
node, torch::lazy::OpKind(at::aten::permute)));
|
|
||||||
}
|
|
||||||
if (node->op().op == at::aten::select) {
|
|
||||||
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(
|
|
||||||
node, torch::lazy::OpKind(at::aten::select)));
|
|
||||||
}
|
|
||||||
if (node->op().op == at::aten::squeeze) {
|
|
||||||
return LowerSqueeze(torch::lazy::NodeCast<torch::lazy::Squeeze>(
|
|
||||||
node, torch::lazy::OpKind(at::aten::squeeze)));
|
|
||||||
}
|
|
||||||
if (node->op().op == at::aten::unsqueeze) {
|
|
||||||
return LowerUnsqueeze(torch::lazy::NodeCast<torch::lazy::Unsqueeze>(
|
|
||||||
node, torch::lazy::OpKind(at::aten::unsqueeze)));
|
|
||||||
}
|
|
||||||
if (node->op().op == at::aten::view) {
|
|
||||||
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(
|
|
||||||
node, torch::lazy::OpKind(at::aten::view)));
|
|
||||||
}
|
|
||||||
if (node->op() == *torch::lazy::ltc_device_data) {
|
|
||||||
const torch::lazy::DeviceData* device_data_node =
|
|
||||||
torch::lazy::NodeCast<torch::lazy::DeviceData>(
|
|
||||||
node, *torch::lazy::ltc_device_data);
|
|
||||||
auto infoptr = device_data_node->data()->info();
|
|
||||||
auto deviceDataInfoPtr =
|
|
||||||
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
|
||||||
if (GRAPH_DUMP_ENABLED) {
|
|
||||||
LOG(ERROR) << "Lowering device data node, tensor id "
|
|
||||||
<< deviceDataInfoPtr->tensor_id << std::endl;
|
|
||||||
}
|
|
||||||
return {loctx()->GetParameter(device_data_node->data())};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
for (const torch::lazy::Output& output : node->operands()) {
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(output));
|
|
||||||
}
|
|
||||||
return LowerBuiltin(node, arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerBuiltin(
|
|
||||||
const torch::lazy::Node* node,
|
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
|
||||||
return LowerTorchMlirBuiltin(
|
|
||||||
function_, node->op().op, node->shapes(), arguments, kwarguments);
|
|
||||||
}
|
|
||||||
TorchMlirOpVector LowerBuiltin(
|
|
||||||
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
|
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
|
||||||
return LowerTorchMlirBuiltin(
|
|
||||||
function_, sym, result_shapes, arguments, kwarguments);
|
|
||||||
}
|
|
||||||
TorchMlirOpVector LowerBuiltin(
|
|
||||||
c10::Symbol sym, const std::vector<c10::TypePtr> types,
|
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
|
||||||
return LowerTorchMlirBuiltin(function_, sym, types, arguments, kwarguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
arguments.emplace_back(node->size);
|
|
||||||
arguments.emplace_back(node->stride);
|
|
||||||
arguments.emplace_back(node->storage_offset);
|
|
||||||
TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments);
|
|
||||||
CHECK_EQ(as_strided_out.size(), 1);
|
|
||||||
return {GenerateClone(as_strided_out.front())};
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector
|
|
||||||
LowerAsStridedViewUpdate(const torch::lazy::AsStridedViewUpdate* node) {
|
|
||||||
torch::jit::Value* destination =
|
|
||||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
const torch::lazy::Output& input_op = node->operand(1);
|
|
||||||
const torch::lazy::Shape& input_shape = input_op.shape();
|
|
||||||
const auto input_dimensions = input_shape.sizes();
|
|
||||||
std::vector<torch::jit::NamedValue> dest_arguments;
|
|
||||||
dest_arguments.emplace_back(destination);
|
|
||||||
dest_arguments.emplace_back(
|
|
||||||
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
|
|
||||||
dest_arguments.emplace_back(node->stride);
|
|
||||||
dest_arguments.emplace_back(node->storage_offset);
|
|
||||||
TorchMlirOpVector as_strided_out =
|
|
||||||
LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments);
|
|
||||||
CHECK_EQ(as_strided_out.size(), 1);
|
|
||||||
torch::jit::Value* as_strided = as_strided_out.front();
|
|
||||||
GenerateCopy(as_strided, loctx()->GetOutputOp(input_op));
|
|
||||||
return {destination};
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
arguments.emplace_back(node->dtype);
|
|
||||||
return LowerBuiltin(at::aten::to, node->shapes(), arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
arguments.emplace_back(node->size);
|
|
||||||
auto expand_out = LowerBuiltin(node, arguments);
|
|
||||||
if (node->is_scalar_expand) {
|
|
||||||
// The aten::expand operations sets all strides to 0 when the original
|
|
||||||
// of rank 0. This leads to false positives when checking for internal
|
|
||||||
// memory overlap, because at::has_internal_overlap returns
|
|
||||||
// MemOverlap::YES when a stride is set to 0.
|
|
||||||
CHECK_EQ(expand_out.size(), 1);
|
|
||||||
return {GenerateClone(expand_out.front())};
|
|
||||||
}
|
|
||||||
return expand_out;
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) {
|
|
||||||
const torch::lazy::Output& input = node->operand(0);
|
|
||||||
torch::jit::Value* base = loctx()->GetOutputOp(input);
|
|
||||||
const auto& base_indices = node->base_indices;
|
|
||||||
const auto& sizes = node->sizes;
|
|
||||||
const torch::lazy::Shape& input_shape = input.shape();
|
|
||||||
CHECK_EQ(sizes.size(), base_indices.size());
|
|
||||||
CHECK_EQ(input_shape.dim(), base_indices.size());
|
|
||||||
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
|
||||||
int64_t start = base_indices[dim];
|
|
||||||
base = GenerateSlice(
|
|
||||||
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
|
||||||
/*end=*/start + sizes[dim], /*step=*/1);
|
|
||||||
}
|
|
||||||
return {base};
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
arguments.push_back(node->dims);
|
|
||||||
return LowerBuiltin(node, arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) {
|
|
||||||
const at::Scalar& value = node->value;
|
|
||||||
const torch::lazy::Shape& shape = node->shape();
|
|
||||||
auto options =
|
|
||||||
at::TensorOptions()
|
|
||||||
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
|
|
||||||
.dtype(shape.scalar_type());
|
|
||||||
return {
|
|
||||||
loctx()->graph()->insertConstant(at::scalar_tensor(value, options))};
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) {
|
|
||||||
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
|
|
||||||
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
|
|
||||||
return {GenerateSlice(
|
|
||||||
/*base=*/base, /*dim=*/node->dim,
|
|
||||||
/*start=*/node->start, /*end=*/node->end,
|
|
||||||
/*step=*/step)};
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
if (node->dim != -1) {
|
|
||||||
arguments.push_back(node->dim);
|
|
||||||
}
|
|
||||||
return LowerBuiltin(node, arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector
|
|
||||||
LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
|
|
||||||
torch::jit::Value* dest =
|
|
||||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
|
|
||||||
torch::jit::Value* selected = GenerateSlice(
|
|
||||||
/*base=*/dest, /*dim=*/node->dim, /*start=*/node->start,
|
|
||||||
/*end=*/node->end, /*step=*/step);
|
|
||||||
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
|
|
||||||
return {dest};
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector
|
|
||||||
LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
|
|
||||||
torch::jit::Value* dest =
|
|
||||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
const auto& base_indices = node->base_indices;
|
|
||||||
const torch::lazy::Output& source_argument = node->operand(1);
|
|
||||||
const torch::lazy::Shape& source_shape = source_argument.shape();
|
|
||||||
CHECK_EQ(source_shape.dim(), base_indices.size());
|
|
||||||
torch::jit::Value* base = dest;
|
|
||||||
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
|
||||||
int64_t start = base_indices[dim];
|
|
||||||
base = GenerateSlice(
|
|
||||||
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
|
||||||
/*end=*/start + source_shape.size(dim),
|
|
||||||
/*step=*/1);
|
|
||||||
}
|
|
||||||
GenerateCopy(base, loctx()->GetOutputOp(source_argument));
|
|
||||||
return {dest};
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
arguments.push_back(node->dim);
|
|
||||||
return LowerBuiltin(node, arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchMlirOpVector LowerView(const torch::lazy::View* node) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
|
||||||
arguments.push_back(node->output_size);
|
|
||||||
return LowerBuiltin(at::aten::reshape, node->shapes(), arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::jit::Value* GenerateClone(torch::jit::Value* val) {
|
|
||||||
std::vector<torch::jit::NamedValue> clone_arguments;
|
|
||||||
clone_arguments.emplace_back(val);
|
|
||||||
|
|
||||||
// Type of cloned value should be identical to the original one.
|
|
||||||
TorchMlirOpVector cloned =
|
|
||||||
LowerBuiltin(at::aten::clone, {val->type()}, clone_arguments);
|
|
||||||
CHECK_EQ(cloned.size(), 1);
|
|
||||||
return cloned.front();
|
|
||||||
}
|
|
||||||
|
|
||||||
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
arguments.emplace_back(destination);
|
arguments.emplace_back(destination);
|
||||||
arguments.emplace_back(source);
|
arguments.emplace_back(source);
|
||||||
LowerBuiltin(
|
LowerBuiltin(
|
||||||
at::aten::copy_,
|
at::aten::copy_,
|
||||||
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), arguments);
|
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), function, arguments);
|
||||||
}
|
|
||||||
|
|
||||||
torch::jit::Value* GenerateSlice(
|
|
||||||
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
|
|
||||||
int64_t step) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
|
||||||
arguments.emplace_back(base);
|
|
||||||
arguments.emplace_back(dim);
|
|
||||||
arguments.emplace_back(start);
|
|
||||||
arguments.emplace_back(end);
|
|
||||||
arguments.emplace_back(step);
|
|
||||||
|
|
||||||
TorchMlirOpVector selected = LowerBuiltin(
|
|
||||||
at::aten::slice,
|
|
||||||
c10::ArrayRef<Shape>(
|
|
||||||
compute_shape_slice(base->type(), dim, start, end, step)),
|
|
||||||
arguments);
|
|
||||||
CHECK_EQ(selected.size(), 1);
|
|
||||||
return selected.front();
|
|
||||||
}
|
|
||||||
torch::lazy::TorchMlirLoweringContext* loctx_;
|
|
||||||
std::shared_ptr<torch::jit::GraphFunction> function_;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::unique_ptr<TorchMlirNodeLoweringInterface>
|
|
||||||
TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) {
|
|
||||||
return std::make_unique<TorchMlirNodeLowering>(
|
|
||||||
"TorchMlirNodeLowering",
|
|
||||||
static_cast<torch::lazy::TorchMlirLoweringContext*>(loctx));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
torch::jit::Value* GenerateSlice(
|
||||||
|
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
|
||||||
|
int64_t step, TorchMlirFunction function) {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(base);
|
||||||
|
arguments.emplace_back(dim);
|
||||||
|
arguments.emplace_back(start);
|
||||||
|
arguments.emplace_back(end);
|
||||||
|
arguments.emplace_back(step);
|
||||||
|
|
||||||
|
TorchMlirOpVector selected = LowerBuiltin(
|
||||||
|
at::aten::slice,
|
||||||
|
c10::ArrayRef<Shape>(
|
||||||
|
compute_shape_slice(base->type(), dim, start, end, step)),
|
||||||
|
function,
|
||||||
|
arguments);
|
||||||
|
CHECK_EQ(selected.size(), 1);
|
||||||
|
return selected.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Node Lowerings
|
||||||
|
|
||||||
|
// Default Node Lowering
|
||||||
|
TorchMlirOpVector TorchMlirNode::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
for (const torch::lazy::Output& output : operands()) {
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(output));
|
||||||
|
}
|
||||||
|
return LowerBuiltin(this, function, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TorchMlir specific nodes
|
||||||
|
|
||||||
|
// Non-native nodes
|
||||||
|
|
||||||
|
TorchMlirOpVector
|
||||||
|
Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
|
arguments.emplace_back(dtype);
|
||||||
|
return LowerBuiltin(at::aten::to, shapes(), function, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector DeviceData::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
auto infoptr = data_->info();
|
||||||
|
auto deviceDataInfoPtr =
|
||||||
|
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
||||||
|
if (GRAPH_DUMP_ENABLED) {
|
||||||
|
LOG(ERROR) << "Lowering device data node, tensor id "
|
||||||
|
<< deviceDataInfoPtr->tensor_id << std::endl;
|
||||||
|
}
|
||||||
|
return {loctx->GetParameter(data_)};
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Expand::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
|
arguments.emplace_back(size);
|
||||||
|
auto expand_out = LowerBuiltin(this, function, arguments);
|
||||||
|
if (is_scalar_expand) {
|
||||||
|
// The aten::expand operations sets all strides to 0 when the original is
|
||||||
|
// of rank 0. This leads to false positives when checking for internal
|
||||||
|
// memory overlap, because at::has_internal_overlap returns
|
||||||
|
// MemOverlap::YES when a stride is set to 0.
|
||||||
|
CHECK_EQ(expand_out.size(), 1);
|
||||||
|
return {GenerateClone(expand_out.front(), function)};
|
||||||
|
}
|
||||||
|
return expand_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Scalar::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
auto options =
|
||||||
|
at::TensorOptions()
|
||||||
|
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
|
||||||
|
.dtype(shape().scalar_type());
|
||||||
|
return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
|
||||||
|
}
|
||||||
|
|
||||||
|
// View Ops
|
||||||
|
|
||||||
|
TorchMlirOpVector AsStrided::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
|
arguments.emplace_back(size);
|
||||||
|
arguments.emplace_back(stride);
|
||||||
|
arguments.emplace_back(storage_offset);
|
||||||
|
TorchMlirOpVector as_strided_out = LowerBuiltin(this, function, arguments);
|
||||||
|
CHECK_EQ(as_strided_out.size(), 1);
|
||||||
|
return {GenerateClone(as_strided_out.front(), function)};
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector AsStridedViewUpdate::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
|
||||||
|
torch::jit::Value* destination =
|
||||||
|
GenerateClone(loctx->GetOutputOp(operand(0)), function);
|
||||||
|
const torch::lazy::Output& input_op = operand(1);
|
||||||
|
const torch::lazy::Shape& input_shape = input_op.shape();
|
||||||
|
const auto input_dimensions = input_shape.sizes();
|
||||||
|
std::vector<torch::jit::NamedValue> dest_arguments;
|
||||||
|
dest_arguments.emplace_back(destination);
|
||||||
|
dest_arguments.emplace_back(
|
||||||
|
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
|
||||||
|
dest_arguments.emplace_back(stride);
|
||||||
|
dest_arguments.emplace_back(storage_offset);
|
||||||
|
TorchMlirOpVector as_strided_out =
|
||||||
|
LowerBuiltin(at::aten::as_strided, shapes(), function, dest_arguments);
|
||||||
|
CHECK_EQ(as_strided_out.size(), 1);
|
||||||
|
torch::jit::Value* as_strided = as_strided_out.front();
|
||||||
|
GenerateCopy(as_strided, loctx->GetOutputOp(input_op), function);
|
||||||
|
return {destination};
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Diagonal::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
|
arguments.emplace_back(offset);
|
||||||
|
arguments.emplace_back(dim1);
|
||||||
|
arguments.emplace_back(dim2);
|
||||||
|
return LowerBuiltin(this, function, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector DiagonalViewUpdate::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
// Since we promise the backends that we never generate any aliased
|
||||||
|
// inplace update IR, therefore we clone the target first and then
|
||||||
|
// update the clone inplace instead. Since the clone is transient,
|
||||||
|
// it will never be aliased, and therefore it's safe.
|
||||||
|
torch::jit::Value* destination =
|
||||||
|
GenerateClone(loctx->GetOutputOp(operand(0)), function);
|
||||||
|
|
||||||
|
// Replay the diagonal.
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(destination);
|
||||||
|
arguments.emplace_back(offset);
|
||||||
|
arguments.emplace_back(dim1);
|
||||||
|
arguments.emplace_back(dim2);
|
||||||
|
auto diag = LowerBuiltin(at::aten::diagonal, shapes(), function, arguments);
|
||||||
|
|
||||||
|
// Update the replayed diagonal view with the input.
|
||||||
|
GenerateCopy(diag.front(), loctx->GetOutputOp(operand(1)), function);
|
||||||
|
|
||||||
|
// Destination's diag view should be updated.
|
||||||
|
return {destination};
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Narrow::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
const torch::lazy::Output& input = operand(0);
|
||||||
|
torch::jit::Value* base = loctx->GetOutputOp(input);
|
||||||
|
const torch::lazy::Shape& input_shape = input.shape();
|
||||||
|
CHECK_EQ(sizes.size(), base_indices.size());
|
||||||
|
CHECK_EQ(input_shape.dim(), base_indices.size());
|
||||||
|
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
||||||
|
int64_t start = base_indices[dim];
|
||||||
|
base = GenerateSlice(
|
||||||
|
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
||||||
|
/*end=*/start + sizes[dim], /*step=*/1,
|
||||||
|
/*function=*/function);
|
||||||
|
}
|
||||||
|
return {base};
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector NarrowViewUpdate::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
torch::jit::Value* dest =
|
||||||
|
GenerateClone(loctx->GetOutputOp(operand(0)), function);
|
||||||
|
const torch::lazy::Output& source_argument = operand(1);
|
||||||
|
const torch::lazy::Shape& source_shape = source_argument.shape();
|
||||||
|
CHECK_EQ(source_shape.dim(), base_indices.size());
|
||||||
|
torch::jit::Value* base = dest;
|
||||||
|
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
||||||
|
int64_t start = base_indices[dim];
|
||||||
|
base = GenerateSlice(
|
||||||
|
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
||||||
|
/*end=*/start + source_shape.size(dim), /*step=*/1,
|
||||||
|
/*function=*/function);
|
||||||
|
}
|
||||||
|
GenerateCopy(base, loctx->GetOutputOp(source_argument), function);
|
||||||
|
return {dest};
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Permute::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
|
arguments.emplace_back(dims);
|
||||||
|
return LowerBuiltin(this, function, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Resize::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
for (const torch::lazy::Output& output : operands()) {
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(output));
|
||||||
|
}
|
||||||
|
return LowerBuiltin(this, function, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Select::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
int64_t step = torch::lazy::GetStride(start, end, stride);
|
||||||
|
torch::jit::Value* base = loctx->GetOutputOp(operand(0));
|
||||||
|
return {GenerateSlice(
|
||||||
|
/*base=*/base, /*dim=*/dim,
|
||||||
|
/*start=*/start, /*end=*/end,
|
||||||
|
/*step=*/step, /*function=*/function)};
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector SelectViewUpdate::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
torch::jit::Value* dest =
|
||||||
|
GenerateClone(loctx->GetOutputOp(operand(0)), function);
|
||||||
|
int64_t step = torch::lazy::GetStride(start, end, stride);
|
||||||
|
torch::jit::Value* selected = GenerateSlice(
|
||||||
|
/*base=*/dest, /*dim=*/dim, /*start=*/start,
|
||||||
|
/*end=*/end, /*step=*/step, /*function=*/function);
|
||||||
|
GenerateCopy(selected, loctx->GetOutputOp(operand(1)), function);
|
||||||
|
return {dest};
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Squeeze::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
|
if (dim != -1) {
|
||||||
|
arguments.emplace_back(dim);
|
||||||
|
}
|
||||||
|
return LowerBuiltin(this, function, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Unsqueeze::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
|
arguments.emplace_back(dim);
|
||||||
|
return LowerBuiltin(this, function, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector
|
||||||
|
View::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
|
arguments.emplace_back(output_size);
|
||||||
|
return LowerBuiltin(at::aten::reshape, shapes(), function, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "../mlir_lowering_context.h"
|
||||||
#include "../mlir_node.h"
|
#include "../mlir_node.h"
|
||||||
|
|
||||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
|
@ -34,6 +35,8 @@ class TORCH_API DeviceData : public TorchMlirNode {
|
||||||
data_ = data;
|
data_ = data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override;
|
||||||
|
|
||||||
static const DeviceData* Cast(const Node* node);
|
static const DeviceData* Cast(const Node* node);
|
||||||
|
|
||||||
// To reuse IR nodes, use this method to create DeviceData nodes
|
// To reuse IR nodes, use this method to create DeviceData nodes
|
||||||
|
|
Loading…
Reference in New Issue