Refactor Node Lowering (#914)

pull/1125/head
Jae Hoon (Antonio) Kim 2022-06-09 15:26:09 -04:00 committed by Henry Tu
parent d9aee0d7a7
commit 8312fa535b
5 changed files with 322 additions and 375 deletions

View File

@ -36,8 +36,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device)
: LoweringContext(name, std::forward<BackendDevice>(device)),
graph_(std::make_shared<torch::jit::Graph>()),
function_(
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
mlir_context_(mlirContextCreate()) {
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
RegisterMlirDialects();
}
@ -49,16 +50,31 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)),
graph_(std::make_shared<torch::jit::Graph>()),
function_(
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
mlir_context_(mlirContextCreate()) {
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
CHECK(ok) << "Failed to lower: " << *node;
Lower(node);
}
RegisterMlirDialects();
}
void TorchMlirLoweringContext::Lower(const Node* node) {
if (auto* torch_mlir_node =
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this);
CHECK(!ops.empty()) << "Failed to lower: " << *node;
CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
} else {
throw std::runtime_error(
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
}
}
void TorchMlirLoweringContext::SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
const std::vector<int64_t>& param_index, bool must_alias) {
@ -136,8 +152,7 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
if (it == emitted_outputs_.end()) {
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
Lower(node);
}
// At this point the output better be present, otherwise there is an issue
// with the lowering code.

View File

@ -23,22 +23,6 @@
namespace torch {
namespace lazy {
class TORCH_API TorchMlirNodeLoweringInterface {
/**
* This interface is only needed for legacy ops, and can be removed once all
* ops implement LtcMlirNode->lower().
* */
public:
TorchMlirNodeLoweringInterface() = default;
virtual ~TorchMlirNodeLoweringInterface() = default;
virtual bool Lower(const Node* node) = 0;
static std::unique_ptr<TorchMlirNodeLoweringInterface>
Create(LoweringContext* loctx);
};
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
public:
// Describes an input/output alias as inserted by the SetUpAlias() API.
@ -61,6 +45,8 @@ public:
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);
void Lower(const Node* node);
// Adds a new input/output alias.
void SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
@ -120,11 +106,11 @@ private:
// Holds the input/output alias information populated by the SetUpAlias() API.
InputOutputAliases input_output_aliases_;
std::shared_ptr<torch::jit::Graph> graph_;
std::shared_ptr<torch::jit::GraphFunction> function_;
MlirContext mlir_context_;
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
std::vector<torch::jit::Value*> root_tuple_;
OutputMap<torch::jit::Value*> emitted_outputs_;
std::unique_ptr<TorchMlirNodeLoweringInterface> lowering_;
};
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {

View File

@ -71,12 +71,6 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
TorchMlirOpVector TorchMlirNode::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
return {};
}
OpKind TorchMlirTensorList::ClassOpKind() {
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
// import otherwise

View File

@ -32,7 +32,7 @@ namespace torch {
namespace lazy {
TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
TorchMlirFunction function, c10::Symbol sym,
const std::vector<c10::TypePtr> tensor_types,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
@ -81,7 +81,7 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
}
TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
TorchMlirFunction function, c10::Symbol sym,
const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
@ -101,6 +101,29 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
function, sym, tensor_types, arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
const torch::lazy::Node* node, TorchMlirFunction function,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(
function, node->op().op, node->shapes(), arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
TorchMlirFunction function,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(
function, sym, result_shapes, arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const std::vector<c10::TypePtr> types,
TorchMlirFunction function,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments);
}
c10::TensorType& cast_tensor_type(c10::TypePtr value_type) {
auto tensor_type = value_type->cast<c10::TensorType>();
TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!");
@ -174,334 +197,32 @@ std::vector<torch::lazy::Shape> compute_shape_slice(
return {Shape(scalar_type.value(), dims)};
}
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
public:
TorchMlirNodeLowering(
const std::string& name, torch::lazy::TorchMlirLoweringContext* loctx)
: loctx_(loctx), function_(
loctx ? std::make_shared<torch::jit::GraphFunction>(
name, loctx->graph(), nullptr)
: nullptr) {}
torch::lazy::TorchMlirLoweringContext* loctx() { return loctx_; }
bool Lower(const torch::lazy::Node* node) override {
if (auto* torch_mlir_node =
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
// First, we call the node lowering function, which exists for newly
// codegenned or refactored nodes
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, loctx());
if (ops.empty()) {
// Then fall back to legacy lowering code, which should be gradually
// removed
ops = LowerNonCodegenOps(node);
}
if (ops.empty()) {
return false;
}
CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
return true;
} else {
TorchMlirOpVector ops = LowerNonCodegenOps(node);
if (!ops.empty()) {
CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
return true;
}
}
throw std::runtime_error(
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
}
// TODO(whc) this is for legacy/non-codegen Ops, and after moving most ops
// to codegen we should delete this and put all the lowering logic into Node
// classes
TorchMlirOpVector LowerNonCodegenOps(const torch::lazy::Node* node) {
if (node->op().op == at::aten::as_strided) {
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(
node, torch::lazy::OpKind(at::aten::as_strided)));
}
if (node->op() == *torch::lazy::ltc_as_strided_view_update) {
return LowerAsStridedViewUpdate(
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(
node, *torch::lazy::ltc_as_strided_view_update));
}
if (node->op() == *torch::lazy::ltc_cast) {
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(
node, *torch::lazy::ltc_cast));
}
if (node->op() == *torch::lazy::ltc_select_view_update) {
return LowerSelectViewUpdate(
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(
node, *torch::lazy::ltc_select_view_update));
}
if (node->op() == *torch::lazy::ltc_narrow_view_update) {
return LowerNarrowViewUpdate(
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(
node, *torch::lazy::ltc_narrow_view_update));
}
if (node->op().op == at::prim::Constant) {
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(
node, torch::lazy::OpKind(at::prim::Constant)));
}
if (node->op().op == at::aten::bernoulli) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
return LowerBuiltin(node, arguments);
}
if (node->op().op == at::aten::expand) {
return LowerExpand(torch::lazy::NodeCast<torch::lazy::Expand>(
node, torch::lazy::OpKind(at::aten::expand)));
}
if (node->op().op == at::aten::narrow) {
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(
node, torch::lazy::OpKind(at::aten::narrow)));
}
if (node->op().op == at::aten::permute) {
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(
node, torch::lazy::OpKind(at::aten::permute)));
}
if (node->op().op == at::aten::select) {
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(
node, torch::lazy::OpKind(at::aten::select)));
}
if (node->op().op == at::aten::squeeze) {
return LowerSqueeze(torch::lazy::NodeCast<torch::lazy::Squeeze>(
node, torch::lazy::OpKind(at::aten::squeeze)));
}
if (node->op().op == at::aten::unsqueeze) {
return LowerUnsqueeze(torch::lazy::NodeCast<torch::lazy::Unsqueeze>(
node, torch::lazy::OpKind(at::aten::unsqueeze)));
}
if (node->op().op == at::aten::view) {
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(
node, torch::lazy::OpKind(at::aten::view)));
}
if (node->op() == *torch::lazy::ltc_device_data) {
const torch::lazy::DeviceData* device_data_node =
torch::lazy::NodeCast<torch::lazy::DeviceData>(
node, *torch::lazy::ltc_device_data);
auto infoptr = device_data_node->data()->info();
auto deviceDataInfoPtr =
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
if (GRAPH_DUMP_ENABLED) {
LOG(ERROR) << "Lowering device data node, tensor id "
<< deviceDataInfoPtr->tensor_id << std::endl;
}
return {loctx()->GetParameter(device_data_node->data())};
}
std::vector<torch::jit::NamedValue> arguments;
for (const torch::lazy::Output& output : node->operands()) {
arguments.emplace_back(loctx()->GetOutputOp(output));
}
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerBuiltin(
const torch::lazy::Node* node,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(
function_, node->op().op, node->shapes(), arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(
function_, sym, result_shapes, arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const std::vector<c10::TypePtr> types,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(function_, sym, types, arguments, kwarguments);
}
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->size);
arguments.emplace_back(node->stride);
arguments.emplace_back(node->storage_offset);
TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments);
CHECK_EQ(as_strided_out.size(), 1);
return {GenerateClone(as_strided_out.front())};
}
TorchMlirOpVector
LowerAsStridedViewUpdate(const torch::lazy::AsStridedViewUpdate* node) {
torch::jit::Value* destination =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
const torch::lazy::Output& input_op = node->operand(1);
const torch::lazy::Shape& input_shape = input_op.shape();
const auto input_dimensions = input_shape.sizes();
std::vector<torch::jit::NamedValue> dest_arguments;
dest_arguments.emplace_back(destination);
dest_arguments.emplace_back(
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
dest_arguments.emplace_back(node->stride);
dest_arguments.emplace_back(node->storage_offset);
TorchMlirOpVector as_strided_out =
LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments);
CHECK_EQ(as_strided_out.size(), 1);
torch::jit::Value* as_strided = as_strided_out.front();
GenerateCopy(as_strided, loctx()->GetOutputOp(input_op));
return {destination};
}
TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->dtype);
return LowerBuiltin(at::aten::to, node->shapes(), arguments);
}
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->size);
auto expand_out = LowerBuiltin(node, arguments);
if (node->is_scalar_expand) {
// The aten::expand operations sets all strides to 0 when the original
// of rank 0. This leads to false positives when checking for internal
// memory overlap, because at::has_internal_overlap returns
// MemOverlap::YES when a stride is set to 0.
CHECK_EQ(expand_out.size(), 1);
return {GenerateClone(expand_out.front())};
}
return expand_out;
}
TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) {
const torch::lazy::Output& input = node->operand(0);
torch::jit::Value* base = loctx()->GetOutputOp(input);
const auto& base_indices = node->base_indices;
const auto& sizes = node->sizes;
const torch::lazy::Shape& input_shape = input.shape();
CHECK_EQ(sizes.size(), base_indices.size());
CHECK_EQ(input_shape.dim(), base_indices.size());
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
int64_t start = base_indices[dim];
base = GenerateSlice(
/*base=*/base, /*dim=*/dim, /*start=*/start,
/*end=*/start + sizes[dim], /*step=*/1);
}
return {base};
}
TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->dims);
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) {
const at::Scalar& value = node->value;
const torch::lazy::Shape& shape = node->shape();
auto options =
at::TensorOptions()
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
.dtype(shape.scalar_type());
return {
loctx()->graph()->insertConstant(at::scalar_tensor(value, options))};
}
TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) {
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
return {GenerateSlice(
/*base=*/base, /*dim=*/node->dim,
/*start=*/node->start, /*end=*/node->end,
/*step=*/step)};
}
TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
if (node->dim != -1) {
arguments.push_back(node->dim);
}
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector
LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
torch::jit::Value* dest =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
torch::jit::Value* selected = GenerateSlice(
/*base=*/dest, /*dim=*/node->dim, /*start=*/node->start,
/*end=*/node->end, /*step=*/step);
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
return {dest};
}
TorchMlirOpVector
LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
torch::jit::Value* dest =
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
const auto& base_indices = node->base_indices;
const torch::lazy::Output& source_argument = node->operand(1);
const torch::lazy::Shape& source_shape = source_argument.shape();
CHECK_EQ(source_shape.dim(), base_indices.size());
torch::jit::Value* base = dest;
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
int64_t start = base_indices[dim];
base = GenerateSlice(
/*base=*/base, /*dim=*/dim, /*start=*/start,
/*end=*/start + source_shape.size(dim),
/*step=*/1);
}
GenerateCopy(base, loctx()->GetOutputOp(source_argument));
return {dest};
}
TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->dim);
return LowerBuiltin(node, arguments);
}
TorchMlirOpVector LowerView(const torch::lazy::View* node) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->output_size);
return LowerBuiltin(at::aten::reshape, node->shapes(), arguments);
}
torch::jit::Value* GenerateClone(torch::jit::Value* val) {
torch::jit::Value*
GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
std::vector<torch::jit::NamedValue> clone_arguments;
clone_arguments.emplace_back(val);
// Type of cloned value should be identical to the original one.
TorchMlirOpVector cloned =
LowerBuiltin(at::aten::clone, {val->type()}, clone_arguments);
LowerBuiltin(at::aten::clone, {val->type()}, function, clone_arguments);
CHECK_EQ(cloned.size(), 1);
return cloned.front();
}
}
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source) {
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source, TorchMlirFunction function) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(destination);
arguments.emplace_back(source);
LowerBuiltin(
at::aten::copy_,
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), arguments);
}
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), function, arguments);
}
torch::jit::Value* GenerateSlice(
torch::jit::Value* GenerateSlice(
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
int64_t step) {
int64_t step, TorchMlirFunction function) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(base);
arguments.emplace_back(dim);
@ -513,19 +234,247 @@ public:
at::aten::slice,
c10::ArrayRef<Shape>(
compute_shape_slice(base->type(), dim, start, end, step)),
function,
arguments);
CHECK_EQ(selected.size(), 1);
return selected.front();
}
torch::lazy::TorchMlirLoweringContext* loctx_;
std::shared_ptr<torch::jit::GraphFunction> function_;
};
std::unique_ptr<TorchMlirNodeLoweringInterface>
TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) {
return std::make_unique<TorchMlirNodeLowering>(
"TorchMlirNodeLowering",
static_cast<torch::lazy::TorchMlirLoweringContext*>(loctx));
}
// Node Lowerings
// Default Node Lowering
TorchMlirOpVector TorchMlirNode::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
for (const torch::lazy::Output& output : operands()) {
arguments.emplace_back(loctx->GetOutputOp(output));
}
return LowerBuiltin(this, function, arguments);
}
// TorchMlir specific nodes
// Non-native nodes
TorchMlirOpVector
Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(dtype);
return LowerBuiltin(at::aten::to, shapes(), function, arguments);
}
TorchMlirOpVector DeviceData::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
auto infoptr = data_->info();
auto deviceDataInfoPtr =
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
if (GRAPH_DUMP_ENABLED) {
LOG(ERROR) << "Lowering device data node, tensor id "
<< deviceDataInfoPtr->tensor_id << std::endl;
}
return {loctx->GetParameter(data_)};
}
TorchMlirOpVector Expand::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(size);
auto expand_out = LowerBuiltin(this, function, arguments);
if (is_scalar_expand) {
// The aten::expand operations sets all strides to 0 when the original is
// of rank 0. This leads to false positives when checking for internal
// memory overlap, because at::has_internal_overlap returns
// MemOverlap::YES when a stride is set to 0.
CHECK_EQ(expand_out.size(), 1);
return {GenerateClone(expand_out.front(), function)};
}
return expand_out;
}
TorchMlirOpVector Scalar::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
auto options =
at::TensorOptions()
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
.dtype(shape().scalar_type());
return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
}
// View Ops
TorchMlirOpVector AsStrided::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(size);
arguments.emplace_back(stride);
arguments.emplace_back(storage_offset);
TorchMlirOpVector as_strided_out = LowerBuiltin(this, function, arguments);
CHECK_EQ(as_strided_out.size(), 1);
return {GenerateClone(as_strided_out.front(), function)};
}
TorchMlirOpVector AsStridedViewUpdate::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
torch::jit::Value* destination =
GenerateClone(loctx->GetOutputOp(operand(0)), function);
const torch::lazy::Output& input_op = operand(1);
const torch::lazy::Shape& input_shape = input_op.shape();
const auto input_dimensions = input_shape.sizes();
std::vector<torch::jit::NamedValue> dest_arguments;
dest_arguments.emplace_back(destination);
dest_arguments.emplace_back(
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
dest_arguments.emplace_back(stride);
dest_arguments.emplace_back(storage_offset);
TorchMlirOpVector as_strided_out =
LowerBuiltin(at::aten::as_strided, shapes(), function, dest_arguments);
CHECK_EQ(as_strided_out.size(), 1);
torch::jit::Value* as_strided = as_strided_out.front();
GenerateCopy(as_strided, loctx->GetOutputOp(input_op), function);
return {destination};
}
TorchMlirOpVector Diagonal::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(offset);
arguments.emplace_back(dim1);
arguments.emplace_back(dim2);
return LowerBuiltin(this, function, arguments);
}
TorchMlirOpVector DiagonalViewUpdate::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
// Since we promise the backends that we never generate any aliased
// inplace update IR, therefore we clone the target first and then
// update the clone inplace instead. Since the clone is transient,
// it will never be aliased, and therefore it's safe.
torch::jit::Value* destination =
GenerateClone(loctx->GetOutputOp(operand(0)), function);
// Replay the diagonal.
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(destination);
arguments.emplace_back(offset);
arguments.emplace_back(dim1);
arguments.emplace_back(dim2);
auto diag = LowerBuiltin(at::aten::diagonal, shapes(), function, arguments);
// Update the replayed diagonal view with the input.
GenerateCopy(diag.front(), loctx->GetOutputOp(operand(1)), function);
// Destination's diag view should be updated.
return {destination};
}
TorchMlirOpVector Narrow::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
const torch::lazy::Output& input = operand(0);
torch::jit::Value* base = loctx->GetOutputOp(input);
const torch::lazy::Shape& input_shape = input.shape();
CHECK_EQ(sizes.size(), base_indices.size());
CHECK_EQ(input_shape.dim(), base_indices.size());
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
int64_t start = base_indices[dim];
base = GenerateSlice(
/*base=*/base, /*dim=*/dim, /*start=*/start,
/*end=*/start + sizes[dim], /*step=*/1,
/*function=*/function);
}
return {base};
}
TorchMlirOpVector NarrowViewUpdate::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
torch::jit::Value* dest =
GenerateClone(loctx->GetOutputOp(operand(0)), function);
const torch::lazy::Output& source_argument = operand(1);
const torch::lazy::Shape& source_shape = source_argument.shape();
CHECK_EQ(source_shape.dim(), base_indices.size());
torch::jit::Value* base = dest;
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
int64_t start = base_indices[dim];
base = GenerateSlice(
/*base=*/base, /*dim=*/dim, /*start=*/start,
/*end=*/start + source_shape.size(dim), /*step=*/1,
/*function=*/function);
}
GenerateCopy(base, loctx->GetOutputOp(source_argument), function);
return {dest};
}
TorchMlirOpVector Permute::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(dims);
return LowerBuiltin(this, function, arguments);
}
TorchMlirOpVector Resize::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
for (const torch::lazy::Output& output : operands()) {
arguments.emplace_back(loctx->GetOutputOp(output));
}
return LowerBuiltin(this, function, arguments);
}
TorchMlirOpVector Select::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
int64_t step = torch::lazy::GetStride(start, end, stride);
torch::jit::Value* base = loctx->GetOutputOp(operand(0));
return {GenerateSlice(
/*base=*/base, /*dim=*/dim,
/*start=*/start, /*end=*/end,
/*step=*/step, /*function=*/function)};
}
TorchMlirOpVector SelectViewUpdate::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
torch::jit::Value* dest =
GenerateClone(loctx->GetOutputOp(operand(0)), function);
int64_t step = torch::lazy::GetStride(start, end, stride);
torch::jit::Value* selected = GenerateSlice(
/*base=*/dest, /*dim=*/dim, /*start=*/start,
/*end=*/end, /*step=*/step, /*function=*/function);
GenerateCopy(selected, loctx->GetOutputOp(operand(1)), function);
return {dest};
}
TorchMlirOpVector Squeeze::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
if (dim != -1) {
arguments.emplace_back(dim);
}
return LowerBuiltin(this, function, arguments);
}
TorchMlirOpVector Unsqueeze::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(dim);
return LowerBuiltin(this, function, arguments);
}
TorchMlirOpVector
View::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(output_size);
return LowerBuiltin(at::aten::reshape, shapes(), function, arguments);
}
} // namespace lazy
} // namespace torch

View File

@ -1,5 +1,6 @@
#pragma once
#include "../mlir_lowering_context.h"
#include "../mlir_node.h"
#include <torch/csrc/lazy/backend/backend_data.h>
@ -34,6 +35,8 @@ class TORCH_API DeviceData : public TorchMlirNode {
data_ = data;
}
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override;
static const DeviceData* Cast(const Node* node);
// To reuse IR nodes, use this method to create DeviceData nodes