mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add value tensor variant to aten::_index_put_impl_
This commit adds the op `ValsemVariantAtenIndexPutImplOp` that represents `Aten_IndexPutImpl_Op` without the underscore. This is needed to make sure that the `ReduceOpVariants` pass turns the in-place op into an op that takes value tensors as inputs, otherwise the `MaximizeValueSemantics` pass will not be able to add value semantics correctly. This commit also adds the lowering of `ValsemVariantAtenIndexPutImplOp` op. This commit also updates the `torch.bincount` op test cases.pull/679/head
parent
8a4388ea7b
commit
3d95c3d6c9
|
@ -1417,7 +1417,7 @@ class BincountModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BincountModule())
|
||||
def BincountModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10,)))
|
||||
module.forward(torch.randint(10, (1000,)))
|
||||
|
||||
|
||||
class BincountStaticSizeModule(torch.nn.Module):
|
||||
|
@ -1427,14 +1427,14 @@ class BincountStaticSizeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([20], torch.int64, True),
|
||||
([200], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.bincount(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: BincountStaticSizeModule())
|
||||
def BincountStaticSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (20,)))
|
||||
module.forward(torch.randint(100, (200,)))
|
||||
|
||||
|
||||
class BincountMinlengthModule(torch.nn.Module):
|
||||
|
@ -1451,4 +1451,4 @@ class BincountMinlengthModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BincountMinlengthModule())
|
||||
def BincountMinlengthModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(500, (20,)))
|
||||
module.forward(torch.randint(5, (20,)))
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
||||
accumulate=False,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatNonAccumulateModule())
|
||||
def IndexPutImplOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100), torch.randint(100, (250,)),
|
||||
tu.rand(250))
|
||||
|
||||
|
||||
class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
||||
accumulate=False,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntNonAccumulateModule())
|
||||
def IndexPutImplOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
|
||||
torch.randint(10000, (300,)))
|
||||
|
||||
|
||||
class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
# Since the input is updated in-place, we pass input.clone() in place
|
||||
# of input to avoid wrong results.
|
||||
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
|
||||
accumulate=True,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatAccumulateModule())
|
||||
def IndexPutImplOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1000), torch.randint(10, (500,)),
|
||||
tu.rand(500))
|
||||
|
||||
|
||||
class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
# Since the input is updated in-place, we pass input.clone() in place
|
||||
# of input to avoid wrong results.
|
||||
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
|
||||
accumulate=True,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntAccumulateModule())
|
||||
def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
||||
torch.randint(1000, (10,)))
|
|
@ -54,6 +54,7 @@ from . import histogram_binning_calibration
|
|||
from . import table_batch_embedding
|
||||
from . import rng
|
||||
from . import cast
|
||||
from . import index_put
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -2921,6 +2921,23 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
|
|||
let assemblyFormat = "$self `,` $dim `,` $index attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($index)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_Aten_IndexPutImpl_Op : Torch_Op<"aten._index_put_impl_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate,
|
||||
Torch_BoolType:$unsafe
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenItemOp : Torch_Op<"aten.item", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -1018,6 +1018,27 @@ def Torch_ValsemVariantAtenFillScalarOp: Torch_Op<"valsem.aten.fill.Scalar", [
|
|||
let assemblyFormat = "$self `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($value)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
// The corresponding without underscore variant for `torch.aten._index_put_impl_`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_ValsemVariantAtenIndexPutImplOp: Torch_Op<"valsem.aten.index_put_impl", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate,
|
||||
Torch_BoolType:$unsafe
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
|
||||
// But TS compiler introduces control flow for `torch._assert` operation. The
|
||||
// `torch._assert` would introduce control flow like:
|
||||
|
|
|
@ -176,6 +176,132 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertValsemVariantAtenIndexPutImplOp
|
||||
: public OpConversionPattern<ValsemVariantAtenIndexPutImplOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ValsemVariantAtenIndexPutImplOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = adaptor.self();
|
||||
Value values = adaptor.values();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType valuesType = values.getType().cast<RankedTensorType>();
|
||||
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
// TODO: Add support for the input with rank other than one.
|
||||
if (inputType.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input rank other than one is not supported");
|
||||
|
||||
// The unsafe should be either `False` or `none`.
|
||||
if (!op.unsafe().getType().isa<Torch::NoneType>()) {
|
||||
bool unsafe;
|
||||
if (!matchPattern(op.unsafe(), m_TorchConstantBool(&unsafe)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: unsafe must be a constant");
|
||||
else if (unsafe)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: unsafe is expected to be false");
|
||||
}
|
||||
|
||||
// The accumulate should be a torch constant of boolean type.
|
||||
bool accumulate;
|
||||
if (!matchPattern(op.accumulate(), m_TorchConstantBool(&accumulate)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected accumulate to be constant bool.");
|
||||
|
||||
// The element type of the `input` and `values` should be same.
|
||||
if (inputType.getElementType() != valuesType.getElementType())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input element type should be same as the values element type.");
|
||||
|
||||
SmallVector<Value> indicesList;
|
||||
getListConstructElements(adaptor.indices(), indicesList);
|
||||
// The size of the list of the index tensors should not be greater than the
|
||||
// input rank.
|
||||
if ((int64_t)indicesList.size() > inputType.getRank())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Indices list size should not be greater than the input rank.");
|
||||
|
||||
// TODO: Add support for cases with indices list size smaller than the input
|
||||
// rank.
|
||||
if ((int64_t)indicesList.size() < inputType.getRank())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented, Indices list size smaller than input rank");
|
||||
|
||||
if (indicesList[0].getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Indices tensor must not be none.");
|
||||
|
||||
// TODO: Add support for the index with rank other than one.
|
||||
int64_t indexRank = typeConverter->convertType(indicesList[0].getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getRank();
|
||||
if (indexRank != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: index rank other than one is not supported");
|
||||
|
||||
// Creating a tm_tensor.scatter op with the following mapping:
|
||||
// 1.) Index tensor from the `indicesList` maps to the indices in scatter
|
||||
// op. Index tensor is expanded from 1-d to 2-d, and its element type is set
|
||||
// to i32 as required for the scatter op.
|
||||
// 2.) `values` is mapped to `updates` in scatter op.
|
||||
// 3.) `input` is mapped to `original` in scatter op.
|
||||
ValueTensorType indexType =
|
||||
indicesList[0].getType().cast<ValueTensorType>();
|
||||
SmallVector<int64_t> expandedIndexSizes{indexType.getSizes()[0], 1};
|
||||
ValueTensorType expandedIndexType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedIndexSizes), indexType.getDtype());
|
||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||
loc, expandedIndexType, indicesList[0], torchCstOne);
|
||||
|
||||
// Converting the index element type to i32.
|
||||
Value indices = convertTensorToDtype(
|
||||
rewriter, loc, expandedIndexTensor,
|
||||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||
indices = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||
|
||||
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
|
||||
loc, input.getType(), ValueRange{values, indices}, ValueRange{input},
|
||||
/*unique_indices=*/false);
|
||||
|
||||
Region &scatterOpRegion = scatterOp.region();
|
||||
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
||||
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
|
||||
{loc, loc});
|
||||
auto blockArgs = scatterOpBlock.getArguments();
|
||||
|
||||
OpBuilder regionBuilder(scatterOpRegion);
|
||||
Value update = blockArgs[0];
|
||||
Value original = blockArgs[1];
|
||||
Value yieldValue = update;
|
||||
// Create an add instruction inside the scatter op region to increment the
|
||||
// `original` value with the value from `updates` if the accumulate flag is
|
||||
// true.
|
||||
if (accumulate) {
|
||||
if (inputType.getElementType().isa<mlir::IntegerType>())
|
||||
yieldValue = regionBuilder.create<arith::AddIOp>(loc, original, update);
|
||||
else if (inputType.getElementType().isa<mlir::FloatType>())
|
||||
yieldValue = regionBuilder.create<arith::AddFOp>(loc, original, update);
|
||||
}
|
||||
regionBuilder.create<TMTensor::YieldOp>(loc, yieldValue);
|
||||
rewriter.replaceOp(op, scatterOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -207,6 +333,9 @@ public:
|
|||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<AtenBincountOp>();
|
||||
patterns.add<ConvertAtenBincountOp>(typeConverter, context);
|
||||
target.addIllegalOp<ValsemVariantAtenIndexPutImplOp>();
|
||||
patterns.add<ConvertValsemVariantAtenIndexPutImplOp>(typeConverter,
|
||||
context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -158,6 +158,9 @@ public:
|
|||
} else if (isa<AtenFill_ScalarOp>(op)) {
|
||||
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else if (isa<Aten_IndexPutImpl_Op>(op)) {
|
||||
newOp = rewriter.create<ValsemVariantAtenIndexPutImplOp>(
|
||||
loc, op->getResultTypes(), op->getOperands());
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
@ -237,6 +240,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
target.addIllegalOp<AtenBernoulli_FloatOp>();
|
||||
target.addIllegalOp<AtenBernoulli_TensorOp>();
|
||||
target.addIllegalOp<AtenFill_ScalarOp>();
|
||||
target.addIllegalOp<Aten_IndexPutImpl_Op>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
auto hasValueSemantics = [](Type t) {
|
||||
|
|
|
@ -516,7 +516,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
|
||||
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
|
||||
AtenExpandOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
|
||||
AtenIndexTensorOp>(op)) {
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
|
|
|
@ -1801,6 +1801,10 @@ module {
|
|||
func @"__torch_mlir_shape_fn.aten.bernoulli.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.any) -> !torch.list<int> {
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.index_put_impl"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.bernoulli"(%arg0: !torch.list<int>, %arg1: !torch.any) -> !torch.list<int> {
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
|
|
|
@ -606,6 +606,10 @@ def aten〇bernoulli〇float(self: List[int], p: float = 0.5, generator: Any = N
|
|||
def aten〇bernoulli〇Tensor(self: List[int], p: List[int], generator: Any = None) -> List[int]:
|
||||
return self
|
||||
|
||||
@not_present_in_registry
|
||||
def aten〇index_put_impl(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False, unsafe: bool = False) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]:
|
||||
return self
|
||||
|
||||
|
|
|
@ -401,6 +401,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
||||
emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
||||
emit("aten::item : (Tensor) -> (Scalar)")
|
||||
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::numel : (Tensor) -> (int)")
|
||||
|
|
|
@ -176,3 +176,25 @@ func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
|
|||
%ret = torch.aten.fill_.Scalar %t, %value : !torch.tensor, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten._index_put_impl_(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor, %[[INDEX:.*]]: !torch.tensor, %[[VALUES:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.tensor) -> !torch.list<optional<tensor>>
|
||||
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor
|
||||
// CHECK: %[[INDEX_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDEX]] : !torch.vtensor
|
||||
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDEX_VTENSOR]] : (!torch.vtensor) -> !torch.list<vtensor>
|
||||
// CHECK: %[[VALUES_VTENSOR:.*]] = torch.copy.to_vtensor %[[VALUES]] : !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.valsem.aten.index_put_impl %[[SELF_VTENSOR]], %[[INDICES_LIST]], %[[VALUES_VTENSOR]], %[[TRUE]], %[[FALSE]] : !torch.vtensor, !torch.list<vtensor>, !torch.vtensor, !torch.bool, !torch.bool -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: return %[[SELF:.*]] : !torch.tensor
|
||||
func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %values: !torch.tensor) -> !torch.tensor {
|
||||
%true = torch.constant.bool true
|
||||
%false = torch.constant.bool false
|
||||
%indicesList = torch.prim.ListConstruct %index : (!torch.tensor) -> !torch.list<optional<tensor>>
|
||||
%ret = torch.aten._index_put_impl_ %self, %indicesList, %values, %true, %false : !torch.tensor, !torch.list<optional<tensor>>, !torch.tensor, !torch.bool, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue