2023-08-30 18:29:39 +08:00
|
|
|
//===- index.cpp ----------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "index.h"
|
|
|
|
|
|
|
|
namespace torch {
|
|
|
|
namespace lazy {
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
IndexTensor::IndexTensor(const torch::lazy::Value &self,
|
|
|
|
const torch::lazy::Value &indices,
|
|
|
|
std::vector<torch::lazy::Shape> &&shapes)
|
2023-08-30 18:29:39 +08:00
|
|
|
: torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(),
|
|
|
|
OpList{self, indices}, std::move(shapes),
|
|
|
|
/* num_outputs */ 1, torch::lazy::MHash()) {}
|
|
|
|
|
|
|
|
std::string IndexTensor::ToString() const {
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << torch::lazy::TorchMlirNode::ToString();
|
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
bool IndexTensor::CanBeReused(const torch::lazy::Value &self,
|
|
|
|
const torch::lazy::Value &indices) const {
|
2023-08-30 18:29:39 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function,
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirLoweringContext *loctx) const {
|
2023-08-30 18:29:39 +08:00
|
|
|
PRINT_FUNCTION();
|
|
|
|
std::vector<torch::jit::NamedValue> arguments;
|
|
|
|
std::vector<torch::jit::NamedValue> kwarguments;
|
|
|
|
arguments.reserve(2);
|
|
|
|
kwarguments.reserve(0);
|
|
|
|
|
|
|
|
size_t i = 0;
|
|
|
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
|
|
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
|
|
|
|
|
|
|
torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin(
|
|
|
|
function, op().op, shapes(), arguments, kwarguments);
|
|
|
|
TORCH_CHECK_EQ(index_out.size(), 1);
|
|
|
|
|
|
|
|
return index_out;
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
IndexPut::IndexPut(const torch::lazy::Value &self,
|
|
|
|
const torch::lazy::Value &indices,
|
|
|
|
const torch::lazy::Value &values, bool accumulate,
|
|
|
|
std::vector<torch::lazy::Shape> &&shapes)
|
2023-08-30 18:29:39 +08:00
|
|
|
: torch::lazy::TorchMlirNode(
|
|
|
|
IndexPut::ClassOpKind(), OpList{self, indices, values},
|
|
|
|
std::move(shapes),
|
|
|
|
/* num_outputs */ 1, torch::lazy::MHash(accumulate)),
|
|
|
|
accumulate(accumulate) {}
|
|
|
|
|
|
|
|
std::string IndexPut::ToString() const {
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << torch::lazy::TorchMlirNode::ToString();
|
|
|
|
ss << ", accumulate=" << accumulate;
|
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
bool IndexPut::CanBeReused(const torch::lazy::Value &self,
|
|
|
|
const torch::lazy::Value &indices,
|
|
|
|
const torch::lazy::Value &values,
|
2023-08-30 18:29:39 +08:00
|
|
|
bool accumulate) const {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function,
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirLoweringContext *loctx) const {
|
2023-08-30 18:29:39 +08:00
|
|
|
PRINT_FUNCTION();
|
|
|
|
std::vector<torch::jit::NamedValue> arguments;
|
|
|
|
std::vector<torch::jit::NamedValue> kwarguments;
|
|
|
|
arguments.reserve(4);
|
|
|
|
kwarguments.reserve(0);
|
|
|
|
|
|
|
|
size_t i = 0;
|
|
|
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
|
|
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
|
|
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
|
|
|
arguments.emplace_back("accumulate", accumulate);
|
|
|
|
|
|
|
|
torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin(
|
|
|
|
function, op().op, shapes(), arguments, kwarguments);
|
|
|
|
|
|
|
|
TORCH_CHECK_EQ(index_out.size(), 1);
|
|
|
|
|
|
|
|
return index_out;
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
} // namespace lazy
|
|
|
|
} // namespace torch
|