[MLIR][TORCH] Add value tensor variant to aten::_index_put_impl_

This commit adds the op `ValsemVariantAtenIndexPutImplOp` that represents
`Aten_IndexPutImpl_Op` without the underscore. This is needed to
make sure that the `ReduceOpVariants` pass turns the in-place op
into an op that takes value tensors as inputs, otherwise the
`MaximizeValueSemantics` pass will not be able to add value
semantics correctly.

This commit also adds the lowering of `ValsemVariantAtenIndexPutImplOp` op.

This commit also updates the `torch.bincount` op test cases.
pull/679/head
Vivek Khandelwal 2022-03-09 18:35:13 +05:30
parent 8a4388ea7b
commit 3d95c3d6c9
12 changed files with 320 additions and 5 deletions

View File

@ -1417,7 +1417,7 @@ class BincountModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BincountModule())
def BincountModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10,)))
module.forward(torch.randint(10, (1000,)))
class BincountStaticSizeModule(torch.nn.Module):
@ -1427,14 +1427,14 @@ class BincountStaticSizeModule(torch.nn.Module):
@export
@annotate_args([
None,
([20], torch.int64, True),
([200], torch.int64, True),
])
def forward(self, x):
return torch.ops.aten.bincount(x)
@register_test_case(module_factory=lambda: BincountStaticSizeModule())
def BincountStaticSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1000, (20,)))
module.forward(torch.randint(100, (200,)))
class BincountMinlengthModule(torch.nn.Module):
@ -1451,4 +1451,4 @@ class BincountMinlengthModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BincountMinlengthModule())
def BincountMinlengthModule_basic(module, tu: TestUtils):
module.forward(torch.randint(500, (20,)))
module.forward(torch.randint(5, (20,)))

View File

@ -0,0 +1,112 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (index,), value,
accumulate=False,
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatNonAccumulateModule())
def IndexPutImplOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(100), torch.randint(100, (250,)),
tu.rand(250))
class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (index,), value,
accumulate=False,
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntNonAccumulateModule())
def IndexPutImplOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
torch.randint(10000, (300,)))
class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
])
def forward(self, input, index, value):
# Since the input is updated in-place, we pass input.clone() in place
# of input to avoid wrong results.
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
accumulate=True,
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatAccumulateModule())
def IndexPutImplOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1000), torch.randint(10, (500,)),
tu.rand(500))
class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, input, index, value):
# Since the input is updated in-place, we pass input.clone() in place
# of input to avoid wrong results.
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
accumulate=True,
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntAccumulateModule())
def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
torch.randint(1000, (10,)))

View File

@ -54,6 +54,7 @@ from . import histogram_binning_calibration
from . import table_batch_embedding
from . import rng
from . import cast
from . import index_put
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -2921,6 +2921,23 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
let assemblyFormat = "$self `,` $dim `,` $index attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($index)) `->` qualified(type($result))";
}
def Torch_Aten_IndexPutImpl_Op : Torch_Op<"aten._index_put_impl_", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate,
Torch_BoolType:$unsafe
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))";
}
def Torch_AtenItemOp : Torch_Op<"aten.item", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1018,6 +1018,27 @@ def Torch_ValsemVariantAtenFillScalarOp: Torch_Op<"valsem.aten.fill.Scalar", [
let assemblyFormat = "$self `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($value)) `->` qualified(type($result))";
}
// The corresponding without underscore variant for `torch.aten._index_put_impl_`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_ValsemVariantAtenIndexPutImplOp: Torch_Op<"valsem.aten.index_put_impl", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorListType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate,
Torch_BoolType:$unsafe
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $indices `,` $values `,` $accumulate `,` $unsafe attr-dict `:` qualified(type($self)) `,` qualified(type($indices)) `,` qualified(type($values)) `,` qualified(type($accumulate)) `,` qualified(type($unsafe)) `->` qualified(type($result))";
}
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
// But TS compiler introduces control flow for `torch._assert` operation. The
// `torch._assert` would introduce control flow like:

View File

