mirror of https://github.com/llvm/torch-mlir
LTC multi-output operations support (#2362)
* LTC/TorchMLIR multi-output operations support * Update torch-mlir jit lowering to support ops with dynamic number of outputs * Added support for aten::split_copy, aten::split_with_sizes_copy * Fix native function for aten::split; cleanup code * Fix TorchMlirTensorList lowering * Remove xfailspull/2408/head
parent
aa007da5ac
commit
5743b6d4ac
|
@ -48,11 +48,6 @@ whitelist:
|
|||
# Enabled for consistency with TS backend
|
||||
- arange.start_out
|
||||
|
||||
# List of ops to autogen even if not supported by Torch-MLIR explicitly
|
||||
#- split_copy.Tensor
|
||||
#- split_with_sizes_copy
|
||||
#- unbind_copy.int
|
||||
|
||||
# List of supported ops that we don't want to do the full codegen for
|
||||
supported:
|
||||
# - bernoulli
|
||||
|
@ -63,6 +58,9 @@ supported:
|
|||
- empty_strided
|
||||
- fill_.Scalar
|
||||
- _unsafe_view
|
||||
- unbind_copy.int
|
||||
- split_copy.Tensor
|
||||
- split_with_sizes_copy
|
||||
|
||||
# ops required for functionalization
|
||||
- lift
|
||||
|
@ -92,11 +90,13 @@ symint:
|
|||
- narrow_copy
|
||||
- slice_backward
|
||||
- slice_copy.Tensor
|
||||
- split_copy.Tensor
|
||||
- slice_scatter
|
||||
- view
|
||||
- view_copy
|
||||
- as_strided_copy
|
||||
- as_strided_scatter
|
||||
- split_with_sizes_copy
|
||||
|
||||
|
||||
additional_ops:
|
||||
|
|
|
@ -1363,8 +1363,6 @@ LTC_XFAIL_SET = {
|
|||
"SqrtIntModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TensorsConcatNegativeDimModule_basic",
|
||||
"TensorsConcatPromoteDTypeModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToBool_basic",
|
||||
|
@ -1372,9 +1370,6 @@ LTC_XFAIL_SET = {
|
|||
"TensorToFloat_basic",
|
||||
"TensorToIntZeroRank_basic",
|
||||
"TensorToInt_basic",
|
||||
"TensorsConcatModule_basic",
|
||||
"TensorsConcatStaticModule_basic",
|
||||
"TensorsConcatNegativeDimStaticModule_basic",
|
||||
"UniformModule_basic",
|
||||
"UniformNoCorrelationModule_basic",
|
||||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
|
@ -1440,16 +1435,6 @@ LTC_XFAIL_SET = {
|
|||
"AtenComplexImagModule_basic",
|
||||
"AtenComplexRealModule_basic",
|
||||
"AtenComplexViewModule_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||
"ScatterValueFloatModule_basic",
|
||||
"ScatterValueIntModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
|
|
|
@ -3742,6 +3742,80 @@ def Torch_AtenViewAsComplexOp : Torch_Op<"aten.view_as_complex", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::unbind_copy.int : (Tensor, int) -> (Tensor[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListOfTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUnbindCopyIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenUnbindCopyIntOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSplitCopyTensorOp : Torch_Op<"aten.split_copy.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::split_copy.Tensor : (Tensor, int, int) -> (Tensor[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$split_size,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListOfTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSplitCopyTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenSplitCopyTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSplitWithSizesCopyOp : Torch_Op<"aten.split_with_sizes_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::split_with_sizes_copy : (Tensor, int[], int) -> (Tensor[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$split_sizes,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListOfTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSplitWithSizesCopyOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenSplitWithSizesCopyOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUniformOp : Torch_Op<"aten.uniform", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -71,6 +71,8 @@ add_library(torch_mlir_ltc_backend SHARED
|
|||
mlir_node.cpp
|
||||
ops/device_data.cpp
|
||||
ops/generic.cpp
|
||||
ops/split.cpp
|
||||
ops/unbind_int.cpp
|
||||
utils/jit_utils.cpp
|
||||
utils/tensor_utils.cpp
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/MetaFunctions.h>
|
||||
|
@ -31,9 +32,47 @@
|
|||
#include "generated/LazyNativeFunctions.h"
|
||||
#include "generated/shape_inference.h"
|
||||
#include "ops/to_copy.h"
|
||||
#include "ops/unbind_int.h"
|
||||
#include "ops/split.h"
|
||||
#include "utils/exception.h"
|
||||
#include "utils/sys_utils.h"
|
||||
|
||||
namespace {
|
||||
at::Tensor to_meta(const at::Tensor& tensor) {
|
||||
// undefined tensors can't be converted to the meta device, since they don't
|
||||
// have sizes/strides
|
||||
if (!tensor.defined())
|
||||
return tensor;
|
||||
auto out = at::native::empty_strided_meta_symint(
|
||||
tensor.sym_sizes(), tensor.sym_strides(),
|
||||
/*dtype=*/c10::make_optional(tensor.scalar_type()),
|
||||
/*layout=*/c10::make_optional(tensor.layout()),
|
||||
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
|
||||
/*pin_memory=*/c10::nullopt);
|
||||
// needs to handle wrapped numbers, so dtype promotion works properly.
|
||||
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
out.unsafeGetTensorImpl()->set_wrapped_number(true);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor>& tensor) {
|
||||
if (tensor.has_value()) {
|
||||
return to_meta(*tensor);
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
||||
std::vector<at::Tensor> outs;
|
||||
outs.reserve(t_list.size());
|
||||
for (const auto& tensor : t_list) {
|
||||
outs.push_back(to_meta(tensor));
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
|
@ -359,6 +398,125 @@ at::Tensor LazyNativeFunctions::_unsafe_view(
|
|||
return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||
TORCH_INTERNAL_ASSERT(common_device);
|
||||
|
||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim);
|
||||
if (!node) {
|
||||
auto self_meta = to_meta(self);
|
||||
auto out_meta = at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim);
|
||||
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
for (const auto & shape : out_meta) {
|
||||
shapes.push_back(
|
||||
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
|
||||
);
|
||||
}
|
||||
|
||||
if(torch::lazy::symbolicShapeEnabled()){
|
||||
std::vector<torch::jit::IValue> inputs = { self, dim };
|
||||
const char* schema_str = "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]";
|
||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||
}
|
||||
|
||||
node = torch::lazy::MakeNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim, std::move(shapes));
|
||||
CacheNode(node);
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> result;
|
||||
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
||||
result.push_back(
|
||||
torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> LazyNativeFunctions::split_with_sizes_copy_symint(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||
TORCH_INTERNAL_ASSERT(common_device);
|
||||
|
||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitWithSizesCopy>(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim);
|
||||
if (!node) {
|
||||
auto self_meta = to_meta(self);
|
||||
auto out_meta = at::compositeexplicitautogradnonfunctional::split_with_sizes_copy_symint(self_meta, split_sizes, dim);
|
||||
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
for (const auto & shape : out_meta) {
|
||||
shapes.push_back(
|
||||
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
|
||||
);
|
||||
}
|
||||
|
||||
if(torch::lazy::symbolicShapeEnabled()){
|
||||
std::vector<torch::jit::IValue> inputs = { self, split_sizes, dim };
|
||||
const char* schema_str = "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]";
|
||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||
}
|
||||
|
||||
node = torch::lazy::MakeNode<SplitWithSizesCopy>(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, std::move(shapes));
|
||||
CacheNode(node);
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> result;
|
||||
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
||||
result.push_back(
|
||||
torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> LazyNativeFunctions::split_copy_symint(const at::Tensor & self, c10::SymInt split_size, int64_t dim) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||
TORCH_INTERNAL_ASSERT(common_device);
|
||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitCopyTensor>(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim);
|
||||
if (!node) {
|
||||
auto self_meta = to_meta(self);
|
||||
auto out_meta = at::compositeexplicitautogradnonfunctional::split_copy_symint(self_meta, split_size, dim);
|
||||
|
||||
std::vector<torch::lazy::Shape> shapes;
|
||||
for (const auto & shape : out_meta) {
|
||||
shapes.push_back(
|
||||
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
|
||||
);
|
||||
}
|
||||
const size_t num_outputs = shapes.size();
|
||||
|
||||
if(torch::lazy::symbolicShapeEnabled()){
|
||||
std::vector<torch::jit::IValue> inputs = { self, split_size, dim };
|
||||
const char* schema_str = "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]";
|
||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||
}
|
||||
|
||||
node = torch::lazy::MakeNode<SplitCopyTensor>(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, std::move(shapes), num_outputs);
|
||||
CacheNode(node);
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> result;
|
||||
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
||||
result.push_back(
|
||||
torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
|
||||
)
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// This is needed by the torch.tensor constructor.
|
||||
// LazyTensor always opts into functionalization.
|
||||
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.
|
||||
|
|
|
@ -116,7 +116,7 @@ torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower(
|
|||
}
|
||||
auto graph = function->graph();
|
||||
auto listnode =
|
||||
graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list));
|
||||
graph->insertNode(graph->createList(c10::TensorType::get(), tensor_list));
|
||||
return {listnode->output()};
|
||||
}
|
||||
|
||||
|
|
|
@ -55,8 +55,17 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
|||
CHECK(sv);
|
||||
|
||||
TorchMlirOpVector results;
|
||||
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
|
||||
// Op returns multiple values.
|
||||
if (sv->getValue()->type()->kind() == c10::TypeKind::ListType) {
|
||||
// Unpack dynamic multi-output operations like aten::split with Tensor[] output type.
|
||||
// This is required to have consistent input types for multi-output node consumers.
|
||||
torch::jit::Node * node = function->graph()->createListUnpack(sv->getValue(), tensor_types.size());
|
||||
function->graph()->insertNode(node);
|
||||
for (const auto & output : node->outputs()) {
|
||||
results.push_back(output);
|
||||
}
|
||||
} else if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
|
||||
// Op returns multiple values and the number of outputs is static and defined
|
||||
// by the operation schema.
|
||||
const auto tuple_call_result = sv->asTuple({}, *function);
|
||||
for (const auto& tuple_component : tuple_call_result) {
|
||||
auto tuple_component_sv =
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
//===- split.cpp ----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "split.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
SplitWithSizesCopy::SplitWithSizesCopy(
|
||||
const torch::lazy::Value& self, const ::std::vector<int64_t>& split_sizes,
|
||||
const int64_t& dim, std::vector<torch::lazy::Shape>&& shapes)
|
||||
: torch::lazy::TorchMlirNode(SplitWithSizesCopy::ClassOpKind(),
|
||||
OpList{ self }, std::move(shapes),
|
||||
split_sizes.size() /* num_outputs */,
|
||||
torch::lazy::MHash(split_sizes, dim)),
|
||||
split_sizes(split_sizes), dim(dim) {}
|
||||
|
||||
std::string SplitWithSizesCopy::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TorchMlirNode::ToString();
|
||||
ss << ", split_sizes=" << split_sizes;
|
||||
ss << ", dim=" << dim;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value& self,
|
||||
const ::std::vector<int64_t>& split_sizes,
|
||||
const int64_t& dim) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
SplitWithSizesCopy::Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const {
|
||||
PRINT_FUNCTION();
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
std::vector<torch::jit::NamedValue> kwarguments;
|
||||
arguments.reserve(3);
|
||||
kwarguments.reserve(0);
|
||||
size_t i = 0;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back("split_sizes", split_sizes);
|
||||
arguments.emplace_back("dim", dim);
|
||||
|
||||
torch::lazy::TorchMlirOpVector split_with_sizes_copy_out =
|
||||
torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments,
|
||||
kwarguments);
|
||||
|
||||
return split_with_sizes_copy_out;
|
||||
}
|
||||
|
||||
SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& split_size,
|
||||
const int64_t& dim,
|
||||
std::vector<torch::lazy::Shape>&& shapes,
|
||||
const size_t num_outputs)
|
||||
: torch::lazy::TorchMlirNode(SplitCopyTensor::ClassOpKind(),
|
||||
OpList{ self, split_size }, std::move(shapes),
|
||||
num_outputs, torch::lazy::MHash(dim)),
|
||||
dim(dim) {}
|
||||
|
||||
std::string SplitCopyTensor::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TorchMlirNode::ToString();
|
||||
ss << ", dim=" << dim;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool SplitCopyTensor::CanBeReused(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& split_size,
|
||||
const int64_t& dim) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
TorchMlirOpVector
|
||||
SplitCopyTensor::Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const {
|
||||
PRINT_FUNCTION();
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
std::vector<torch::jit::NamedValue> kwarguments;
|
||||
arguments.reserve(3);
|
||||
kwarguments.reserve(0);
|
||||
size_t i = 0;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back("dim", dim);
|
||||
|
||||
torch::lazy::TorchMlirOpVector split_copy_out =
|
||||
torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments,
|
||||
kwarguments);
|
||||
return split_copy_out;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,65 @@
|
|||
//===- split.h ------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../mlir_node.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class SplitWithSizesCopy : public torch::lazy::TorchMlirNode {
|
||||
public:
|
||||
static torch::lazy::OpKind ClassOpKind() {
|
||||
return torch::lazy::OpKind(at::aten::split_with_sizes_copy);
|
||||
}
|
||||
|
||||
SplitWithSizesCopy(const torch::lazy::Value& self,
|
||||
const ::std::vector<int64_t>& split_sizes,
|
||||
const int64_t& dim,
|
||||
std::vector<torch::lazy::Shape>&& shapes);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
bool CanBeReused(const torch::lazy::Value& self,
|
||||
const ::std::vector<int64_t>& split_sizes,
|
||||
const int64_t& dim) const;
|
||||
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const override;
|
||||
|
||||
std::vector<int64_t> split_sizes;
|
||||
int64_t dim;
|
||||
};
|
||||
|
||||
class SplitCopyTensor : public torch::lazy::TorchMlirNode {
|
||||
public:
|
||||
static torch::lazy::OpKind ClassOpKind() {
|
||||
return torch::lazy::OpKind(at::aten::split_copy);
|
||||
}
|
||||
|
||||
SplitCopyTensor(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& split_size, const int64_t& dim,
|
||||
std::vector<torch::lazy::Shape>&& shapes,
|
||||
const size_t num_outputs = 1);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
bool CanBeReused(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& split_size,
|
||||
const int64_t& dim) const;
|
||||
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const override;
|
||||
|
||||
int64_t dim;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,54 @@
|
|||
//===- unbind_int.cpp -----------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "unbind_int.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim,
|
||||
std::vector<torch::lazy::Shape>&& shapes)
|
||||
: torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{ self },
|
||||
std::move(shapes),
|
||||
self.shape().size(dim), /* num_outputs */
|
||||
torch::lazy::MHash(dim)),
|
||||
dim(dim) {}
|
||||
|
||||
std::string UnbindCopyInt::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TorchMlirNode::ToString();
|
||||
ss << ", dim=" << dim;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool UnbindCopyInt::CanBeReused(const torch::lazy::Value& self,
|
||||
const int64_t& dim) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
TorchMlirOpVector UnbindCopyInt::Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const {
|
||||
PRINT_FUNCTION();
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
std::vector<torch::jit::NamedValue> kwarguments;
|
||||
arguments.reserve(2);
|
||||
kwarguments.reserve(0);
|
||||
size_t i = 0;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back("dim", dim);
|
||||
|
||||
torch::lazy::TorchMlirOpVector unbind_copy_out =
|
||||
torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments,
|
||||
kwarguments);
|
||||
|
||||
return unbind_copy_out;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,37 @@
|
|||
//===- unbind_int.h ------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../mlir_node.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class UnbindCopyInt : public torch::lazy::TorchMlirNode {
|
||||
public:
|
||||
static torch::lazy::OpKind ClassOpKind() {
|
||||
return torch::lazy::OpKind(at::aten::unbind_copy);
|
||||
}
|
||||
|
||||
UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim,
|
||||
std::vector<torch::lazy::Shape>&& shapes);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
bool CanBeReused(const torch::lazy::Value& self, const int64_t& dim) const;
|
||||
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const override;
|
||||
|
||||
int64_t dim;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -333,6 +333,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::imag : (Tensor) -> (Tensor)")
|
||||
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
||||
|
||||
# Ops with dynamic number of outputs
|
||||
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")
|
||||
emit("aten::split_copy.Tensor : (Tensor, int, int) -> (Tensor[])")
|
||||
emit("aten::split_with_sizes_copy : (Tensor, int[], int) -> (Tensor[])")
|
||||
|
||||
# Random number generation
|
||||
emit_with_mutating_variants("aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)")
|
||||
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue