2022-03-24 22:15:43 +08:00
|
|
|
//===- mlir_node.cpp ------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// This file is adapted from pytorch/pytorch
|
|
|
|
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.cpp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir_node.h"
|
2022-07-13 03:56:52 +08:00
|
|
|
#include "utils/exception.h"
|
2022-03-24 22:15:43 +08:00
|
|
|
|
|
|
|
namespace torch {
|
|
|
|
namespace lazy {
|
|
|
|
|
2022-05-03 21:35:44 +08:00
|
|
|
namespace {
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
hash_t OperandHashes(const OpList &operands, const c10::ArrayRef<Shape> &shapes,
|
|
|
|
const hash_t &seed, bool bakeInSizes) {
|
2022-05-03 21:35:44 +08:00
|
|
|
hash_t hash = seed;
|
2024-01-30 01:59:33 +08:00
|
|
|
for (auto &operand : operands) {
|
2022-05-03 21:35:44 +08:00
|
|
|
if (!operand) {
|
|
|
|
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
|
|
|
|
hash = HashCombine(hash, operand_hash);
|
|
|
|
}
|
2024-01-30 01:59:33 +08:00
|
|
|
for (auto &shape : shapes) {
|
2022-05-03 21:35:44 +08:00
|
|
|
hash = HashCombine(hash, shape.hash(bakeInSizes));
|
|
|
|
}
|
|
|
|
return hash;
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
// Adds a static hook that is run after every single TorchMlirNode is
|
|
|
|
// initialized
|
|
|
|
static std::vector<std::function<void(TorchMlirNode *)>> constructor_hooks;
|
|
|
|
void TorchMlirNode::addConstructorHook(std::function<void(TorchMlirNode *)> f) {
|
2022-09-29 23:43:39 +08:00
|
|
|
constructor_hooks.emplace_back(f);
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands,
|
|
|
|
std::vector<Shape> &&shapes, size_t num_outputs,
|
|
|
|
hash_t hash_seed)
|
2022-05-03 21:35:44 +08:00
|
|
|
: Node(op, operands, std::move(shapes), num_outputs) {
|
|
|
|
hash_seed = HashCombine(op.hash(), hash_seed);
|
|
|
|
shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
|
2024-01-30 01:59:33 +08:00
|
|
|
dag_hash_ = (enableDynamicShape()
|
|
|
|
? OperandHashes(operands, this->shapes(), hash_seed, false)
|
|
|
|
: shape_hash_);
|
2022-09-29 23:43:39 +08:00
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
for (std::function<void(TorchMlirNode *)> &f : constructor_hooks) {
|
2022-09-29 23:43:39 +08:00
|
|
|
f(this);
|
|
|
|
}
|
2022-05-03 21:35:44 +08:00
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands,
|
|
|
|
const std::function<Shape()> &shape_fn,
|
|
|
|
size_t num_outputs, hash_t hash_seed)
|
|
|
|
: TorchMlirNode(op, operands, std::vector<Shape>{}, num_outputs,
|
|
|
|
hash_seed) {
|
2022-05-03 21:35:44 +08:00
|
|
|
addComputedShape(shape_fn);
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, size_t num_outputs,
|
|
|
|
hash_t hash_seed)
|
|
|
|
: TorchMlirNode(op, operands, std::vector<Shape>{}, num_outputs,
|
|
|
|
hash_seed) {}
|
2022-05-03 21:35:44 +08:00
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirNode::TorchMlirNode(OpKind op, Shape shape, size_t num_outputs,
|
|
|
|
hash_t hash_seed)
|
2022-05-03 21:35:44 +08:00
|
|
|
: TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
|
|
|
|
|
|
|
|
hash_t TorchMlirNode::hash() const { return dag_hash_; }
|
|
|
|
|
|
|
|
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirNode *TorchMlirNode::mlir_node(int index) const {
|
|
|
|
return dynamic_cast<TorchMlirNode *>(operands_.at(index).get());
|
2022-09-29 23:43:39 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// TorchMlirTensorList
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2022-06-08 02:38:50 +08:00
|
|
|
OpKind TorchMlirTensorList::ClassOpKind() {
|
|
|
|
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
|
|
|
// import otherwise
|
2022-07-13 03:56:52 +08:00
|
|
|
static const OpKind tensor_list_opkind =
|
|
|
|
OpKind::Get("lazy_tensors::tensor_list");
|
2022-06-08 02:38:50 +08:00
|
|
|
return tensor_list_opkind;
|
|
|
|
}
|
|
|
|
|
|
|
|
TorchMlirTensorList::TorchMlirTensorList(OpList values)
|
2022-05-03 21:35:44 +08:00
|
|
|
: TorchMlirNode(
|
2022-06-08 02:38:50 +08:00
|
|
|
/*op=*/TorchMlirTensorList::ClassOpKind(),
|
2022-05-03 21:35:44 +08:00
|
|
|
/*operands=*/values,
|
|
|
|
/*shapes=*/std::vector<Shape>(),
|
|
|
|
/*num_outputs=*/1,
|
|
|
|
/*hash_seed=*/kHashSeed) {}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
torch::lazy::TorchMlirOpVector
|
|
|
|
TorchMlirTensorList::Lower(TorchMlirFunction function,
|
|
|
|
TorchMlirLoweringContext *loctx) const {
|
|
|
|
std::vector<torch::jit::Value *> tensor_list;
|
2022-05-03 21:35:44 +08:00
|
|
|
CHECK(!operands().empty());
|
2024-01-30 01:59:33 +08:00
|
|
|
for (const torch::lazy::Output &operand : operands()) {
|
2022-05-03 21:35:44 +08:00
|
|
|
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
|
|
|
}
|
|
|
|
auto graph = function->graph();
|
|
|
|
auto listnode =
|
2023-08-21 04:32:11 +08:00
|
|
|
graph->insertNode(graph->createList(c10::TensorType::get(), tensor_list));
|
2022-05-03 21:35:44 +08:00
|
|
|
return {listnode->output()};
|
|
|
|
}
|
|
|
|
|
2023-08-30 18:29:39 +08:00
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// TorchMlirOptionalTensorList
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
OpKind TorchMlirOptionalTensorList::ClassOpKind() {
|
|
|
|
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
|
|
|
// import otherwise
|
|
|
|
static const OpKind tensor_list_opkind =
|
|
|
|
OpKind::Get("lazy_tensors::optional_tensor_list");
|
|
|
|
return tensor_list_opkind;
|
|
|
|
}
|
|
|
|
|
|
|
|
TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values)
|
|
|
|
: TorchMlirNode(
|
|
|
|
/*op=*/TorchMlirOptionalTensorList::ClassOpKind(),
|
|
|
|
/*operands=*/values,
|
|
|
|
/*shapes=*/std::vector<Shape>(),
|
|
|
|
/*num_outputs=*/1,
|
|
|
|
/*hash_seed=*/kHashSeed) {}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
torch::lazy::TorchMlirOpVector
|
|
|
|
TorchMlirOptionalTensorList::Lower(TorchMlirFunction function,
|
|
|
|
TorchMlirLoweringContext *loctx) const {
|
|
|
|
std::vector<torch::jit::Value *> tensor_list;
|
2023-08-30 18:29:39 +08:00
|
|
|
CHECK(!operands().empty());
|
2024-01-30 01:59:33 +08:00
|
|
|
for (const torch::lazy::Output &operand : operands()) {
|
2023-08-30 18:29:39 +08:00
|
|
|
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
|
|
|
}
|
|
|
|
auto graph = function->graph();
|
2024-01-30 01:59:33 +08:00
|
|
|
auto listnode = graph->insertNode(graph->createList(
|
|
|
|
c10::OptionalType::create(c10::TensorType::get()), tensor_list));
|
2023-08-30 18:29:39 +08:00
|
|
|
return {listnode->output()};
|
|
|
|
}
|
|
|
|
|
2022-03-24 22:15:43 +08:00
|
|
|
} // namespace lazy
|
|
|
|
} // namespace torch
|