mirror of https://github.com/llvm/torch-mlir
[LTC] Tensor[]? support operands type support using partial codegen (#2410)
* Tensor[]? support operands type support using partial codegen * aten.index.Tensor support via partial codegen * Add torch.index_put tracing support * Added optional tensor list type support for LTC/TorchMLIR lowering * Added comments Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>pull/2426/head snapshot-20230830.946
parent
17d02811d5
commit
6b02e9a926
|
@ -1,11 +1,7 @@
|
|||
blacklist:
|
||||
# List of unsupported ops in LTC autogen because of some error
|
||||
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here
|
||||
- _index_put_impl # Error: TODO not sure if there are other valid types to handle here
|
||||
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
|
||||
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
|
||||
- index_put # Error: TODO not sure if there are other valid types to handle here
|
||||
- index_put_ # Error: TODO not sure if there are other valid types to handle here
|
||||
# Disabled in favour of `aten::index_put` which supports optional indices via `hacked_twin` JIT hack.
|
||||
# It also doesn't have confusing `unsafe` argument.
|
||||
- _index_put_impl
|
||||
|
||||
# Ops with list of tensors output
|
||||
- split.Tensor
|
||||
|
@ -61,6 +57,8 @@ supported:
|
|||
- unbind_copy.int
|
||||
- split_copy.Tensor
|
||||
- split_with_sizes_copy
|
||||
- index.Tensor
|
||||
- index_put
|
||||
|
||||
# ops required for functionalization
|
||||
- lift
|
||||
|
|
|
@ -34,6 +34,7 @@ from .xfail_sets import (
|
|||
STABLEHLO_CRASHING_SET,
|
||||
TOSA_PASS_SET,
|
||||
LTC_XFAIL_SET,
|
||||
LTC_CRASHING_SET,
|
||||
TORCHDYNAMO_XFAIL_SET,
|
||||
TORCHDYNAMO_CRASHING_SET
|
||||
)
|
||||
|
@ -114,7 +115,7 @@ def main():
|
|||
elif args.config == "lazy_tensor_core":
|
||||
config = LazyTensorCoreTestConfig()
|
||||
xfail_set = LTC_XFAIL_SET
|
||||
crashing_set = set()
|
||||
crashing_set = LTC_CRASHING_SET
|
||||
elif args.config == "torchdynamo":
|
||||
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
xfail_set = TORCHDYNAMO_XFAIL_SET
|
||||
|
|
|
@ -1263,6 +1263,12 @@ if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
|||
"ReshapeCollapseModule_basic",
|
||||
}
|
||||
|
||||
LTC_CRASHING_SET = {
|
||||
# TODO: update test to move all inputs to the lazy device. Otherwise test fails with:
|
||||
# Check failed: lazy_tensor Input tensor is not a lazy tensor: CPUBoolType.
|
||||
"HBC_basic",
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
"_Convolution2DAllFalseModule_basic",
|
||||
"_Convolution2DBenchmarkModule_basic",
|
||||
|
@ -1295,31 +1301,6 @@ LTC_XFAIL_SET = {
|
|||
"GeIntModule_basic",
|
||||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"HBC_basic",
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||
"IndexPut1DIntAccumulateModule_basic",
|
||||
"IndexPut1DIntNonAccumulateModule_basic",
|
||||
"IndexPut2DFloatAccumulateModule_basic",
|
||||
"IndexPut2DFloatNonAccumulateModule_basic",
|
||||
"IndexPut2DIntAccumulateModule_basic",
|
||||
"IndexPut2DIntNonAccumulateModule_basic",
|
||||
"IndexPut3DFloatAccumulateModule_basic",
|
||||
"IndexPut3DFloatNonAccumulateModule_basic",
|
||||
"IndexPut3DIntAccumulateModule_basic",
|
||||
"IndexPut3DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntAccumulateModule_basic",
|
||||
|
@ -1331,27 +1312,6 @@ LTC_XFAIL_SET = {
|
|||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
"IndexTensorModule3dInput_basic",
|
||||
"IndexTensorModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"IndexTensorMultiInputContiguousCenter_basic",
|
||||
"IndexTensorMultiInputNonContiguous_basic",
|
||||
"IndexTensorMultiInputOneDim_basic",
|
||||
"IndexTensorMultiInputThreeIndexers_basic",
|
||||
"IndexTensorMultiInput_basic",
|
||||
"IndexTensorSelectDimModule_basic",
|
||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"IndexTensorHackedTwinModule_basic",
|
||||
"IndexTensorHackedTwinModule3dInput_basic",
|
||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"Matmul_dot",
|
||||
"MulIntModule_basic",
|
||||
|
@ -1421,7 +1381,5 @@ LTC_XFAIL_SET = {
|
|||
"AtenComplexViewModule_basic",
|
||||
"ScatterValueFloatModule_basic",
|
||||
"ScatterValueIntModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"UniformStaticShapeModule_basic",
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
}
|
||||
|
|
|
@ -71,6 +71,8 @@ add_library(torch_mlir_ltc_backend SHARED
|
|||
mlir_node.cpp
|
||||
ops/device_data.cpp
|
||||
ops/generic.cpp
|
||||
ops/index.cpp
|
||||
ops/ivalue.cpp
|
||||
ops/split.cpp
|
||||
ops/unbind_int.cpp
|
||||
utils/jit_utils.cpp
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||
#include <ATen/FunctionalTensorWrapper.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/MetaFunctions.h>
|
||||
|
@ -34,6 +35,8 @@
|
|||
#include "ops/to_copy.h"
|
||||
#include "ops/unbind_int.h"
|
||||
#include "ops/split.h"
|
||||
#include "ops/index.h"
|
||||
#include "ops/ivalue.h"
|
||||
#include "utils/exception.h"
|
||||
#include "utils/sys_utils.h"
|
||||
|
||||
|
@ -71,6 +74,15 @@ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
|||
}
|
||||
return outs;
|
||||
}
|
||||
|
||||
c10::List<c10::optional<at::Tensor>> to_meta(const c10::List<c10::optional<at::Tensor>>& t_list) {
|
||||
c10::List<c10::optional<at::Tensor>> outs;
|
||||
outs.reserve(t_list.size());
|
||||
for (const auto& tensor : t_list) {
|
||||
outs.push_back(to_meta(tensor));
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace torch {
|
||||
|
@ -517,6 +529,89 @@ std::vector<at::Tensor> LazyNativeFunctions::split_copy_symint(const at::Tensor
|
|||
return result;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||
TORCH_INTERNAL_ASSERT(common_device);
|
||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||
|
||||
std::vector<torch::lazy::Value> values;
|
||||
for (const auto & it : indices) {
|
||||
c10::optional<at::Tensor> tensor = it;
|
||||
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
|
||||
values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
|
||||
}
|
||||
|
||||
auto list = MakeNode<TorchMlirOptionalTensorList>(values);
|
||||
|
||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexTensor>(lazy_self->GetIrValue(), list);
|
||||
|
||||
if (!node) {
|
||||
auto self_meta = to_meta(self);
|
||||
auto indices_meta = to_meta(indices);
|
||||
auto out_meta = at::meta::index(self_meta, indices_meta);
|
||||
|
||||
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
||||
if(torch::lazy::symbolicShapeEnabled()) {
|
||||
std::vector<torch::jit::IValue> inputs = { self, indices };
|
||||
const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor";
|
||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||
}
|
||||
|
||||
node = torch::lazy::MakeNode<IndexTensor>(lazy_self->GetIrValue(), list, std::move(shapes));
|
||||
CacheNode(node);
|
||||
}
|
||||
|
||||
auto result = torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
|
||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||
TORCH_INTERNAL_ASSERT(common_device);
|
||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||
LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device);
|
||||
|
||||
std::vector<torch::lazy::Value> indices_vector;
|
||||
for (const auto & it : indices) {
|
||||
c10::optional<at::Tensor> tensor = it;
|
||||
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
|
||||
indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
|
||||
}
|
||||
|
||||
auto indices_list = MakeNode<TorchMlirOptionalTensorList>(indices_vector);
|
||||
|
||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate);
|
||||
|
||||
if (!node) {
|
||||
auto self_meta = to_meta(self);
|
||||
auto indices_meta = to_meta(indices);
|
||||
auto values_meta = to_meta(values);
|
||||
|
||||
auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate);
|
||||
|
||||
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
||||
if(torch::lazy::symbolicShapeEnabled()) {
|
||||
std::vector<torch::jit::IValue> inputs = { self, indices, values };
|
||||
const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor";
|
||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||
}
|
||||
|
||||
node = torch::lazy::MakeNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes));
|
||||
CacheNode(node);
|
||||
}
|
||||
|
||||
auto result = torch::lazy::CreateAtenFromLtcTensor(
|
||||
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// This is needed by the torch.tensor constructor.
|
||||
// LazyTensor always opts into functionalization.
|
||||
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.
|
||||
|
|
|
@ -120,5 +120,38 @@ torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower(
|
|||
return {listnode->output()};
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// TorchMlirOptionalTensorList
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
OpKind TorchMlirOptionalTensorList::ClassOpKind() {
|
||||
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
|
||||
// import otherwise
|
||||
static const OpKind tensor_list_opkind =
|
||||
OpKind::Get("lazy_tensors::optional_tensor_list");
|
||||
return tensor_list_opkind;
|
||||
}
|
||||
|
||||
TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values)
|
||||
: TorchMlirNode(
|
||||
/*op=*/TorchMlirOptionalTensorList::ClassOpKind(),
|
||||
/*operands=*/values,
|
||||
/*shapes=*/std::vector<Shape>(),
|
||||
/*num_outputs=*/1,
|
||||
/*hash_seed=*/kHashSeed) {}
|
||||
|
||||
torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower(
|
||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
||||
std::vector<torch::jit::Value*> tensor_list;
|
||||
CHECK(!operands().empty());
|
||||
for (const torch::lazy::Output& operand : operands()) {
|
||||
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
||||
}
|
||||
auto graph = function->graph();
|
||||
auto listnode =
|
||||
graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list));
|
||||
return {listnode->output()};
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -91,5 +91,18 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode {
|
|||
TorchMlirLoweringContext* loctx) const override;
|
||||
};
|
||||
|
||||
// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent
|
||||
// optional tensors, so the output type for this op is !torch.list<optional<vtensor>>.
|
||||
struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode {
|
||||
static OpKind ClassOpKind();
|
||||
|
||||
TorchMlirOptionalTensorList() = delete;
|
||||
TorchMlirOptionalTensorList(OpList values);
|
||||
|
||||
torch::lazy::TorchMlirOpVector Lower(
|
||||
TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const override;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -43,9 +43,14 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
|||
for (auto arg : arguments) {
|
||||
torch::jit::Value* value = arg.value(dummy_graph);
|
||||
if (value->type()->kind() == c10::TypeKind::ListType) {
|
||||
auto list_element_type = value->type()->cast<c10::ListType>()->getElementType();
|
||||
if (list_element_type->cast<c10::OptionalType>()) {
|
||||
value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get())));
|
||||
} else {
|
||||
value->setType(c10::ListType::create(c10::TensorType::get()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto builtin =
|
||||
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
//===- index.cpp ----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "index.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
IndexTensor::IndexTensor(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& indices,
|
||||
std::vector<torch::lazy::Shape>&& shapes)
|
||||
: torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(),
|
||||
OpList{self, indices}, std::move(shapes),
|
||||
/* num_outputs */ 1, torch::lazy::MHash()) {}
|
||||
|
||||
std::string IndexTensor::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TorchMlirNode::ToString();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool IndexTensor::CanBeReused(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& indices) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const {
|
||||
PRINT_FUNCTION();
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
std::vector<torch::jit::NamedValue> kwarguments;
|
||||
arguments.reserve(2);
|
||||
kwarguments.reserve(0);
|
||||
|
||||
size_t i = 0;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
|
||||
torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin(
|
||||
function, op().op, shapes(), arguments, kwarguments);
|
||||
TORCH_CHECK_EQ(index_out.size(), 1);
|
||||
|
||||
return index_out;
|
||||
}
|
||||
|
||||
IndexPut::IndexPut(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& indices,
|
||||
const torch::lazy::Value& values, bool accumulate,
|
||||
std::vector<torch::lazy::Shape>&& shapes)
|
||||
: torch::lazy::TorchMlirNode(
|
||||
IndexPut::ClassOpKind(), OpList{self, indices, values},
|
||||
std::move(shapes),
|
||||
/* num_outputs */ 1, torch::lazy::MHash(accumulate)),
|
||||
accumulate(accumulate) {}
|
||||
|
||||
std::string IndexPut::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TorchMlirNode::ToString();
|
||||
ss << ", accumulate=" << accumulate;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool IndexPut::CanBeReused(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& indices,
|
||||
const torch::lazy::Value& values,
|
||||
bool accumulate) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const {
|
||||
PRINT_FUNCTION();
|
||||
std::vector<torch::jit::NamedValue> arguments;
|
||||
std::vector<torch::jit::NamedValue> kwarguments;
|
||||
arguments.reserve(4);
|
||||
kwarguments.reserve(0);
|
||||
|
||||
size_t i = 0;
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
|
||||
arguments.emplace_back("accumulate", accumulate);
|
||||
|
||||
torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin(
|
||||
function, op().op, shapes(), arguments, kwarguments);
|
||||
|
||||
TORCH_CHECK_EQ(index_out.size(), 1);
|
||||
|
||||
return index_out;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,58 @@
|
|||
//===- index.h ------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../mlir_node.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
class IndexTensor : public torch::lazy::TorchMlirNode {
|
||||
public:
|
||||
static torch::lazy::OpKind ClassOpKind() {
|
||||
return torch::lazy::OpKind(at::aten::index);
|
||||
}
|
||||
|
||||
IndexTensor(const torch::lazy::Value& self, const torch::lazy::Value& indices,
|
||||
std::vector<torch::lazy::Shape>&& shapes);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
bool CanBeReused(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& indices) const;
|
||||
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const override;
|
||||
};
|
||||
|
||||
class IndexPut : public torch::lazy::TorchMlirNode {
|
||||
public:
|
||||
static torch::lazy::OpKind ClassOpKind() {
|
||||
return torch::lazy::OpKind(at::aten::index_put);
|
||||
}
|
||||
|
||||
IndexPut(const torch::lazy::Value& self, const torch::lazy::Value& indices,
|
||||
const torch::lazy::Value& values, bool accumulate,
|
||||
std::vector<torch::lazy::Shape>&& shapes);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
bool CanBeReused(const torch::lazy::Value& self,
|
||||
const torch::lazy::Value& indices,
|
||||
const torch::lazy::Value& values, bool accumulate) const;
|
||||
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const override;
|
||||
|
||||
bool accumulate;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,36 @@
|
|||
//===- ivalue.cpp
|
||||
//----------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "ivalue.h"
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
IValueConstant::IValueConstant(const c10::IValue& value)
|
||||
: torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{},
|
||||
std::vector<Shape>{},
|
||||
/* num_outputs */ 1, torch::lazy::MHash()),
|
||||
value(value) {}
|
||||
|
||||
std::string IValueConstant::ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << torch::lazy::TorchMlirNode::ToString();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const {
|
||||
return {loctx->graph()->insertConstant(value)};
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
|
@ -0,0 +1,37 @@
|
|||
//===- index.h ------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../mlir_node.h"
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
// IValueConstant IR Node represents a `prim::Constant` constructed with IValue
|
||||
// parameter which is helpfull in different usecases when we need custom
|
||||
// native ops lowering to torch-mlir IR nodes.
|
||||
class IValueConstant : public torch::lazy::TorchMlirNode {
|
||||
public:
|
||||
static torch::lazy::OpKind ClassOpKind() {
|
||||
return torch::lazy::OpKind(at::prim::Constant);
|
||||
}
|
||||
|
||||
IValueConstant(const c10::IValue& value);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||
TorchMlirLoweringContext* loctx) const override;
|
||||
|
||||
c10::IValue value;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
Loading…
Reference in New Issue