[LTC] Tensor[]? support operands type support using partial codegen (#2410)

* Tensor[]? support operands type support using partial codegen

* aten.index.Tensor support via partial codegen

* Add torch.index_put tracing support

* Added optional tensor list type support for LTC/TorchMLIR lowering

* Added comments

Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
pull/2426/head snapshot-20230830.946
Gleb Kazantaev 2023-08-30 06:29:39 -04:00 committed by GitHub
parent 17d02811d5
commit 6b02e9a926
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 392 additions and 57 deletions

View File

@ -1,11 +1,7 @@
blacklist: blacklist:
# List of unsupported ops in LTC autogen because of some error # Disabled in favour of `aten::index_put` which supports optional indices via `hacked_twin` JIT hack.
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here # It also doesn't have confusing `unsafe` argument.
- _index_put_impl # Error: TODO not sure if there are other valid types to handle here - _index_put_impl
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
- index_put # Error: TODO not sure if there are other valid types to handle here
- index_put_ # Error: TODO not sure if there are other valid types to handle here
# Ops with list of tensors output # Ops with list of tensors output
- split.Tensor - split.Tensor
@ -61,6 +57,8 @@ supported:
- unbind_copy.int - unbind_copy.int
- split_copy.Tensor - split_copy.Tensor
- split_with_sizes_copy - split_with_sizes_copy
- index.Tensor
- index_put
# ops required for functionalization # ops required for functionalization
- lift - lift

View File

@ -34,6 +34,7 @@ from .xfail_sets import (
STABLEHLO_CRASHING_SET, STABLEHLO_CRASHING_SET,
TOSA_PASS_SET, TOSA_PASS_SET,
LTC_XFAIL_SET, LTC_XFAIL_SET,
LTC_CRASHING_SET,
TORCHDYNAMO_XFAIL_SET, TORCHDYNAMO_XFAIL_SET,
TORCHDYNAMO_CRASHING_SET TORCHDYNAMO_CRASHING_SET
) )
@ -114,7 +115,7 @@ def main():
elif args.config == "lazy_tensor_core": elif args.config == "lazy_tensor_core":
config = LazyTensorCoreTestConfig() config = LazyTensorCoreTestConfig()
xfail_set = LTC_XFAIL_SET xfail_set = LTC_XFAIL_SET
crashing_set = set() crashing_set = LTC_CRASHING_SET
elif args.config == "torchdynamo": elif args.config == "torchdynamo":
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = TORCHDYNAMO_XFAIL_SET xfail_set = TORCHDYNAMO_XFAIL_SET

View File

@ -1263,6 +1263,12 @@ if torch_version_for_comparison() < version.parse("2.1.0.dev"):
"ReshapeCollapseModule_basic", "ReshapeCollapseModule_basic",
} }
LTC_CRASHING_SET = {
# TODO: update test to move all inputs to the lazy device. Otherwise test fails with:
# Check failed: lazy_tensor Input tensor is not a lazy tensor: CPUBoolType.
"HBC_basic",
}
LTC_XFAIL_SET = { LTC_XFAIL_SET = {
"_Convolution2DAllFalseModule_basic", "_Convolution2DAllFalseModule_basic",
"_Convolution2DBenchmarkModule_basic", "_Convolution2DBenchmarkModule_basic",
@ -1295,31 +1301,6 @@ LTC_XFAIL_SET = {
"GeIntModule_basic", "GeIntModule_basic",
"GtFloatIntModule_basic", "GtFloatIntModule_basic",
"GtIntModule_basic", "GtIntModule_basic",
"HBC_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
"IndexPut1DIntNonAccumulateModule_basic",
"IndexPut2DFloatAccumulateModule_basic",
"IndexPut2DFloatNonAccumulateModule_basic",
"IndexPut2DIntAccumulateModule_basic",
"IndexPut2DIntNonAccumulateModule_basic",
"IndexPut3DFloatAccumulateModule_basic",
"IndexPut3DFloatNonAccumulateModule_basic",
"IndexPut3DIntAccumulateModule_basic",
"IndexPut3DIntNonAccumulateModule_basic",
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin1DIntAccumulateModule_basic",
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin2DIntAccumulateModule_basic",
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
"IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic", "IndexPutImpl1DIntAccumulateModule_basic",
@ -1331,27 +1312,6 @@ LTC_XFAIL_SET = {
"IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic", "IndexPutImplIndexWithNoneModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorSelectDimModule_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinModule_basic",
"IndexTensorHackedTwinModule3dInput_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"LiftFreshCopyModule_basic", "LiftFreshCopyModule_basic",
"Matmul_dot", "Matmul_dot",
"MulIntModule_basic", "MulIntModule_basic",
@ -1421,7 +1381,5 @@ LTC_XFAIL_SET = {
"AtenComplexViewModule_basic", "AtenComplexViewModule_basic",
"ScatterValueFloatModule_basic", "ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic", "ScatterValueIntModule_basic",
"IndexTensorNegativeIndexModule_basic",
"UniformStaticShapeModule_basic", "UniformStaticShapeModule_basic",
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
} }

View File

@ -71,6 +71,8 @@ add_library(torch_mlir_ltc_backend SHARED
mlir_node.cpp mlir_node.cpp
ops/device_data.cpp ops/device_data.cpp
ops/generic.cpp ops/generic.cpp
ops/index.cpp
ops/ivalue.cpp
ops/split.cpp ops/split.cpp
ops/unbind_int.cpp ops/unbind_int.cpp
utils/jit_utils.cpp utils/jit_utils.cpp

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h> #include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/FunctionalTensorWrapper.h> #include <ATen/FunctionalTensorWrapper.h>
#include <ATen/InferSize.h> #include <ATen/InferSize.h>
#include <ATen/MetaFunctions.h> #include <ATen/MetaFunctions.h>
@ -34,6 +35,8 @@
#include "ops/to_copy.h" #include "ops/to_copy.h"
#include "ops/unbind_int.h" #include "ops/unbind_int.h"
#include "ops/split.h" #include "ops/split.h"
#include "ops/index.h"
#include "ops/ivalue.h"
#include "utils/exception.h" #include "utils/exception.h"
#include "utils/sys_utils.h" #include "utils/sys_utils.h"
@ -71,6 +74,15 @@ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
} }
return outs; return outs;
} }
c10::List<c10::optional<at::Tensor>> to_meta(const c10::List<c10::optional<at::Tensor>>& t_list) {
c10::List<c10::optional<at::Tensor>> outs;
outs.reserve(t_list.size());
for (const auto& tensor : t_list) {
outs.push_back(to_meta(tensor));
}
return outs;
}
} // namespace } // namespace
namespace torch { namespace torch {
@ -517,6 +529,89 @@ std::vector<at::Tensor> LazyNativeFunctions::split_copy_symint(const at::Tensor
return result; return result;
} }
at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
std::vector<torch::lazy::Value> values;
for (const auto & it : indices) {
c10::optional<at::Tensor> tensor = it;
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
}
auto list = MakeNode<TorchMlirOptionalTensorList>(values);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexTensor>(lazy_self->GetIrValue(), list);
if (!node) {
auto self_meta = to_meta(self);
auto indices_meta = to_meta(indices);
auto out_meta = at::meta::index(self_meta, indices_meta);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = { self, indices };
const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<IndexTensor>(lazy_self->GetIrValue(), list, std::move(shapes));
CacheNode(node);
}
auto result = torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
return result;
}
at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device);
std::vector<torch::lazy::Value> indices_vector;
for (const auto & it : indices) {
c10::optional<at::Tensor> tensor = it;
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
}
auto indices_list = MakeNode<TorchMlirOptionalTensorList>(indices_vector);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate);
if (!node) {
auto self_meta = to_meta(self);
auto indices_meta = to_meta(indices);
auto values_meta = to_meta(values);
auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if(torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = { self, indices, values };
const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes));
CacheNode(node);
}
auto result = torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
return result;
}
// This is needed by the torch.tensor constructor. // This is needed by the torch.tensor constructor.
// LazyTensor always opts into functionalization. // LazyTensor always opts into functionalization.
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object. // "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.

View File

@ -120,5 +120,38 @@ torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower(
return {listnode->output()}; return {listnode->output()};
} }
///////////////////////////////////////////////////////////////////////////////
// TorchMlirOptionalTensorList
///////////////////////////////////////////////////////////////////////////////
OpKind TorchMlirOptionalTensorList::ClassOpKind() {
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
// import otherwise
static const OpKind tensor_list_opkind =
OpKind::Get("lazy_tensors::optional_tensor_list");
return tensor_list_opkind;
}
TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values)
: TorchMlirNode(
/*op=*/TorchMlirOptionalTensorList::ClassOpKind(),
/*operands=*/values,
/*shapes=*/std::vector<Shape>(),
/*num_outputs=*/1,
/*hash_seed=*/kHashSeed) {}
torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
std::vector<torch::jit::Value*> tensor_list;
CHECK(!operands().empty());
for (const torch::lazy::Output& operand : operands()) {
tensor_list.emplace_back(loctx->GetOutputOp(operand));
}
auto graph = function->graph();
auto listnode =
graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list));
return {listnode->output()};
}
} // namespace lazy } // namespace lazy
} // namespace torch } // namespace torch

View File

@ -91,5 +91,18 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode {
TorchMlirLoweringContext* loctx) const override; TorchMlirLoweringContext* loctx) const override;
}; };
// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent
// optional tensors, so the output type for this op is !torch.list<optional<vtensor>>.
struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode {
static OpKind ClassOpKind();
TorchMlirOptionalTensorList() = delete;
TorchMlirOptionalTensorList(OpList values);
torch::lazy::TorchMlirOpVector Lower(
TorchMlirFunction function,
TorchMlirLoweringContext* loctx) const override;
};
} // namespace lazy } // namespace lazy
} // namespace torch } // namespace torch

View File

