From 3d95c3d6c90b892fa8064f5a7244d4f600f5ae60 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 9 Mar 2022 18:35:13 +0530 Subject: [PATCH] [MLIR][TORCH] Add value tensor variant to aten::_index_put_impl_ This commit adds the op `ValsemVariantAtenIndexPutImplOp` that represents `Aten_IndexPutImpl_Op` without the underscore. This is needed to make sure that the `ReduceOpVariants` pass turns the in-place op into an op that takes value tensors as inputs, otherwise the `MaximizeValueSemantics` pass will not be able to add value semantics correctly. This commit also adds the lowering of `ValsemVariantAtenIndexPutImplOp` op. This commit also updates the `torch.bincount` op test cases. --- e2e_testing/torchscript/basic.py | 8 +- e2e_testing/torchscript/index_put.py | 112 +++++++++++++++ e2e_testing/torchscript/main.py | 1 + .../Dialect/Torch/IR/GeneratedAtenOps.td | 17 +++ .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 21 +++ .../TorchToTMTensor/TorchToTMTensor.cpp | 129 ++++++++++++++++++ .../Torch/Transforms/ReduceOpVariants.cpp | 4 + lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 + .../jit_ir/build_tools/shape_lib_gen.py | 4 + .../jit_ir/build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/reduce-op-variants.mlir | 22 +++ 12 files changed, 320 insertions(+), 5 deletions(-) create mode 100644 e2e_testing/torchscript/index_put.py diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 931203563..2d166ffbc 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -1417,7 +1417,7 @@ class BincountModule(torch.nn.Module): @register_test_case(module_factory=lambda: BincountModule()) def BincountModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10,))) + module.forward(torch.randint(10, (1000,))) class BincountStaticSizeModule(torch.nn.Module): @@ -1427,14 +1427,14 @@ class BincountStaticSizeModule(torch.nn.Module): @export @annotate_args([ None, - ([20], torch.int64, True), + ([200], torch.int64, True), ]) def forward(self, x): return torch.ops.aten.bincount(x) @register_test_case(module_factory=lambda: BincountStaticSizeModule()) def BincountStaticSizeModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (20,))) + module.forward(torch.randint(100, (200,))) class BincountMinlengthModule(torch.nn.Module): @@ -1451,4 +1451,4 @@ class BincountMinlengthModule(torch.nn.Module): @register_test_case(module_factory=lambda: BincountMinlengthModule()) def BincountMinlengthModule_basic(module, tu: TestUtils): - module.forward(torch.randint(500, (20,))) + module.forward(torch.randint(5, (20,))) diff --git a/e2e_testing/torchscript/index_put.py b/e2e_testing/torchscript/index_put.py new file mode 100644 index 000000000..f92a13f14 --- /dev/null +++ b/e2e_testing/torchscript/index_put.py @@ -0,0 +1,112 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + + +class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input, (index,), value, + accumulate=False, + unsafe=False) + + +@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatNonAccumulateModule()) +def IndexPutImplOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(100), torch.randint(100, (250,)), + tu.rand(250)) + + +class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input, (index,), value, + accumulate=False, + unsafe=False) + + +@register_test_case(module_factory=lambda: IndexPutImplOneDimIntNonAccumulateModule()) +def IndexPutImplOneDimIntNonAccumulateModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)), + torch.randint(10000, (300,))) + + +class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ]) + def forward(self, input, index, value): + # Since the input is updated in-place, we pass input.clone() in place + # of input to avoid wrong results. + return torch.ops.aten._index_put_impl_(input.clone(), (index,), value, + accumulate=True, + unsafe=False) + + +@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatAccumulateModule()) +def IndexPutImplOneDimFloatAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1000), torch.randint(10, (500,)), + tu.rand(500)) + + +class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, input, index, value): + # Since the input is updated in-place, we pass input.clone() in place + # of input to avoid wrong results. + return torch.ops.aten._index_put_impl_(input.clone(), (index,), value, + accumulate=True, + unsafe=False) + + +@register_test_case(module_factory=lambda: IndexPutImplOneDimIntAccumulateModule()) +def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils): + module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)), + torch.randint(1000, (10,))) diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 0e3d73f8e..3baca1d23 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -54,6 +54,7 @@ from . import histogram_binning_calibration from . import table_batch_embedding from . import rng from . import cast +from . import index_put def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index b0b5c6ebe..9aed8afa5 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -2921,6 +2921,23 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [ let assemblyFormat = "$self `,` $dim `,` $index attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($index)) `->` qualified(type($result))"; } +def Torch_Aten_IndexPutImpl_Op : Torch_Op<"aten._index_put_impl_", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalTensorListType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate, + Torch_BoolType:$unsafe + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))"; +} + def Torch_AtenItemOp : Torch_Op<"aten.item", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index ed47fd5b3..2ec682434 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1018,6 +1018,27 @@ def Torch_ValsemVariantAtenFillScalarOp: Torch_Op<"valsem.aten.fill.Scalar", [ let assemblyFormat = "$self `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($value)) `->` qualified(type($result))"; } +// The corresponding without underscore variant for `torch.aten._index_put_impl_` +// doesn't exist in the pytorch ops registry. Add it here. +def Torch_ValsemVariantAtenIndexPutImplOp: Torch_Op<"valsem.aten.index_put_impl", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { +let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalTensorListType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate, + Torch_BoolType:$unsafe + ); +let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))"; +} + // To handle runtime assertions, torchscript provides us `torch._assert` operation. // But TS compiler introduces control flow for `torch._assert` operation. The // `torch._assert` would introduce control flow like: diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index c31a1aea9..b33d314e7 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -176,6 +176,132 @@ public: }; } // namespace +namespace { +class ConvertValsemVariantAtenIndexPutImplOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ValsemVariantAtenIndexPutImplOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op.getLoc(); + MLIRContext *context = op->getContext(); + Value input = adaptor.self(); + Value values = adaptor.values(); + RankedTensorType inputType = input.getType().cast(); + RankedTensorType valuesType = values.getType().cast(); + Type resultElemType = typeConverter->convertType(op->getResult(0).getType()) + .cast() + .getElementType(); + + // TODO: Add support for the input with rank other than one. + if (inputType.getRank() != 1) + return rewriter.notifyMatchFailure( + op, "unimplemented: input rank other than one is not supported"); + + // The unsafe should be either `False` or `none`. + if (!op.unsafe().getType().isa()) { + bool unsafe; + if (!matchPattern(op.unsafe(), m_TorchConstantBool(&unsafe))) + return rewriter.notifyMatchFailure( + op, "unimplemented: unsafe must be a constant"); + else if (unsafe) + return rewriter.notifyMatchFailure( + op, "unimplemented: unsafe is expected to be false"); + } + + // The accumulate should be a torch constant of boolean type. + bool accumulate; + if (!matchPattern(op.accumulate(), m_TorchConstantBool(&accumulate))) + return rewriter.notifyMatchFailure( + op, "Expected accumulate to be constant bool."); + + // The element type of the `input` and `values` should be same. + if (inputType.getElementType() != valuesType.getElementType()) + return rewriter.notifyMatchFailure( + op, "Input element type should be same as the values element type."); + + SmallVector indicesList; + getListConstructElements(adaptor.indices(), indicesList); + // The size of the list of the index tensors should not be greater than the + // input rank. + if ((int64_t)indicesList.size() > inputType.getRank()) + return rewriter.notifyMatchFailure( + op, "Indices list size should not be greater than the input rank."); + + // TODO: Add support for cases with indices list size smaller than the input + // rank. + if ((int64_t)indicesList.size() < inputType.getRank()) + return rewriter.notifyMatchFailure( + op, "Unimplemented, Indices list size smaller than input rank"); + + if (indicesList[0].getType().isa()) + return rewriter.notifyMatchFailure(op, + "Indices tensor must not be none."); + + // TODO: Add support for the index with rank other than one. + int64_t indexRank = typeConverter->convertType(indicesList[0].getType()) + .cast() + .getRank(); + if (indexRank != 1) + return rewriter.notifyMatchFailure( + op, "unimplemented: index rank other than one is not supported"); + + // Creating a tm_tensor.scatter op with the following mapping: + // 1.) Index tensor from the `indicesList` maps to the indices in scatter + // op. Index tensor is expanded from 1-d to 2-d, and its element type is set + // to i32 as required for the scatter op. + // 2.) `values` is mapped to `updates` in scatter op. + // 3.) `input` is mapped to `original` in scatter op. + ValueTensorType indexType = + indicesList[0].getType().cast(); + SmallVector expandedIndexSizes{indexType.getSizes()[0], 1}; + ValueTensorType expandedIndexType = ValueTensorType::get( + context, llvm::makeArrayRef(expandedIndexSizes), indexType.getDtype()); + Value torchCstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value expandedIndexTensor = rewriter.create( + loc, expandedIndexType, indicesList[0], torchCstOne); + + // Converting the index element type to i32. + Value indices = convertTensorToDtype( + rewriter, loc, expandedIndexTensor, + mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); + indices = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(indices.getType()), indices); + + auto scatterOp = rewriter.create( + loc, input.getType(), ValueRange{values, indices}, ValueRange{input}, + /*unique_indices=*/false); + + Region &scatterOpRegion = scatterOp.region(); + auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); + scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType}, + {loc, loc}); + auto blockArgs = scatterOpBlock.getArguments(); + + OpBuilder regionBuilder(scatterOpRegion); + Value update = blockArgs[0]; + Value original = blockArgs[1]; + Value yieldValue = update; + // Create an add instruction inside the scatter op region to increment the + // `original` value with the value from `updates` if the accumulate flag is + // true. + if (accumulate) { + if (inputType.getElementType().isa()) + yieldValue = regionBuilder.create(loc, original, update); + else if (inputType.getElementType().isa()) + yieldValue = regionBuilder.create(loc, original, update); + } + regionBuilder.create(loc, yieldValue); + rewriter.replaceOp(op, scatterOp->getResult(0)); + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -207,6 +333,9 @@ public: RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, + context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 73d3f00bf..5f99cc954 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -158,6 +158,9 @@ public: } else if (isa(op)) { newOp = rewriter.create( loc, op->getResultTypes(), op->getOperands()); + } else if (isa(op)) { + newOp = rewriter.create( + loc, op->getResultTypes(), op->getOperands()); } else { return failure(); } @@ -237,6 +240,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *op) { if (op->hasTrait()) { auto hasValueSemantics = [](Type t) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8305f54bc..2f0771b49 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -516,7 +516,7 @@ ChangeResult TypeAnalyzer::visitOperation( AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, - AtenIndexTensorOp>(op)) { + AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp>(op)) { ValueKnowledge knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.dtype = operands[0]->getValue().dtype; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 09be27407..40dd85f96 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -1801,6 +1801,10 @@ module { func @"__torch_mlir_shape_fn.aten.bernoulli.Tensor"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.any) -> !torch.list { return %arg0 : !torch.list } + func @"__torch_mlir_shape_fn.aten.index_put_impl"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func @"__torch_mlir_shape_fn.aten.bernoulli"(%arg0: !torch.list, %arg1: !torch.any) -> !torch.list { return %arg0 : !torch.list } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 71a8c5e41..8634f18cc 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -606,6 +606,10 @@ def aten〇bernoulli〇float(self: List[int], p: float = 0.5, generator: Any = N def aten〇bernoulli〇Tensor(self: List[int], p: List[int], generator: Any = None) -> List[int]: return self +@not_present_in_registry +def aten〇index_put_impl(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False, unsafe: bool = False) -> List[int]: + return upstream_shape_helpers.unary(self) + def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]: return self diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index c76635339..d6909e3bb 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -401,6 +401,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") + emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)") emit("aten::item : (Tensor) -> (Scalar)") emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") emit("aten::numel : (Tensor) -> (int)") diff --git a/test/Dialect/Torch/reduce-op-variants.mlir b/test/Dialect/Torch/reduce-op-variants.mlir index 30de404f1..febfdd209 100644 --- a/test/Dialect/Torch/reduce-op-variants.mlir +++ b/test/Dialect/Torch/reduce-op-variants.mlir @@ -176,3 +176,25 @@ func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor { %ret = torch.aten.fill_.Scalar %t, %value : !torch.tensor, !torch.int -> !torch.tensor return %ret : !torch.tensor } + +// CHECK-LABEL: func @torch.aten._index_put_impl_( +// CHECK-SAME: %[[SELF:.*]]: !torch.tensor, %[[INDEX:.*]]: !torch.tensor, %[[VALUES:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.tensor) -> !torch.list> +// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor +// CHECK: %[[INDEX_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDEX]] : !torch.vtensor +// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDEX_VTENSOR]] : (!torch.vtensor) -> !torch.list +// CHECK: %[[VALUES_VTENSOR:.*]] = torch.copy.to_vtensor %[[VALUES]] : !torch.vtensor +// CHECK: %[[VRET:.*]] = torch.valsem.aten.index_put_impl %[[SELF_VTENSOR]], %[[INDICES_LIST]], %[[VALUES_VTENSOR]], %[[TRUE]], %[[FALSE]] : !torch.vtensor, !torch.list, !torch.vtensor, !torch.bool, !torch.bool -> !torch.vtensor +// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor +// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor +// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor +// CHECK: return %[[SELF:.*]] : !torch.tensor +func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %values: !torch.tensor) -> !torch.tensor { + %true = torch.constant.bool true + %false = torch.constant.bool false + %indicesList = torch.prim.ListConstruct %index : (!torch.tensor) -> !torch.list> + %ret = torch.aten._index_put_impl_ %self, %indicesList, %values, %true, %false : !torch.tensor, !torch.list>, !torch.tensor, !torch.bool, !torch.bool -> !torch.tensor + return %ret : !torch.tensor +}