mirror of https://github.com/llvm/torch-mlir
Refactor Node Lowering (#914)
parent
d9aee0d7a7
commit
8312fa535b
|
@ -36,8 +36,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
|
|||
const std::string& name, BackendDevice device)
|
||||
: LoweringContext(name, std::forward<BackendDevice>(device)),
|
||||
graph_(std::make_shared<torch::jit::Graph>()),
|
||||
function_(
|
||||
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
|
||||
mlir_context_(mlirContextCreate()) {
|
||||
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
|
||||
RegisterMlirDialects();
|
||||
}
|
||||
|
||||
|
@ -49,16 +50,31 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
|
|||
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
|
||||
std::forward<Util::EmissionMap>(emit_status)),
|
||||
graph_(std::make_shared<torch::jit::Graph>()),
|
||||
function_(
|
||||
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
|
||||
mlir_context_(mlirContextCreate()) {
|
||||
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
|
||||
for (auto node : post_order) {
|
||||
bool ok = lowering_->Lower(node);
|
||||
CHECK(ok) << "Failed to lower: " << *node;
|
||||
Lower(node);
|
||||
}
|
||||
|
||||
RegisterMlirDialects();
|
||||
}
|
||||
|
||||
void TorchMlirLoweringContext::Lower(const Node* node) {
|
||||
if (auto* torch_mlir_node =
|
||||
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
|
||||
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this);
|
||||
CHECK(!ops.empty()) << "Failed to lower: " << *node;
|
||||
CHECK_EQ(node->num_outputs(), ops.size());
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
|
||||
}
|
||||
}
|
||||
|
||||
void TorchMlirLoweringContext::SetUpAlias(
|
||||
const std::vector<int64_t>& output_index, int64_t param_number,
|
||||
const std::vector<int64_t>& param_index, bool must_alias) {
|
||||
|
@ -136,8 +152,7 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
|
|||
if (it == emitted_outputs_.end()) {
|
||||
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
|
||||
for (auto node : post_order) {
|
||||
bool ok = lowering_->Lower(node);
|
||||
TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
|
||||
Lower(node);
|
||||
}
|
||||
// At this point the output better be present, otherwise there is an issue
|
||||
// with the lowering code.
|
||||
|
|
|
@ -23,22 +23,6 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class TORCH_API TorchMlirNodeLoweringInterface {
|
||||
/**
|
||||
* This interface is only needed for legacy ops, and can be removed once all
|
||||
* ops implement LtcMlirNode->lower().
|
||||
* */
|
||||
public:
|
||||
TorchMlirNodeLoweringInterface() = default;
|
||||
|
||||
virtual ~TorchMlirNodeLoweringInterface() = default;
|
||||
|
||||
virtual bool Lower(const Node* node) = 0;
|
||||
|
||||
static std::unique_ptr<TorchMlirNodeLoweringInterface>
|
||||
Create(LoweringContext* loctx);
|
||||
};
|
||||
|
||||
class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
|
||||
public:
|
||||
// Describes an input/output alias as inserted by the SetUpAlias() API.
|
||||
|
@ -61,6 +45,8 @@ public:
|
|||
c10::ArrayRef<torch::lazy::Node*> post_order,
|
||||
torch::lazy::Util::EmissionMap emit_status);
|
||||
|
||||
void Lower(const Node* node);
|
||||
|
||||
// Adds a new input/output alias.
|
||||
void SetUpAlias(
|
||||
const std::vector<int64_t>& output_index, int64_t param_number,
|
||||
|
@ -120,11 +106,11 @@ private:
|
|||
// Holds the input/output alias information populated by the SetUpAlias() API.
|
||||
InputOutputAliases input_output_aliases_;
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
std::shared_ptr<torch::jit::GraphFunction> function_;
|
||||
MlirContext mlir_context_;
|
||||
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
|
||||
std::vector<torch::jit::Value*> root_tuple_;
|
||||
OutputMap<torch::jit::Value*> emitted_outputs_;
|
||||
std::unique_ptr<TorchMlirNodeLoweringInterface> lowering_;
|
||||
};
|
||||
|
||||
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
|
||||
|
|
|
@ -71,12 +71,6 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }
|
|||
|
||||
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
||||
|
||||
TorchMlirOpVector TorchMlirNode::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
OpKind TorchMlirTensorList::ClassOpKind() {
|
||||
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
||||
// import otherwise
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace torch {
|
|||
namespace lazy {
|
||||
|
||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
|
||||
TorchMlirFunction function, c10::Symbol sym,
|
||||
const std::vector<c10::TypePtr> tensor_types,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
||||
|
@ -81,7 +81,7 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
|||
}
|
||||
|
||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
|
||||
TorchMlirFunction function, c10::Symbol sym,
|
||||
const c10::ArrayRef<Shape> result_shapes,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
||||
|
@ -101,6 +101,29 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
|||
function, sym, tensor_types, arguments, kwarguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
const torch::lazy::Node* node, TorchMlirFunction function,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(
|
||||
function, node->op().op, node->shapes(), arguments, kwarguments);
|
||||
}
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
|
||||
TorchMlirFunction function,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(
|
||||
function, sym, result_shapes, arguments, kwarguments);
|
||||
}
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
c10::Symbol sym, const std::vector<c10::TypePtr> types,
|
||||
TorchMlirFunction function,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments);
|
||||
}
|
||||
|
||||
c10::TensorType& cast_tensor_type(c10::TypePtr value_type) {
|
||||
auto tensor_type = value_type->cast<c10::TensorType>();
|
||||
TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!");
|
||||
|
@ -174,334 +197,32 @@ std::vector<torch::lazy::Shape> compute_shape_slice(
|
|||
return {Shape(scalar_type.value(), dims)};
|
||||
}
|
||||
|
||||
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
|
||||
public:
|
||||
TorchMlirNodeLowering(
|
||||
const std::string& name, torch::lazy::TorchMlirLoweringContext* loctx)
|
||||
: loctx_(loctx), function_(
|
||||
loctx ? std::make_shared<torch::jit::GraphFunction>(
|
||||
name, loctx->graph(), nullptr)
|
||||
: nullptr) {}
|
||||
|
||||
torch::lazy::TorchMlirLoweringContext* loctx() { return loctx_; }
|
||||
|
||||
bool Lower(const torch::lazy::Node* node) override {
|
||||
if (auto* torch_mlir_node =
|
||||
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
|
||||
// First, we call the node lowering function, which exists for newly
|
||||
// codegenned or refactored nodes
|
||||
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, loctx());
|
||||
if (ops.empty()) {
|
||||
// Then fall back to legacy lowering code, which should be gradually
|
||||
// removed
|
||||
ops = LowerNonCodegenOps(node);
|
||||
}
|
||||
if (ops.empty()) {
|
||||
return false;
|
||||
}
|
||||
CHECK_EQ(node->num_outputs(), ops.size());
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
TorchMlirOpVector ops = LowerNonCodegenOps(node);
|
||||
if (!ops.empty()) {
|
||||
CHECK_EQ(node->num_outputs(), ops.size());
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
|
||||
}
|
||||
|
||||
// TODO(whc) this is for legacy/non-codegen Ops, and after moving most ops
|
||||
// to codegen we should delete this and put all the lowering logic into Node
|
||||
// classes
|
||||
TorchMlirOpVector LowerNonCodegenOps(const torch::lazy::Node* node) {
|
||||
|
||||
if (node->op().op == at::aten::as_strided) {
|
||||
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(
|
||||
node, torch::lazy::OpKind(at::aten::as_strided)));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_as_strided_view_update) {
|
||||
return LowerAsStridedViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(
|
||||
node, *torch::lazy::ltc_as_strided_view_update));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_cast) {
|
||||
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(
|
||||
node, *torch::lazy::ltc_cast));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_select_view_update) {
|
||||
return LowerSelectViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(
|
||||
node, *torch::lazy::ltc_select_view_update));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_narrow_view_update) {
|
||||
return LowerNarrowViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(
|
||||
node, *torch::lazy::ltc_narrow_view_update));
|
||||
}
|
||||
if (node->op().op == at::prim::Constant) {
|
||||
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(
|
||||
node, torch::lazy::OpKind(at::prim::Constant)));
|
||||
}
|
||||
if (node->op().op == at::aten::bernoulli) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
if (node->op().op == at::aten::expand) {
|
||||
return LowerExpand(torch::lazy::NodeCast<torch::lazy::Expand>(
|
||||
node, torch::lazy::OpKind(at::aten::expand)));
|
||||
}
|
||||
if (node->op().op == at::aten::narrow) {
|
||||
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(
|
||||
node, torch::lazy::OpKind(at::aten::narrow)));
|
||||
}
|
||||
if (node->op().op == at::aten::permute) {
|
||||
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(
|
||||
node, torch::lazy::OpKind(at::aten::permute)));
|
||||
}
|
||||
if (node->op().op == at::aten::select) {
|
||||
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(
|
||||
node, torch::lazy::OpKind(at::aten::select)));
|
||||
}
|
||||
if (node->op().op == at::aten::squeeze) {
|
||||
return LowerSqueeze(torch::lazy::NodeCast<torch::lazy::Squeeze>(
|
||||
node, torch::lazy::OpKind(at::aten::squeeze)));
|
||||
}
|
||||
if (node->op().op == at::aten::unsqueeze) {
|
||||
return LowerUnsqueeze(torch::lazy::NodeCast<torch::lazy::Unsqueeze>(
|
||||
node, torch::lazy::OpKind(at::aten::unsqueeze)));
|
||||
}
|
||||
if (node->op().op == at::aten::view) {
|
||||
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(
|
||||
node, torch::lazy::OpKind(at::aten::view)));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_device_data) {
|
||||
const torch::lazy::DeviceData* device_data_node =
|
||||
torch::lazy::NodeCast<torch::lazy::DeviceData>(
|
||||
node, *torch::lazy::ltc_device_data);
|
||||
auto infoptr = device_data_node->data()->info();
|
||||
auto deviceDataInfoPtr =
|
||||
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
||||
if (GRAPH_DUMP_ENABLED) {
|
||||
LOG(ERROR) << "Lowering device data node, tensor id "
|
||||
<< deviceDataInfoPtr->tensor_id << std::endl;
|
||||
}
|
||||
return {loctx()->GetParameter(device_data_node->data())};
|
||||
}
|
||||
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (const torch::lazy::Output& output : node->operands()) {
|
||||
arguments.emplace_back(loctx()->GetOutputOp(output));
|
||||
}
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
const torch::lazy::Node* node,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(
|
||||
function_, node->op().op, node->shapes(), arguments, kwarguments);
|
||||
}
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(
|
||||
function_, sym, result_shapes, arguments, kwarguments);
|
||||
}
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
c10::Symbol sym, const std::vector<c10::TypePtr> types,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(function_, sym, types, arguments, kwarguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->size);
|
||||
arguments.emplace_back(node->stride);
|
||||
arguments.emplace_back(node->storage_offset);
|
||||
TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
return {GenerateClone(as_strided_out.front())};
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerAsStridedViewUpdate(const torch::lazy::AsStridedViewUpdate* node) {
|
||||
torch::jit::Value* destination =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
const torch::lazy::Output& input_op = node->operand(1);
|
||||
const torch::lazy::Shape& input_shape = input_op.shape();
|
||||
const auto input_dimensions = input_shape.sizes();
|
||||
std::vector<torch::jit::NamedValue> dest_arguments;
|
||||
dest_arguments.emplace_back(destination);
|
||||
dest_arguments.emplace_back(
|
||||
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
|
||||
dest_arguments.emplace_back(node->stride);
|
||||
dest_arguments.emplace_back(node->storage_offset);
|
||||
TorchMlirOpVector as_strided_out =
|
||||
LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
torch::jit::Value* as_strided = as_strided_out.front();
|
||||
GenerateCopy(as_strided, loctx()->GetOutputOp(input_op));
|
||||
return {destination};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->dtype);
|
||||
return LowerBuiltin(at::aten::to, node->shapes(), arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->size);
|
||||
auto expand_out = LowerBuiltin(node, arguments);
|
||||
if (node->is_scalar_expand) {
|
||||
// The aten::expand operations sets all strides to 0 when the original
|
||||
// of rank 0. This leads to false positives when checking for internal
|
||||
// memory overlap, because at::has_internal_overlap returns
|
||||
// MemOverlap::YES when a stride is set to 0.
|
||||
CHECK_EQ(expand_out.size(), 1);
|
||||
return {GenerateClone(expand_out.front())};
|
||||
}
|
||||
return expand_out;
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) {
|
||||
const torch::lazy::Output& input = node->operand(0);
|
||||
torch::jit::Value* base = loctx()->GetOutputOp(input);
|
||||
const auto& base_indices = node->base_indices;
|
||||
const auto& sizes = node->sizes;
|
||||
const torch::lazy::Shape& input_shape = input.shape();
|
||||
CHECK_EQ(sizes.size(), base_indices.size());
|
||||
CHECK_EQ(input_shape.dim(), base_indices.size());
|
||||
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
||||
int64_t start = base_indices[dim];
|
||||
base = GenerateSlice(
|
||||
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
||||
/*end=*/start + sizes[dim], /*step=*/1);
|
||||
}
|
||||
return {base};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->dims);
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) {
|
||||
const at::Scalar& value = node->value;
|
||||
const torch::lazy::Shape& shape = node->shape();
|
||||
auto options =
|
||||
at::TensorOptions()
|
||||
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
|
||||
.dtype(shape.scalar_type());
|
||||
return {
|
||||
loctx()->graph()->insertConstant(at::scalar_tensor(value, options))};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) {
|
||||
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
|
||||
torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0));
|
||||
return {GenerateSlice(
|
||||
/*base=*/base, /*dim=*/node->dim,
|
||||
/*start=*/node->start, /*end=*/node->end,
|
||||
/*step=*/step)};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
if (node->dim != -1) {
|
||||
arguments.push_back(node->dim);
|
||||
}
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
int64_t step = torch::lazy::GetStride(node->start, node->end, node->stride);
|
||||
torch::jit::Value* selected = GenerateSlice(
|
||||
/*base=*/dest, /*dim=*/node->dim, /*start=*/node->start,
|
||||
/*end=*/node->end, /*step=*/step);
|
||||
GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1)));
|
||||
return {dest};
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx()->GetOutputOp(node->operand(0)));
|
||||
const auto& base_indices = node->base_indices;
|
||||
const torch::lazy::Output& source_argument = node->operand(1);
|
||||
const torch::lazy::Shape& source_shape = source_argument.shape();
|
||||
CHECK_EQ(source_shape.dim(), base_indices.size());
|
||||
torch::jit::Value* base = dest;
|
||||
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
||||
int64_t start = base_indices[dim];
|
||||
base = GenerateSlice(
|
||||
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
||||
/*end=*/start + source_shape.size(dim),
|
||||
/*step=*/1);
|
||||
}
|
||||
GenerateCopy(base, loctx()->GetOutputOp(source_argument));
|
||||
return {dest};
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->dim);
|
||||
return LowerBuiltin(node, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerView(const torch::lazy::View* node) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->output_size);
|
||||
return LowerBuiltin(at::aten::reshape, node->shapes(), arguments);
|
||||
}
|
||||
|
||||
torch::jit::Value* GenerateClone(torch::jit::Value* val) {
|
||||
torch::jit::Value*
|
||||
GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
|
||||
std::vector<torch::jit::NamedValue> clone_arguments;
|
||||
clone_arguments.emplace_back(val);
|
||||
|
||||
// Type of cloned value should be identical to the original one.
|
||||
TorchMlirOpVector cloned =
|
||||
LowerBuiltin(at::aten::clone, {val->type()}, clone_arguments);
|
||||
LowerBuiltin(at::aten::clone, {val->type()}, function, clone_arguments);
|
||||
CHECK_EQ(cloned.size(), 1);
|
||||
return cloned.front();
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source) {
|
||||
|
||||
void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source, TorchMlirFunction function) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(destination);
|
||||
arguments.emplace_back(source);
|
||||
LowerBuiltin(
|
||||
at::aten::copy_,
|
||||
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), arguments);
|
||||
}
|
||||
c10::ArrayRef<Shape>(compute_shape_copy(source->type())), function, arguments);
|
||||
}
|
||||
|
||||
torch::jit::Value* GenerateSlice(
|
||||
|
||||
torch::jit::Value* GenerateSlice(
|
||||
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
|
||||
int64_t step) {
|
||||
int64_t step, TorchMlirFunction function) {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(base);
|
||||
arguments.emplace_back(dim);
|
||||
|
@ -513,19 +234,247 @@ public:
|
|||
at::aten::slice,
|
||||
c10::ArrayRef<Shape>(
|
||||
compute_shape_slice(base->type(), dim, start, end, step)),
|
||||
function,
|
||||
arguments);
|
||||
CHECK_EQ(selected.size(), 1);
|
||||
return selected.front();
|
||||
}
|
||||
torch::lazy::TorchMlirLoweringContext* loctx_;
|
||||
std::shared_ptr<torch::jit::GraphFunction> function_;
|
||||
};
|
||||
|
||||
std::unique_ptr<TorchMlirNodeLoweringInterface>
|
||||
TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) {
|
||||
return std::make_unique<TorchMlirNodeLowering>(
|
||||
"TorchMlirNodeLowering",
|
||||
static_cast<torch::lazy::TorchMlirLoweringContext*>(loctx));
|
||||
}
|
||||
|
||||
// Node Lowerings
|
||||
|
||||
// Default Node Lowering
|
||||
TorchMlirOpVector TorchMlirNode::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (const torch::lazy::Output& output : operands()) {
|
||||
arguments.emplace_back(loctx->GetOutputOp(output));
|
||||
}
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
// TorchMlir specific nodes
|
||||
|
||||
// Non-native nodes
|
||||
|
||||
TorchMlirOpVector
|
||||
Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||
arguments.emplace_back(dtype);
|
||||
return LowerBuiltin(at::aten::to, shapes(), function, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector DeviceData::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
auto infoptr = data_->info();
|
||||
auto deviceDataInfoPtr =
|
||||
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
||||
if (GRAPH_DUMP_ENABLED) {
|
||||
LOG(ERROR) << "Lowering device data node, tensor id "
|
||||
<< deviceDataInfoPtr->tensor_id << std::endl;
|
||||
}
|
||||
return {loctx->GetParameter(data_)};
|
||||
}
|
||||
|
||||
TorchMlirOpVector Expand::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||
arguments.emplace_back(size);
|
||||
auto expand_out = LowerBuiltin(this, function, arguments);
|
||||
if (is_scalar_expand) {
|
||||
// The aten::expand operations sets all strides to 0 when the original is
|
||||
// of rank 0. This leads to false positives when checking for internal
|
||||
// memory overlap, because at::has_internal_overlap returns
|
||||
// MemOverlap::YES when a stride is set to 0.
|
||||
CHECK_EQ(expand_out.size(), 1);
|
||||
return {GenerateClone(expand_out.front(), function)};
|
||||
}
|
||||
return expand_out;
|
||||
}
|
||||
|
||||
TorchMlirOpVector Scalar::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
auto options =
|
||||
at::TensorOptions()
|
||||
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
|
||||
.dtype(shape().scalar_type());
|
||||
return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
|
||||
}
|
||||
|
||||
// View Ops
|
||||
|
||||
TorchMlirOpVector AsStrided::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||
arguments.emplace_back(size);
|
||||
arguments.emplace_back(stride);
|
||||
arguments.emplace_back(storage_offset);
|
||||
TorchMlirOpVector as_strided_out = LowerBuiltin(this, function, arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
return {GenerateClone(as_strided_out.front(), function)};
|
||||
}
|
||||
|
||||
TorchMlirOpVector AsStridedViewUpdate::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
|
||||
torch::jit::Value* destination =
|
||||
GenerateClone(loctx->GetOutputOp(operand(0)), function);
|
||||
const torch::lazy::Output& input_op = operand(1);
|
||||
const torch::lazy::Shape& input_shape = input_op.shape();
|
||||
const auto input_dimensions = input_shape.sizes();
|
||||
std::vector<torch::jit::NamedValue> dest_arguments;
|
||||
dest_arguments.emplace_back(destination);
|
||||
dest_arguments.emplace_back(
|
||||
std::vector<int64_t>(input_dimensions.begin(), input_dimensions.end()));
|
||||
dest_arguments.emplace_back(stride);
|
||||
dest_arguments.emplace_back(storage_offset);
|
||||
TorchMlirOpVector as_strided_out =
|
||||
LowerBuiltin(at::aten::as_strided, shapes(), function, dest_arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
torch::jit::Value* as_strided = as_strided_out.front();
|
||||
GenerateCopy(as_strided, loctx->GetOutputOp(input_op), function);
|
||||
return {destination};
|
||||
}
|
||||
|
||||
TorchMlirOpVector Diagonal::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||
arguments.emplace_back(offset);
|
||||
arguments.emplace_back(dim1);
|
||||
arguments.emplace_back(dim2);
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector DiagonalViewUpdate::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
// Since we promise the backends that we never generate any aliased
|
||||
// inplace update IR, therefore we clone the target first and then
|
||||
// update the clone inplace instead. Since the clone is transient,
|
||||
// it will never be aliased, and therefore it's safe.
|
||||
torch::jit::Value* destination =
|
||||
GenerateClone(loctx->GetOutputOp(operand(0)), function);
|
||||
|
||||
// Replay the diagonal.
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(destination);
|
||||
arguments.emplace_back(offset);
|
||||
arguments.emplace_back(dim1);
|
||||
arguments.emplace_back(dim2);
|
||||
auto diag = LowerBuiltin(at::aten::diagonal, shapes(), function, arguments);
|
||||
|
||||
// Update the replayed diagonal view with the input.
|
||||
GenerateCopy(diag.front(), loctx->GetOutputOp(operand(1)), function);
|
||||
|
||||
// Destination's diag view should be updated.
|
||||
return {destination};
|
||||
}
|
||||
|
||||
TorchMlirOpVector Narrow::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
const torch::lazy::Output& input = operand(0);
|
||||
torch::jit::Value* base = loctx->GetOutputOp(input);
|
||||
const torch::lazy::Shape& input_shape = input.shape();
|
||||
CHECK_EQ(sizes.size(), base_indices.size());
|
||||
CHECK_EQ(input_shape.dim(), base_indices.size());
|
||||
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
||||
int64_t start = base_indices[dim];
|
||||
base = GenerateSlice(
|
||||
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
||||
/*end=*/start + sizes[dim], /*step=*/1,
|
||||
/*function=*/function);
|
||||
}
|
||||
return {base};
|
||||
}
|
||||
|
||||
TorchMlirOpVector NarrowViewUpdate::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx->GetOutputOp(operand(0)), function);
|
||||
const torch::lazy::Output& source_argument = operand(1);
|
||||
const torch::lazy::Shape& source_shape = source_argument.shape();
|
||||
CHECK_EQ(source_shape.dim(), base_indices.size());
|
||||
torch::jit::Value* base = dest;
|
||||
for (size_t dim = 0; dim < base_indices.size(); ++dim) {
|
||||
int64_t start = base_indices[dim];
|
||||
base = GenerateSlice(
|
||||
/*base=*/base, /*dim=*/dim, /*start=*/start,
|
||||
/*end=*/start + source_shape.size(dim), /*step=*/1,
|
||||
/*function=*/function);
|
||||
}
|
||||
GenerateCopy(base, loctx->GetOutputOp(source_argument), function);
|
||||
return {dest};
|
||||
}
|
||||
|
||||
TorchMlirOpVector Permute::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||
arguments.emplace_back(dims);
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector Resize::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
for (const torch::lazy::Output& output : operands()) {
|
||||
arguments.emplace_back(loctx->GetOutputOp(output));
|
||||
}
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector Select::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
int64_t step = torch::lazy::GetStride(start, end, stride);
|
||||
torch::jit::Value* base = loctx->GetOutputOp(operand(0));
|
||||
return {GenerateSlice(
|
||||
/*base=*/base, /*dim=*/dim,
|
||||
/*start=*/start, /*end=*/end,
|
||||
/*step=*/step, /*function=*/function)};
|
||||
}
|
||||
|
||||
TorchMlirOpVector SelectViewUpdate::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
torch::jit::Value* dest =
|
||||
GenerateClone(loctx->GetOutputOp(operand(0)), function);
|
||||
int64_t step = torch::lazy::GetStride(start, end, stride);
|
||||
torch::jit::Value* selected = GenerateSlice(
|
||||
/*base=*/dest, /*dim=*/dim, /*start=*/start,
|
||||
/*end=*/end, /*step=*/step, /*function=*/function);
|
||||
GenerateCopy(selected, loctx->GetOutputOp(operand(1)), function);
|
||||
return {dest};
|
||||
}
|
||||
|
||||
TorchMlirOpVector Squeeze::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||
if (dim != -1) {
|
||||
arguments.emplace_back(dim);
|
||||
}
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector Unsqueeze::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||
arguments.emplace_back(dim);
|
||||
return LowerBuiltin(this, function, arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
View::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||
arguments.emplace_back(output_size);
|
||||
return LowerBuiltin(at::aten::reshape, shapes(), function, arguments);
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "../mlir_lowering_context.h"
|
||||
#include "../mlir_node.h"
|
||||
|
||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||
|
@ -34,6 +35,8 @@ class TORCH_API DeviceData : public TorchMlirNode {
|
|||
data_ = data;
|
||||
}
|
||||
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override;
|
||||
|
||||
static const DeviceData* Cast(const Node* node);
|
||||
|
||||
// To reuse IR nodes, use this method to create DeviceData nodes
|
||||
|
|
Loading…
Reference in New Issue