@ -43,7 +43,12 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
for (auto arg : arguments) { for (auto arg : arguments) {
torch::jit::Value* value = arg.value(dummy_graph); torch::jit::Value* value = arg.value(dummy_graph);
if (value->type()->kind() == c10::TypeKind::ListType) { if (value->type()->kind() == c10::TypeKind::ListType) {
value->setType(c10::ListType::create(c10::TensorType::get())); auto list_element_type = value->type()->cast<c10::ListType>()->getElementType();
if (list_element_type->cast<c10::OptionalType>()) {
value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get())));
} else {
value->setType(c10::ListType::create(c10::TensorType::get()));
}
} }
} }

View File

@ -0,0 +1,99 @@
//===- index.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "index.h"
namespace torch {
namespace lazy {
IndexTensor::IndexTensor(const torch::lazy::Value& self,
const torch::lazy::Value& indices,
std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(),
OpList{self, indices}, std::move(shapes),
/* num_outputs */ 1, torch::lazy::MHash()) {}
std::string IndexTensor::ToString() const {
std::stringstream ss;
ss << torch::lazy::TorchMlirNode::ToString();
return ss.str();
}
bool IndexTensor::CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& indices) const {
return false;
}
TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function,
TorchMlirLoweringContext* loctx) const {
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve(2);
kwarguments.reserve(0);
size_t i = 0;
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin(
function, op().op, shapes(), arguments, kwarguments);
TORCH_CHECK_EQ(index_out.size(), 1);
return index_out;
}
IndexPut::IndexPut(const torch::lazy::Value& self,
const torch::lazy::Value& indices,
const torch::lazy::Value& values, bool accumulate,
std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TorchMlirNode(
IndexPut::ClassOpKind(), OpList{self, indices, values},
std::move(shapes),
/* num_outputs */ 1, torch::lazy::MHash(accumulate)),
accumulate(accumulate) {}
std::string IndexPut::ToString() const {
std::stringstream ss;
ss << torch::lazy::TorchMlirNode::ToString();
ss << ", accumulate=" << accumulate;
return ss.str();
}
bool IndexPut::CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& indices,
const torch::lazy::Value& values,
bool accumulate) const {
return false;
}
TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function,
TorchMlirLoweringContext* loctx) const {
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve(4);
kwarguments.reserve(0);
size_t i = 0;
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
arguments.emplace_back("accumulate", accumulate);
torch::lazy::TorchMlirOpVector index_out = torch::lazy::LowerTorchMlirBuiltin(
function, op().op, shapes(), arguments, kwarguments);
TORCH_CHECK_EQ(index_out.size(), 1);
return index_out;
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,58 @@
//===- index.h ------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#pragma once
#include "../mlir_node.h"
namespace torch {
namespace lazy {
class IndexTensor : public torch::lazy::TorchMlirNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::index);
}
IndexTensor(const torch::lazy::Value& self, const torch::lazy::Value& indices,
std::vector<torch::lazy::Shape>&& shapes);
std::string ToString() const override;
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& indices) const;
TorchMlirOpVector Lower(TorchMlirFunction function,
TorchMlirLoweringContext* loctx) const override;
};
class IndexPut : public torch::lazy::TorchMlirNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::aten::index_put);
}
IndexPut(const torch::lazy::Value& self, const torch::lazy::Value& indices,
const torch::lazy::Value& values, bool accumulate,
std::vector<torch::lazy::Shape>&& shapes);
std::string ToString() const override;
bool CanBeReused(const torch::lazy::Value& self,
const torch::lazy::Value& indices,
const torch::lazy::Value& values, bool accumulate) const;
TorchMlirOpVector Lower(TorchMlirFunction function,
TorchMlirLoweringContext* loctx) const override;
bool accumulate;
};
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,36 @@
//===- ivalue.cpp
//----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "ivalue.h"
#include <ATen/core/ivalue.h>
namespace torch {
namespace lazy {
IValueConstant::IValueConstant(const c10::IValue& value)
: torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{},
std::vector<Shape>{},
/* num_outputs */ 1, torch::lazy::MHash()),
value(value) {}
std::string IValueConstant::ToString() const {
std::stringstream ss;
ss << torch::lazy::TorchMlirNode::ToString();
return ss.str();
}
TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function,
TorchMlirLoweringContext* loctx) const {
return {loctx->graph()->insertConstant(value)};
}
} // namespace lazy
} // namespace torch

View File

@ -0,0 +1,37 @@
//===- index.h ------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#pragma once
#include "../mlir_node.h"
namespace torch {
namespace lazy {
// IValueConstant IR Node represents a `prim::Constant` constructed with IValue
// parameter which is helpfull in different usecases when we need custom
// native ops lowering to torch-mlir IR nodes.
class IValueConstant : public torch::lazy::TorchMlirNode {
public:
static torch::lazy::OpKind ClassOpKind() {
return torch::lazy::OpKind(at::prim::Constant);
}
IValueConstant(const c10::IValue& value);
std::string ToString() const override;
TorchMlirOpVector Lower(TorchMlirFunction function,
TorchMlirLoweringContext* loctx) const override;
c10::IValue value;
};
} // namespace lazy
} // namespace torch