Add support for multi-dim input to `index_put_impl` (#722)

This commit adds support for multi-dimensional tensors as input to the
`_index_put_impl_` op. The support was to some degree already there,
since `ScatterOp` already supports multi-dimensional tensors. This
commit also adds a bit more error checking to `index_put` and
refactors the code for creating `ScatterOp`s to mimic the way one
would make a `Linalg::GenericOp`.
pull/726/head snapshot-20220331.361
Ramiro Leal-Cavazos 2022-03-31 09:27:21 -07:00 committed by GitHub
parent ccf924d3df
commit 51d4d55f8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 213 additions and 113 deletions

View File

@ -35,6 +35,9 @@ Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
// Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
Type dtype);
// Helper funtion to get rank of `Base tensor type`.
// -1 is returned if the tensorRank can't be determined.
int getTensorRank(Value tensor);
} // namespace Torch
} // namespace torch

View File

@ -48,6 +48,28 @@ using namespace mlir::torch::TMTensor;
// that these patterns become mostly mechanical associations of
// "aten.foo -> linalg.foo".
static Value createTMTensorScatterOp(
OpBuilder &b, Location loc, Value updates, Value indices, Value original,
bool uniqueIndices,
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
auto originalTensorType = original.getType().cast<RankedTensorType>();
Type originalElementType = originalTensorType.getElementType();
auto scatterOp = b.create<TMTensor::ScatterOp>(
loc, originalTensorType, ValueRange{updates, indices},
ValueRange{original}, uniqueIndices);
Region &scatterOpRegion = scatterOp.region();
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
scatterOpBlock.addArguments({originalElementType, originalElementType},
{loc, loc});
OpBuilder regionBuilder(scatterOpRegion);
auto blockArgs = scatterOpBlock.getArguments();
Value updatesElement = blockArgs[0];
Value originalElement = blockArgs[1];
bodyBuild(regionBuilder, loc, updatesElement, originalElement);
return scatterOp->getResult(0);
}
namespace {
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
// non-negative ints.
@ -88,7 +110,7 @@ public:
// TODO: Incorporate the weight argument.
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weights operand is not incorporated.");
op, "Unimplemented: the weights operand is not incorporated.");
// Finding the maximum value in the input tensor.
SmallVector<int64_t> maxTensorSizes;
@ -129,9 +151,9 @@ public:
indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>()
.getElementType();
auto resultType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type resultElemType = resultType.getElementType();
SmallVector<Value, 1> inputSizeDynamic =
getTensorSizesUntilDim(rewriter, loc, input, 0);
@ -152,25 +174,14 @@ public:
Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize},
resultElemType, constantZero);
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
loc, bincountTensor.getType(), ValueRange{updatesTensor, indices},
ValueRange{bincountTensor},
/*unique_indices=*/false);
Region &scatterOpRegion = scatterOp.region();
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
{loc, loc});
auto blockArgs = scatterOpBlock.getArguments();
// Creating an add instruction inside the scatter op region to increment the
// frequency counter with one.
OpBuilder regionBuilder(scatterOpRegion);
Value add = regionBuilder.create<arith::AddIOp>(loc,
/*bincount=*/blockArgs[1],
constantOne);
regionBuilder.create<TMTensor::YieldOp>(loc, add);
rewriter.replaceOp(op, scatterOp->getResult(0));
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, updatesTensor, indices, bincountTensor,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value _, Value bincountElem) {
Value add = b.create<arith::AddIOp>(loc, bincountElem, constantOne);
b.create<TMTensor::YieldOp>(loc, add);
});
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success();
}
};
@ -192,14 +203,8 @@ public:
Value values = adaptor.values();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType valuesType = values.getType().cast<RankedTensorType>();
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>()
.getElementType();
// TODO: Add support for the input with rank other than one.
if (inputType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: input rank other than one is not supported");
auto resultType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
// The unsafe should be either `False` or `none`.
if (!op.unsafe().getType().isa<Torch::NoneType>()) {
@ -231,23 +236,14 @@ public:
return rewriter.notifyMatchFailure(
op, "Indices list size should not be greater than the input rank.");
// TODO: Add support for cases with indices list size smaller than the input
// rank.
if ((int64_t)indicesList.size() < inputType.getRank())
// TODO: Add support for cases with indices list size not equal to 1.
if (indicesList.size() != 1)
return rewriter.notifyMatchFailure(
op, "Unimplemented, Indices list size smaller than input rank");
op, "Unimplemented: Indices list size != 1");
Value indexTensor = indicesList[0];
if (indicesList[0].getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op,
"Indices tensor must not be none.");
// TODO: Add support for the index with rank other than one.
int64_t indexRank = typeConverter->convertType(indicesList[0].getType())
.cast<RankedTensorType>()
.getRank();
if (indexRank != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: index rank other than one is not supported");
if (indexTensor.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "Index tensor must not be None.");
// Creating a tm_tensor.scatter op with the following mapping:
// 1.) Index tensor from the `indicesList` maps to the indices in scatter
@ -255,48 +251,55 @@ public:
// to i32 as required for the scatter op.
// 2.) `values` is mapped to `updates` in scatter op.
// 3.) `input` is mapped to `original` in scatter op.
ValueTensorType indexType =
indicesList[0].getType().cast<ValueTensorType>();
SmallVector<int64_t> expandedIndexSizes{indexType.getSizes()[0], 1};
ValueTensorType expandedIndexType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIndexSizes), indexType.getDtype());
if (getTensorRank(indexTensor) != 1)
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensor with rank != 1 is not supported");
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
int64_t indexTensorSize = indexTensorType.getSizes()[0];
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
ValueTensorType expandedIndexTensorType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIndexTensorSizes),
indexTensorType.getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
loc, expandedIndexType, indicesList[0], torchCstOne);
loc, expandedIndexTensorType, indexTensor, torchCstOne);
// Converting the index element type to i32.
// `TMTensor::ScatterOp` expects indices of element type i32.
Value indices = convertTensorToDtype(
rewriter, loc, expandedIndexTensor,
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
indices = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
loc, input.getType(), ValueRange{values, indices}, ValueRange{input},
/*unique_indices=*/false);
bool invalidInputTypeFound = false;
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, values, indices, input, /*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value valuesElement,
Value inputElement) {
Value yieldValue = valuesElement;
if (accumulate) {
if (inputElement.getType().isa<mlir::IntegerType>()) {
yieldValue =
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
} else if (inputElement.getType().isa<mlir::FloatType>()) {
yieldValue =
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
} else {
invalidInputTypeFound = true;
return;
}
}
b.create<TMTensor::YieldOp>(loc, yieldValue);
});
Region &scatterOpRegion = scatterOp.region();
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
{loc, loc});
auto blockArgs = scatterOpBlock.getArguments();
OpBuilder regionBuilder(scatterOpRegion);
Value update = blockArgs[0];
Value original = blockArgs[1];
Value yieldValue = update;
// Create an add instruction inside the scatter op region to increment the
// `original` value with the value from `updates` if the accumulate flag is
// true.
if (accumulate) {
if (inputType.getElementType().isa<mlir::IntegerType>())
yieldValue = regionBuilder.create<arith::AddIOp>(loc, original, update);
else if (inputType.getElementType().isa<mlir::FloatType>())
yieldValue = regionBuilder.create<arith::AddFOp>(loc, original, update);
if (invalidInputTypeFound) {
return rewriter.notifyMatchFailure(
op,
"unimplemented: input tensor must be of integer type or float type");
}
regionBuilder.create<TMTensor::YieldOp>(loc, yieldValue);
rewriter.replaceOp(op, scatterOp->getResult(0));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success();
}
};

View File

@ -21,19 +21,6 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// Helper funtion to get rank of `Base tensor type`.
// -1 is returned if the tensorRank can't be determined.
static int getTensorRank(Value tensor) {
int tensorRank = -1;
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (tensorType.hasSizes()) {
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
tensorRank = tensorShape.size();
}
return tensorRank;
}
// Helper function to compute the return type of the reduction function.
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
// the input tensor.

View File

@ -95,3 +95,14 @@ Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
return converted;
}
int Torch::getTensorRank(Value tensor) {
int tensorRank = -1;
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (tensorType.hasSizes()) {
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
tensorRank = tensorShape.size();
}
return tensorRank;
}

View File

@ -11,7 +11,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -29,13 +29,61 @@ class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatNonAccumulateModule())
def IndexPutImplOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule())
def IndexPutImpl1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(100), torch.randint(100, (250,)),
tu.rand(250))
class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module):
class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1, -1], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (index,), value,
accumulate=False,
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule())
def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8), torch.randint(4, (5,)),
tu.rand(5, 8))
class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (index,), value,
accumulate=False,
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule())
def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5,)),
tu.rand(5, 8, 6))
class IndexPutImpl1DIntNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -53,13 +101,13 @@ class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module):
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntNonAccumulateModule())
def IndexPutImplOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule())
def IndexPutImpl1DIntNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
torch.randint(10000, (300,)))
class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module):
class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -79,13 +127,61 @@ class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module):
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatAccumulateModule())
def IndexPutImplOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: IndexPutImpl1DFloatAccumulateModule())
def IndexPutImpl1DFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1000), torch.randint(10, (500,)),
tu.rand(500))
class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1, -1], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
accumulate=True,
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImpl2DFloatAccumulateModule())
def IndexPutImpl2DFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8), torch.randint(4, (5,)),
tu.rand(5, 8))
class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
accumulate=True,
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImpl3DFloatAccumulateModule())
def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5,)),
tu.rand(5, 8, 6))
class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -105,14 +201,14 @@ class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
unsafe=False)
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntAccumulateModule())
def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule())
def IndexPutImpl1DIntAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
torch.randint(1000, (10,)))
# ==============================================================================
class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module):
class IndexPut1DFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -128,13 +224,13 @@ class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module):
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
@register_test_case(module_factory=lambda: IndexPutOneDimFloatNonAccumulateModule())
def IndexPutOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: IndexPut1DFloatNonAccumulateModule())
def IndexPut1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(100), torch.randint(100, (250,)),
tu.rand(250))
class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module):
class IndexPut1DIntNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -150,13 +246,13 @@ class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module):
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
@register_test_case(module_factory=lambda: IndexPutOneDimIntNonAccumulateModule())
def IndexPutOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: IndexPut1DIntNonAccumulateModule())
def IndexPut1DIntNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
torch.randint(10000, (300,)))
class IndexPutOneDimFloatAccumulateModule(torch.nn.Module):
class IndexPut1DFloatAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -172,13 +268,13 @@ class IndexPutOneDimFloatAccumulateModule(torch.nn.Module):
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
@register_test_case(module_factory=lambda: IndexPutOneDimFloatAccumulateModule())
def IndexPutOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: IndexPut1DFloatAccumulateModule())
def IndexPut1DFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1000), torch.randint(10, (500,)),
tu.rand(500))
class IndexPutOneDimIntAccumulateModule(torch.nn.Module):
class IndexPut1DIntAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -194,7 +290,7 @@ class IndexPutOneDimIntAccumulateModule(torch.nn.Module):
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
@register_test_case(module_factory=lambda: IndexPutOneDimIntAccumulateModule())
def IndexPutOneDimIntAccumulateModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: IndexPut1DIntAccumulateModule())
def IndexPut1DIntAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
torch.randint(1000, (10,)))