mirror of https://github.com/llvm/torch-mlir
[LTC] Tensor[]? support operands type support using partial codegen (#2410)
* Tensor[]? support operands type support using partial codegen * aten.index.Tensor support via partial codegen * Add torch.index_put tracing support * Added optional tensor list type support for LTC/TorchMLIR lowering * Added comments Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>pull/2426/head snapshot-20230830.946
parent
17d02811d5
commit
6b02e9a926
|
@ -1,11 +1,7 @@
|
||||||
blacklist:
|
blacklist:
|
||||||
# List of unsupported ops in LTC autogen because of some error
|
# Disabled in favour of `aten::index_put` which supports optional indices via `hacked_twin` JIT hack.
|
||||||
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here
|
# It also doesn't have confusing `unsafe` argument.
|
||||||
- _index_put_impl # Error: TODO not sure if there are other valid types to handle here
|
- _index_put_impl
|
||||||
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
|
||||||
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
|
|
||||||
- index_put # Error: TODO not sure if there are other valid types to handle here
|
|
||||||
- index_put_ # Error: TODO not sure if there are other valid types to handle here
|
|
||||||
|
|
||||||
# Ops with list of tensors output
|
# Ops with list of tensors output
|
||||||
- split.Tensor
|
- split.Tensor
|
||||||
|
@ -61,6 +57,8 @@ supported:
|
||||||
- unbind_copy.int
|
- unbind_copy.int
|
||||||
- split_copy.Tensor
|
- split_copy.Tensor
|
||||||
- split_with_sizes_copy
|
- split_with_sizes_copy
|
||||||
|
- index.Tensor
|
||||||
|
- index_put
|
||||||
|
|
||||||
# ops required for functionalization
|
# ops required for functionalization
|
||||||
- lift
|
- lift
|
||||||
|
|
|
@ -34,6 +34,7 @@ from .xfail_sets import (
|
||||||
STABLEHLO_CRASHING_SET,
|
STABLEHLO_CRASHING_SET,
|
||||||
TOSA_PASS_SET,
|
TOSA_PASS_SET,
|
||||||
LTC_XFAIL_SET,
|
LTC_XFAIL_SET,
|
||||||
|
LTC_CRASHING_SET,
|
||||||
TORCHDYNAMO_XFAIL_SET,
|
TORCHDYNAMO_XFAIL_SET,
|
||||||
TORCHDYNAMO_CRASHING_SET
|
TORCHDYNAMO_CRASHING_SET
|
||||||
)
|
)
|
||||||
|
@ -114,7 +115,7 @@ def main():
|
||||||
elif args.config == "lazy_tensor_core":
|
elif args.config == "lazy_tensor_core":
|
||||||
config = LazyTensorCoreTestConfig()
|
config = LazyTensorCoreTestConfig()
|
||||||
xfail_set = LTC_XFAIL_SET
|
xfail_set = LTC_XFAIL_SET
|
||||||
crashing_set = set()
|
crashing_set = LTC_CRASHING_SET
|
||||||
elif args.config == "torchdynamo":
|
elif args.config == "torchdynamo":
|
||||||
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||||
xfail_set = TORCHDYNAMO_XFAIL_SET
|
xfail_set = TORCHDYNAMO_XFAIL_SET
|
||||||
|
|
|
@ -1263,6 +1263,12 @@ if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||||
"ReshapeCollapseModule_basic",
|
"ReshapeCollapseModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LTC_CRASHING_SET = {
|
||||||
|
# TODO: update test to move all inputs to the lazy device. Otherwise test fails with:
|
||||||
|
# Check failed: lazy_tensor Input tensor is not a lazy tensor: CPUBoolType.
|
||||||
|
"HBC_basic",
|
||||||
|
}
|
||||||
|
|
||||||
LTC_XFAIL_SET = {
|
LTC_XFAIL_SET = {
|
||||||
"_Convolution2DAllFalseModule_basic",
|
"_Convolution2DAllFalseModule_basic",
|
||||||
"_Convolution2DBenchmarkModule_basic",
|
"_Convolution2DBenchmarkModule_basic",
|
||||||
|
@ -1295,31 +1301,6 @@ LTC_XFAIL_SET = {
|
||||||
"GeIntModule_basic",
|
"GeIntModule_basic",
|
||||||
"GtFloatIntModule_basic",
|
"GtFloatIntModule_basic",
|
||||||
"GtIntModule_basic",
|
"GtIntModule_basic",
|
||||||
"HBC_basic",
|
|
||||||
"IndexPut1DFloatAccumulateModule_basic",
|
|
||||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut1DIntAccumulateModule_basic",
|
|
||||||
"IndexPut1DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPut2DFloatAccumulateModule_basic",
|
|
||||||
"IndexPut2DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut2DIntAccumulateModule_basic",
|
|
||||||
"IndexPut2DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPut3DFloatAccumulateModule_basic",
|
|
||||||
"IndexPut3DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut3DIntAccumulateModule_basic",
|
|
||||||
"IndexPut3DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutImpl1DFloatAccumulateModule_basic",
|
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImpl1DIntAccumulateModule_basic",
|
"IndexPutImpl1DIntAccumulateModule_basic",
|
||||||
|
@ -1331,27 +1312,6 @@ LTC_XFAIL_SET = {
|
||||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImplIndexWithNoneModule_basic",
|
"IndexPutImplIndexWithNoneModule_basic",
|
||||||
"IndexTensorModule3dInput_basic",
|
|
||||||
"IndexTensorModule_basic",
|
|
||||||
"IndexTensorStaticModule_basic",
|
|
||||||
"IndexTensorMultiIndexStaticModule_basic",
|
|
||||||
"IndexTensorMultiInputContiguousCenter_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguous_basic",
|
|
||||||
"IndexTensorMultiInputOneDim_basic",
|
|
||||||
"IndexTensorMultiInputThreeIndexers_basic",
|
|
||||||
"IndexTensorMultiInput_basic",
|
|
||||||
"IndexTensorSelectDimModule_basic",
|
|
||||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
|
||||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorHackedTwinModule_basic",
|
|
||||||
"IndexTensorHackedTwinModule3dInput_basic",
|
|
||||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
|
||||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
|
||||||
"LiftFreshCopyModule_basic",
|
"LiftFreshCopyModule_basic",
|
||||||
"Matmul_dot",
|
"Matmul_dot",
|
||||||
"MulIntModule_basic",
|
"MulIntModule_basic",
|
||||||
|
@ -1421,7 +1381,5 @@ LTC_XFAIL_SET = {
|
||||||
"AtenComplexViewModule_basic",
|
"AtenComplexViewModule_basic",
|
||||||
"ScatterValueFloatModule_basic",
|
"ScatterValueFloatModule_basic",
|
||||||
"ScatterValueIntModule_basic",
|
"ScatterValueIntModule_basic",
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
|
||||||
"UniformStaticShapeModule_basic",
|
"UniformStaticShapeModule_basic",
|
||||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,6 +71,8 @@ add_library(torch_mlir_ltc_backend SHARED
|
||||||
mlir_node.cpp
|
mlir_node.cpp
|
||||||
ops/device_data.cpp
|
ops/device_data.cpp
|
||||||
ops/generic.cpp
|
ops/generic.cpp
|
||||||
|
ops/index.cpp
|
||||||
|
ops/ivalue.cpp
|
||||||
ops/split.cpp
|
ops/split.cpp
|
||||||
ops/unbind_int.cpp
|
ops/unbind_int.cpp
|
||||||
utils/jit_utils.cpp
|
utils/jit_utils.cpp
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||||
|
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||||
#include <ATen/FunctionalTensorWrapper.h>
|
#include <ATen/FunctionalTensorWrapper.h>
|
||||||
#include <ATen/InferSize.h>
|
#include <ATen/InferSize.h>
|
||||||
#include <ATen/MetaFunctions.h>
|
#include <ATen/MetaFunctions.h>
|
||||||
|
@ -34,6 +35,8 @@
|
||||||
#include "ops/to_copy.h"
|
#include "ops/to_copy.h"
|
||||||
#include "ops/unbind_int.h"
|
#include "ops/unbind_int.h"
|
||||||
#include "ops/split.h"
|
#include "ops/split.h"
|
||||||
|
#include "ops/index.h"
|
||||||
|
#include "ops/ivalue.h"
|
||||||
#include "utils/exception.h"
|
#include "utils/exception.h"
|
||||||
#include "utils/sys_utils.h"
|
#include "utils/sys_utils.h"
|
||||||
|
|
||||||
|
@ -71,6 +74,15 @@ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
||||||
}
|
}
|
||||||
return outs;
|
return outs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c10::List<c10::optional<at::Tensor>> to_meta(const c10::List<c10::optional<at::Tensor>>& t_list) {
|
||||||
|
c10::List<c10::optional<at::Tensor>> outs;
|
||||||
|
outs.reserve(t_list.size());
|
||||||
|
for (const auto& tensor : t_list) {
|
||||||
|
outs.push_back(to_meta(tensor));
|
||||||
|
}
|
||||||
|
return outs;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
@ -517,6 +529,89 @@ std::vector<at::Tensor> LazyNativeFunctions::split_copy_symint(const at::Tensor
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||||
|
TORCH_INTERNAL_ASSERT(common_device);
|
||||||
|
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Value> values;
|
||||||
|
for (const auto & it : indices) {
|
||||||
|
c10::optional<at::Tensor> tensor = it;
|
||||||
|
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
|
||||||
|
values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto list = MakeNode<TorchMlirOptionalTensorList>(values);
|
||||||
|
|
||||||
|
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexTensor>(lazy_self->GetIrValue(), list);
|
||||||
|
|
||||||
|
if (!node) {
|
||||||
|
auto self_meta = to_meta(self);
|
||||||
|
auto indices_meta = to_meta(indices);
|
||||||
|
auto out_meta = at::meta::index(self_meta, indices_meta);
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
|
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
||||||
|
if(torch::lazy::symbolicShapeEnabled()) {
|
||||||
|
std::vector<torch::jit::IValue> inputs = { self, indices };
|
||||||
|
const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor";
|
||||||
|
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
node = torch::lazy::MakeNode<IndexTensor>(lazy_self->GetIrValue(), list, std::move(shapes));
|
||||||
|
CacheNode(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||||
|
TORCH_INTERNAL_ASSERT(common_device);
|
||||||
|
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||||
|
LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device);
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Value> indices_vector;
|
||||||
|
for (const auto & it : indices) {
|
||||||
|
c10::optional<at::Tensor> tensor = it;
|
||||||
|
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
|
||||||
|
indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto indices_list = MakeNode<TorchMlirOptionalTensorList>(indices_vector);
|
||||||
|
|
||||||
|
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate);
|
||||||
|
|
||||||
|
if (!node) {
|
||||||
|
auto self_meta = to_meta(self);
|
||||||
|
auto indices_meta = to_meta(indices);
|
||||||
|
auto values_meta = to_meta(values);
|
||||||
|
|
||||||
|
auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate);
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
|
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
||||||
|
if(torch::lazy::symbolicShapeEnabled()) {
|
||||||
|
std::vector<torch::jit::IValue> inputs = { self, indices, values };
|
||||||
|
const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor";
|
||||||
|
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
node = torch::lazy::MakeNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes));
|
||||||
|
CacheNode(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = torch::lazy::CreateAtenFromLtcTensor(
|
||||||
|
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// This is needed by the torch.tensor constructor.
|
// This is needed by the torch.tensor constructor.
|
||||||
// LazyTensor always opts into functionalization.
|
// LazyTensor always opts into functionalization.
|
||||||
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.
|
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.
|
||||||
|
|
|
@ -120,5 +120,38 @@ torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower(
|
||||||
return {listnode->output()};
|
return {listnode->output()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// TorchMlirOptionalTensorList
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
OpKind TorchMlirOptionalTensorList::ClassOpKind() {
|
||||||
|
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
||||||
|
// import otherwise
|
||||||
|
static const OpKind tensor_list_opkind =
|
||||||
|
OpKind::Get("lazy_tensors::optional_tensor_list");
|
||||||
|
return tensor_list_opkind;
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values)
|
||||||
|
: TorchMlirNode(
|
||||||
|
/*op=*/TorchMlirOptionalTensorList::ClassOpKind(),
|
||||||
|
/*operands=*/values,
|
||||||
|
/*shapes=*/std::vector<Shape>(),
|
||||||
|
/*num_outputs=*/1,
|
||||||
|
/*hash_seed=*/kHashSeed) {}
|
||||||
|
|
||||||
|
torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower(
|
||||||
|
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||||
|
std::vector<torch::jit::Value*> tensor_list;
|
||||||
|
CHECK(!operands().empty());
|
||||||
|
for (const torch::lazy::Output& operand : operands()) {
|
||||||
|
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
||||||
|
}
|
||||||
|
auto graph = function->graph();
|
||||||
|
auto listnode =
|
||||||
|
graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list));
|
||||||
|
return {listnode->output()};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -91,5 +91,18 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode {
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext* loctx) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent
|
||||||
|
// optional tensors, so the output type for this op is !torch.list<optional<vtensor>>.
|
||||||
|
struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode {
|
||||||
|
static OpKind ClassOpKind();
|
||||||
|
|
||||||
|
TorchMlirOptionalTensorList() = delete;
|
||||||
|
TorchMlirOptionalTensorList(OpList values);
|
||||||
|
|
||||||
|
torch::lazy::TorchMlirOpVector Lower(
|
||||||
|
TorchMlirFunction function,
|
||||||
|
TorchMlirLoweringContext* loctx) const override;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -43,7 +43,12 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
for (auto arg : arguments) {
|
for (auto arg : arguments) {
|
||||||
torch::jit::Value* value = arg.value(dummy_graph);
|
torch::jit::Value* value = arg.value(dummy_graph);
|
||||||
if (value->type()->kind() == c10::TypeKind::ListType) {
|
if (value->type()->kind() == c10::TypeKind::ListType) {
|
||||||
value->setType(c10::ListType::create(c10::TensorType::get()));
|
auto list_element_type = value->type()->cast<c10::ListType>()->getElementType();
|
||||||
|
if (list_element_type->cast<c10::OptionalType>()) {
|
||||||
|
value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get())));
|
||||||
|
} else {
|
||||||
|
value->setType(c10::ListType::create(c10::TensorType::get()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
//===- index.cpp ----------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "index.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
IndexTensor::IndexTensor(const torch::lazy::Value& self,
|
||||||
|
const torch::lazy::Value& indices,
|
||||||
|
std::vector<torch::lazy::Shape>&& shapes)
|
||||||
|
: torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(),
|
||||||
|
OpList{self, indices}, std::move(shapes),
|
||||||
|
/* num_outputs */ 1, torch::lazy::MHash()) {}
|
||||||
|
|
||||||
|
std::string IndexTensor::ToString() const {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << torch::lazy::TorchMlirNode::ToString();
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IndexTensor::CanBeReused(const torch::lazy::Value& self,
|
||||||
|
const torch::lazy::Value& indices) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function,
|
||||||
|
TorchMlirLoweringContext* loctx) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
|
arguments.reserve(2);
|
||||||
|
kwarguments.reserve(0);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||||
|
|
||||||
|
torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin(
|
||||||
|
function, op().op, shapes(), arguments, kwarguments);
|
||||||
|
TORCH_CHECK_EQ(index_out.size(), 1);
|
||||||
|
|
||||||
|
return index_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
IndexPut::IndexPut(const torch::lazy::Value& self,
|
||||||
|
const torch::lazy::Value& indices,
|
||||||
|
const torch::lazy::Value& values, bool accumulate,
|
||||||
|
std::vector<torch::lazy::Shape>&& shapes)
|
||||||
|
: torch::lazy::TorchMlirNode(
|
||||||
|
IndexPut::ClassOpKind(), OpList{self, indices, values},
|
||||||
|
std::move(shapes),
|
||||||
|
/* num_outputs */ 1, torch::lazy::MHash(accumulate)),
|
||||||
|
accumulate(accumulate) {}
|
||||||
|
|
||||||
|
std::string IndexPut::ToString() const {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << torch::lazy::TorchMlirNode::ToString();
|
||||||
|
ss << ", accumulate=" << accumulate;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IndexPut::CanBeReused(const torch::lazy::Value& self,
|
||||||
|
const torch::lazy::Value& indices,
|
||||||
|
const torch::lazy::Value& values,
|
||||||
|
bool accumulate) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function,
|
||||||
|
TorchMlirLoweringContext* loctx) const {
|
||||||
|
PRINT_FUNCTION();
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
|
arguments.reserve(4);
|
||||||
|
kwarguments.reserve(0);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||||
|
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||||
|
arguments.emplace_back("accumulate", accumulate);
|
||||||
|
|
||||||
|
torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin(
|
||||||
|
function, op().op, shapes(), arguments, kwarguments);
|
||||||
|
|
||||||
|
TORCH_CHECK_EQ(index_out.size(), 1);
|
||||||
|
|
||||||
|
return index_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,58 @@
|
||||||
|
//===- index.h ------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "../mlir_node.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
class IndexTensor : public torch::lazy::TorchMlirNode {
|
||||||
|
public:
|
||||||
|
static torch::lazy::OpKind ClassOpKind() {
|
||||||
|
return torch::lazy::OpKind(at::aten::index);
|
||||||
|
}
|
||||||
|
|
||||||
|
IndexTensor(const torch::lazy::Value& self, const torch::lazy::Value& indices,
|
||||||
|
std::vector<torch::lazy::Shape>&& shapes);
|
||||||
|
|
||||||
|
std::string ToString() const override;
|
||||||
|
|
||||||
|
bool CanBeReused(const torch::lazy::Value& self,
|
||||||
|
const torch::lazy::Value& indices) const;
|
||||||
|
|
||||||
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
|
TorchMlirLoweringContext* loctx) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class IndexPut : public torch::lazy::TorchMlirNode {
|
||||||
|
public:
|
||||||
|
static torch::lazy::OpKind ClassOpKind() {
|
||||||
|
return torch::lazy::OpKind(at::aten::index_put);
|
||||||
|
}
|
||||||
|
|
||||||
|
IndexPut(const torch::lazy::Value& self, const torch::lazy::Value& indices,
|
||||||
|
const torch::lazy::Value& values, bool accumulate,
|
||||||
|
std::vector<torch::lazy::Shape>&& shapes);
|
||||||
|
|
||||||
|
std::string ToString() const override;
|
||||||
|
|
||||||
|
bool CanBeReused(const torch::lazy::Value& self,
|
||||||
|
const torch::lazy::Value& indices,
|
||||||
|
const torch::lazy::Value& values, bool accumulate) const;
|
||||||
|
|
||||||
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
|
TorchMlirLoweringContext* loctx) const override;
|
||||||
|
|
||||||
|
bool accumulate;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,36 @@
|
||||||
|
//===- ivalue.cpp
|
||||||
|
//----------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "ivalue.h"
|
||||||
|
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
IValueConstant::IValueConstant(const c10::IValue& value)
|
||||||
|
: torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{},
|
||||||
|
std::vector<Shape>{},
|
||||||
|
/* num_outputs */ 1, torch::lazy::MHash()),
|
||||||
|
value(value) {}
|
||||||
|
|
||||||
|
std::string IValueConstant::ToString() const {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << torch::lazy::TorchMlirNode::ToString();
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function,
|
||||||
|
TorchMlirLoweringContext* loctx) const {
|
||||||
|
return {loctx->graph()->insertConstant(value)};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
|
@ -0,0 +1,37 @@
|
||||||
|
//===- index.h ------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "../mlir_node.h"
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace lazy {
|
||||||
|
|
||||||
|
// IValueConstant IR Node represents a `prim::Constant` constructed with IValue
|
||||||
|
// parameter which is helpfull in different usecases when we need custom
|
||||||
|
// native ops lowering to torch-mlir IR nodes.
|
||||||
|
class IValueConstant : public torch::lazy::TorchMlirNode {
|
||||||
|
public:
|
||||||
|
static torch::lazy::OpKind ClassOpKind() {
|
||||||
|
return torch::lazy::OpKind(at::prim::Constant);
|
||||||
|
}
|
||||||
|
|
||||||
|
IValueConstant(const c10::IValue& value);
|
||||||
|
|
||||||
|
std::string ToString() const override;
|
||||||
|
|
||||||
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
|
TorchMlirLoweringContext* loctx) const override;
|
||||||
|
|
||||||
|
c10::IValue value;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
Loading…
Reference in New Issue