2023-08-21 04:32:11 +08:00
|
|
|
//===- split.cpp ----------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "split.h"
|
|
|
|
|
|
|
|
namespace torch {
|
|
|
|
namespace lazy {
|
|
|
|
|
|
|
|
SplitWithSizesCopy::SplitWithSizesCopy(
|
2024-01-30 01:59:33 +08:00
|
|
|
const torch::lazy::Value &self, const ::std::vector<int64_t> &split_sizes,
|
|
|
|
const int64_t &dim, std::vector<torch::lazy::Shape> &&shapes)
|
2023-08-21 04:32:11 +08:00
|
|
|
: torch::lazy::TorchMlirNode(SplitWithSizesCopy::ClassOpKind(),
|
2024-01-30 01:59:33 +08:00
|
|
|
OpList{self}, std::move(shapes),
|
2023-08-21 04:32:11 +08:00
|
|
|
split_sizes.size() /* num_outputs */,
|
|
|
|
torch::lazy::MHash(split_sizes, dim)),
|
|
|
|
split_sizes(split_sizes), dim(dim) {}
|
|
|
|
|
|
|
|
std::string SplitWithSizesCopy::ToString() const {
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << torch::lazy::TorchMlirNode::ToString();
|
|
|
|
ss << ", split_sizes=" << split_sizes;
|
|
|
|
ss << ", dim=" << dim;
|
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value &self,
|
|
|
|
const ::std::vector<int64_t> &split_sizes,
|
|
|
|
const int64_t &dim) const {
|
2023-08-21 04:32:11 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
TorchMlirOpVector
|
|
|
|
SplitWithSizesCopy::Lower(TorchMlirFunction function,
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirLoweringContext *loctx) const {
|
2023-08-21 04:32:11 +08:00
|
|
|
PRINT_FUNCTION();
|
|
|
|
std::vector<torch::jit::NamedValue> arguments;
|
|
|
|
std::vector<torch::jit::NamedValue> kwarguments;
|
|
|
|
arguments.reserve(3);
|
|
|
|
kwarguments.reserve(0);
|
|
|
|
size_t i = 0;
|
|
|
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
|
|
|
arguments.emplace_back("split_sizes", split_sizes);
|
|
|
|
arguments.emplace_back("dim", dim);
|
|
|
|
|
|
|
|
torch::lazy::TorchMlirOpVector split_with_sizes_copy_out =
|
|
|
|
torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments,
|
|
|
|
kwarguments);
|
|
|
|
|
|
|
|
return split_with_sizes_copy_out;
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value &self,
|
|
|
|
const torch::lazy::Value &split_size,
|
|
|
|
const int64_t &dim,
|
|
|
|
std::vector<torch::lazy::Shape> &&shapes,
|
2023-08-21 04:32:11 +08:00
|
|
|
const size_t num_outputs)
|
|
|
|
: torch::lazy::TorchMlirNode(SplitCopyTensor::ClassOpKind(),
|
2024-01-30 01:59:33 +08:00
|
|
|
OpList{self, split_size}, std::move(shapes),
|
2023-08-21 04:32:11 +08:00
|
|
|
num_outputs, torch::lazy::MHash(dim)),
|
|
|
|
dim(dim) {}
|
|
|
|
|
|
|
|
std::string SplitCopyTensor::ToString() const {
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << torch::lazy::TorchMlirNode::ToString();
|
|
|
|
ss << ", dim=" << dim;
|
|
|
|
return ss.str();
|
|
|
|
}
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
bool SplitCopyTensor::CanBeReused(const torch::lazy::Value &self,
|
|
|
|
const torch::lazy::Value &split_size,
|
|
|
|
const int64_t &dim) const {
|
2023-08-21 04:32:11 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
TorchMlirOpVector
|
|
|
|
SplitCopyTensor::Lower(TorchMlirFunction function,
|
2024-01-30 01:59:33 +08:00
|
|
|
TorchMlirLoweringContext *loctx) const {
|
2023-08-21 04:32:11 +08:00
|
|
|
PRINT_FUNCTION();
|
|
|
|
std::vector<torch::jit::NamedValue> arguments;
|
|
|
|
std::vector<torch::jit::NamedValue> kwarguments;
|
|
|
|
arguments.reserve(3);
|
|
|
|
kwarguments.reserve(0);
|
|
|
|
size_t i = 0;
|
|
|
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
|
|
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
|
|
|
arguments.emplace_back("dim", dim);
|
|
|
|
|
|
|
|
torch::lazy::TorchMlirOpVector split_copy_out =
|
|
|
|
torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments,
|
|
|
|
kwarguments);
|
|
|
|
return split_copy_out;
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace lazy
|
|
|
|
} // namespace torch
|