[LINALG] Broadcast `values` to shape of slize in `index_put` (#3487)

The `index_put` operation, `input[indices] = values`, allows for the
values to be any shape that is broadcastable to the slice
`input[indices]`. This commit adds broadcasting support to the Linalg
lowering of `IndexPutHackedTwinOp`.

Fixes: #3465
pull/3505/head
Ramiro Leal-Cavazos 2024-06-26 09:59:49 +01:00 committed by GitHub
parent d2bc70f188
commit e29191bd08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 14 deletions

View File

@ -541,19 +541,9 @@ public:
namespace {
Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
OpBuilder b) {
llvm::SmallVector<Value> indices(indicesRef);
// Declare commonly used constants up front:
Value torchCstZero =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(0));
Value torchCstOne =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(1));
Value torchCstNegOne =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(-1));
// Determine the broadcast sizes and materialize missing implicit end
// dimensions:
// Determine the common broadcast shape of all the index tensors.
std::pair<llvm::SmallVector<Value>, llvm::SmallVector<int64_t>>
getBroadcastShape(Location loc, llvm::ArrayRef<Value> indices, OpBuilder b) {
int64_t indicesRank = 0;
for (auto index : indices) {
auto indexTy = cast<Torch::ValueTensorType>(index.getType());
@ -567,6 +557,8 @@ Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
return std::max(dim0, dim1);
};
Value torchCstOne =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(1));
llvm::SmallVector<Value> broadcastSizes(indicesRank, torchCstOne);
llvm::SmallVector<int64_t> broadcastShape(indicesRank, 0);
for (auto index : indices) {
@ -585,6 +577,21 @@ Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
broadcastShape[idx] = maxDim(size, broadcastShape[idx]);
}
}
return std::make_pair(broadcastSizes, broadcastShape);
}
Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
OpBuilder b) {
llvm::SmallVector<Value> indices(indicesRef);
// Declare commonly used constants up front:
Value torchCstZero =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(0));
Value torchCstOne =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(1));
Value torchCstNegOne =
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(-1));
auto [broadcastSizes, broadcastShape] = getBroadcastShape(loc, indicesRef, b);
auto mulDim = [](int64_t dim0, int64_t dim1) {
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
@ -733,6 +740,34 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
}
// Broadcast the `values` tensor to the slice size created by the list of index
// tensors.
static Value broadcastValuesToSliceSize(Location loc, Value input, Value values,
llvm::ArrayRef<Value> indices,
OpBuilder b) {
auto inputType = cast<ValueTensorType>(input.getType());
ArrayRef<int64_t> inputStaticShape = inputType.getSizes();
auto valuesType = cast<ValueTensorType>(values.getType());
// In the case where the input rank is greater than the number of index
// tensors, the remaining dimensions of the input are indexed in their
// entirety. Thus, we need to append the remaining dimensions to get the shape
// of the indexed slice.
auto [resultShape, resultStaticShape] = getBroadcastShape(loc, indices, b);
for (size_t i = indices.size(); i < inputStaticShape.size(); i++) {
Value dim = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(i));
resultShape.push_back(b.create<AtenSizeIntOp>(loc, input, dim));
resultStaticShape.push_back(inputStaticShape[i]);
}
auto resultType = b.getType<Torch::ValueTensorType>(
resultStaticShape, valuesType.getOptionalDtype());
Value broadcastShapeList = b.create<PrimListConstructOp>(
loc, Torch::ListType::get(b.getType<Torch::IntType>()), resultShape);
return b.create<AtenBroadcastToOp>(loc, resultType, values,
broadcastShapeList);
}
class ConvertAtenIndexPutHackedTwinOp
: public OpConversionPattern<AtenIndexPutHackedTwinOp> {
public:
@ -780,6 +815,8 @@ public:
if (optionalIndicesCount == 0)
return rewriter.notifyMatchFailure(op, "Indices list must not be empty.");
values = broadcastValuesToSliceSize(loc, input, values, optionalIndicesList,
rewriter);
// Filter to available indices and get the indicesMap:
SmallVector<Value> indicesList;
SmallVector<int64_t> indicesMap;

View File

@ -1494,7 +1494,7 @@ STABLEHLO_PASS_SET = {
"RenormModuleFloat32_basic",
}
STABLEHLO_CRASHING_SET = set()
STABLEHLO_CRASHING_SET = {"IndexPutWithNoneAndBroadcastModule_basic"}
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
@ -2427,6 +2427,7 @@ ONNX_XFAIL_SET = {
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
"IntFloatModule_basic",
"IntImplicitModule_basic",
"IouOfModule_basic",

View File

@ -1269,3 +1269,36 @@ def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils):
tu.randint(7, high=5),
tu.rand(2, 3, 6, 7),
)
# ==============================================================================
class IndexPutWithNoneAndBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, 4, 5], torch.float32, True),
([6, 1], torch.int64, True),
([7], torch.int64, True),
([1, 6, 7], torch.float32, True),
]
)
def forward(self, input, index1, index2, value):
return torch.ops.aten.index_put(
input, (None, None, index1, index2), value, accumulate=True
)
@register_test_case(module_factory=lambda: IndexPutWithNoneAndBroadcastModule())
def IndexPutWithNoneAndBroadcastModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 3, 4, 5),
tu.randint(6, 1, high=4),
tu.randint(7, high=5),
tu.rand(1, 6, 7), # broadcasted to (2, 3, 6, 7)
)