@ -176,6 +176,132 @@ public:
};
} // namespace
namespace {
class ConvertValsemVariantAtenIndexPutImplOp
: public OpConversionPattern<ValsemVariantAtenIndexPutImplOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ValsemVariantAtenIndexPutImplOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
Value input = adaptor.self();
Value values = adaptor.values();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType valuesType = values.getType().cast<RankedTensorType>();
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>()
.getElementType();
// TODO: Add support for the input with rank other than one.
if (inputType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: input rank other than one is not supported");
// The unsafe should be either `False` or `none`.
if (!op.unsafe().getType().isa<Torch::NoneType>()) {
bool unsafe;
if (!matchPattern(op.unsafe(), m_TorchConstantBool(&unsafe)))
return rewriter.notifyMatchFailure(
op, "unimplemented: unsafe must be a constant");
else if (unsafe)
return rewriter.notifyMatchFailure(
op, "unimplemented: unsafe is expected to be false");
}
// The accumulate should be a torch constant of boolean type.
bool accumulate;
if (!matchPattern(op.accumulate(), m_TorchConstantBool(&accumulate)))
return rewriter.notifyMatchFailure(
op, "Expected accumulate to be constant bool.");
// The element type of the `input` and `values` should be same.
if (inputType.getElementType() != valuesType.getElementType())
return rewriter.notifyMatchFailure(
op, "Input element type should be same as the values element type.");
SmallVector<Value> indicesList;
getListConstructElements(adaptor.indices(), indicesList);
// The size of the list of the index tensors should not be greater than the
// input rank.
if ((int64_t)indicesList.size() > inputType.getRank())
return rewriter.notifyMatchFailure(
op, "Indices list size should not be greater than the input rank.");
// TODO: Add support for cases with indices list size smaller than the input
// rank.
if ((int64_t)indicesList.size() < inputType.getRank())
return rewriter.notifyMatchFailure(
op, "Unimplemented, Indices list size smaller than input rank");
if (indicesList[0].getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op,
"Indices tensor must not be none.");
// TODO: Add support for the index with rank other than one.
int64_t indexRank = typeConverter->convertType(indicesList[0].getType())
.cast<RankedTensorType>()
.getRank();
if (indexRank != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: index rank other than one is not supported");
// Creating a tm_tensor.scatter op with the following mapping:
// 1.) Index tensor from the `indicesList` maps to the indices in scatter
// op. Index tensor is expanded from 1-d to 2-d, and its element type is set
// to i32 as required for the scatter op.
// 2.) `values` is mapped to `updates` in scatter op.
// 3.) `input` is mapped to `original` in scatter op.
ValueTensorType indexType =
indicesList[0].getType().cast<ValueTensorType>();
SmallVector<int64_t> expandedIndexSizes{indexType.getSizes()[0], 1};
ValueTensorType expandedIndexType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIndexSizes), indexType.getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
loc, expandedIndexType, indicesList[0], torchCstOne);
// Converting the index element type to i32.
Value indices = convertTensorToDtype(
rewriter, loc, expandedIndexTensor,
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
loc, input.getType(), ValueRange{values, indices}, ValueRange{input},
/*unique_indices=*/false);
Region &scatterOpRegion = scatterOp.region();
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
{loc, loc});
auto blockArgs = scatterOpBlock.getArguments();
OpBuilder regionBuilder(scatterOpRegion);
Value update = blockArgs[0];
Value original = blockArgs[1];
Value yieldValue = update;
// Create an add instruction inside the scatter op region to increment the
// `original` value with the value from `updates` if the accumulate flag is
// true.
if (accumulate) {
if (inputType.getElementType().isa<mlir::IntegerType>())
yieldValue = regionBuilder.create<arith::AddIOp>(loc, original, update);
else if (inputType.getElementType().isa<mlir::FloatType>())
yieldValue = regionBuilder.create<arith::AddFOp>(loc, original, update);
}
regionBuilder.create<TMTensor::YieldOp>(loc, yieldValue);
rewriter.replaceOp(op, scatterOp->getResult(0));
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -207,6 +333,9 @@ public:
RewritePatternSet patterns(context);
target.addIllegalOp<AtenBincountOp>();
patterns.add<ConvertAtenBincountOp>(typeConverter, context);
target.addIllegalOp<ValsemVariantAtenIndexPutImplOp>();
patterns.add<ConvertValsemVariantAtenIndexPutImplOp>(typeConverter,
context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -158,6 +158,9 @@ public:
} else if (isa<AtenFill_ScalarOp>(op)) {
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<Aten_IndexPutImpl_Op>(op)) {
newOp = rewriter.create<ValsemVariantAtenIndexPutImplOp>(
loc, op->getResultTypes(), op->getOperands());
} else {
return failure();
}
@ -237,6 +240,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.addIllegalOp<AtenBernoulli_TensorOp>();
target.addIllegalOp<AtenFill_ScalarOp>();
target.addIllegalOp<Aten_IndexPutImpl_Op>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
auto hasValueSemantics = [](Type t) {

View File

@ -516,7 +516,7 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
AtenExpandOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
AtenIndexTensorOp>(op)) {
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = operands[0]->getValue().dtype;

View File

@ -1801,6 +1801,10 @@ module {
func @"__torch_mlir_shape_fn.aten.bernoulli.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.any) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.index_put_impl"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.bernoulli"(%arg0: !torch.list<int>, %arg1: !torch.any) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}

View File

@ -606,6 +606,10 @@ def atenbernoullifloat(self: List[int], p: float = 0.5, generator: Any = N
def atenbernoulliTensor(self: List[int], p: List[int], generator: Any = None) -> List[int]:
return self
@not_present_in_registry
def atenindex_put_impl(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False, unsafe: bool = False) -> List[int]:
return upstream_shape_helpers.unary(self)
def atenbernoulli(self: List[int], generator: Any = None) -> List[int]:
return self

View File

@ -401,6 +401,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)")
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)")

View File

@ -176,3 +176,25 @@ func @torch.aten.fill_.Scalar(%t: !torch.tensor) -> !torch.tensor {
%ret = torch.aten.fill_.Scalar %t, %value : !torch.tensor, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten._index_put_impl_(
// CHECK-SAME: %[[SELF:.*]]: !torch.tensor, %[[INDEX:.*]]: !torch.tensor, %[[VALUES:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INDICES_OPTIONAL_LIST:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.tensor) -> !torch.list<optional<tensor>>
// CHECK: %[[SELF_VTENSOR:.*]] = torch.copy.to_vtensor %[[SELF]] : !torch.vtensor
// CHECK: %[[INDEX_VTENSOR:.*]] = torch.copy.to_vtensor %[[INDEX]] : !torch.vtensor
// CHECK: %[[INDICES_LIST:.*]] = torch.prim.ListConstruct %[[INDEX_VTENSOR]] : (!torch.vtensor) -> !torch.list<vtensor>
// CHECK: %[[VALUES_VTENSOR:.*]] = torch.copy.to_vtensor %[[VALUES]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.valsem.aten.index_put_impl %[[SELF_VTENSOR]], %[[INDICES_LIST]], %[[VALUES_VTENSOR]], %[[TRUE]], %[[FALSE]] : !torch.vtensor, !torch.list<vtensor>, !torch.vtensor, !torch.bool, !torch.bool -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor.contents %[[COPY_VTENSOR]] overwrites %[[SELF]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[SELF:.*]] : !torch.tensor
func @torch.aten._index_put_impl_(%self: !torch.tensor, %index: !torch.tensor, %values: !torch.tensor) -> !torch.tensor {
%true = torch.constant.bool true
%false = torch.constant.bool false
%indicesList = torch.prim.ListConstruct %index : (!torch.tensor) -> !torch.list<optional<tensor>>
%ret = torch.aten._index_put_impl_ %self, %indicesList, %values, %true, %false : !torch.tensor, !torch.list<optional<tensor>>, !torch.tensor, !torch.bool, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}