mirror of https://github.com/llvm/torch-mlir
[LINALG] Broadcast `values` to shape of slize in `index_put` (#3487)
The `index_put` operation, `input[indices] = values`, allows for the values to be any shape that is broadcastable to the slice `input[indices]`. This commit adds broadcasting support to the Linalg lowering of `IndexPutHackedTwinOp`. Fixes: #3465pull/3505/head
parent
d2bc70f188
commit
e29191bd08
|
@ -541,19 +541,9 @@ public:
|
|||
|
||||
namespace {
|
||||
|
||||
Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
|
||||
OpBuilder b) {
|
||||
llvm::SmallVector<Value> indices(indicesRef);
|
||||
// Declare commonly used constants up front:
|
||||
Value torchCstZero =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(0));
|
||||
Value torchCstOne =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(1));
|
||||
Value torchCstNegOne =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(-1));
|
||||
|
||||
// Determine the broadcast sizes and materialize missing implicit end
|
||||
// dimensions:
|
||||
// Determine the common broadcast shape of all the index tensors.
|
||||
std::pair<llvm::SmallVector<Value>, llvm::SmallVector<int64_t>>
|
||||
getBroadcastShape(Location loc, llvm::ArrayRef<Value> indices, OpBuilder b) {
|
||||
int64_t indicesRank = 0;
|
||||
for (auto index : indices) {
|
||||
auto indexTy = cast<Torch::ValueTensorType>(index.getType());
|
||||
|
@ -567,6 +557,8 @@ Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
|
|||
return std::max(dim0, dim1);
|
||||
};
|
||||
|
||||
Value torchCstOne =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(1));
|
||||
llvm::SmallVector<Value> broadcastSizes(indicesRank, torchCstOne);
|
||||
llvm::SmallVector<int64_t> broadcastShape(indicesRank, 0);
|
||||
for (auto index : indices) {
|
||||
|
@ -585,6 +577,21 @@ Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
|
|||
broadcastShape[idx] = maxDim(size, broadcastShape[idx]);
|
||||
}
|
||||
}
|
||||
return std::make_pair(broadcastSizes, broadcastShape);
|
||||
}
|
||||
|
||||
Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
|
||||
OpBuilder b) {
|
||||
llvm::SmallVector<Value> indices(indicesRef);
|
||||
// Declare commonly used constants up front:
|
||||
Value torchCstZero =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(0));
|
||||
Value torchCstOne =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(1));
|
||||
Value torchCstNegOne =
|
||||
b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(-1));
|
||||
|
||||
auto [broadcastSizes, broadcastShape] = getBroadcastShape(loc, indicesRef, b);
|
||||
|
||||
auto mulDim = [](int64_t dim0, int64_t dim1) {
|
||||
if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize)
|
||||
|
@ -733,6 +740,34 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch,
|
|||
return b.create<AtenViewOp>(loc, valuesTy, values, outDimsList);
|
||||
}
|
||||
|
||||
// Broadcast the `values` tensor to the slice size created by the list of index
|
||||
// tensors.
|
||||
static Value broadcastValuesToSliceSize(Location loc, Value input, Value values,
|
||||
llvm::ArrayRef<Value> indices,
|
||||
OpBuilder b) {
|
||||
auto inputType = cast<ValueTensorType>(input.getType());
|
||||
ArrayRef<int64_t> inputStaticShape = inputType.getSizes();
|
||||
auto valuesType = cast<ValueTensorType>(values.getType());
|
||||
|
||||
// In the case where the input rank is greater than the number of index
|
||||
// tensors, the remaining dimensions of the input are indexed in their
|
||||
// entirety. Thus, we need to append the remaining dimensions to get the shape
|
||||
// of the indexed slice.
|
||||
auto [resultShape, resultStaticShape] = getBroadcastShape(loc, indices, b);
|
||||
for (size_t i = indices.size(); i < inputStaticShape.size(); i++) {
|
||||
Value dim = b.create<Torch::ConstantIntOp>(loc, b.getI64IntegerAttr(i));
|
||||
resultShape.push_back(b.create<AtenSizeIntOp>(loc, input, dim));
|
||||
resultStaticShape.push_back(inputStaticShape[i]);
|
||||
}
|
||||
|
||||
auto resultType = b.getType<Torch::ValueTensorType>(
|
||||
resultStaticShape, valuesType.getOptionalDtype());
|
||||
Value broadcastShapeList = b.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(b.getType<Torch::IntType>()), resultShape);
|
||||
return b.create<AtenBroadcastToOp>(loc, resultType, values,
|
||||
broadcastShapeList);
|
||||
}
|
||||
|
||||
class ConvertAtenIndexPutHackedTwinOp
|
||||
: public OpConversionPattern<AtenIndexPutHackedTwinOp> {
|
||||
public:
|
||||
|
@ -780,6 +815,8 @@ public:
|
|||
if (optionalIndicesCount == 0)
|
||||
return rewriter.notifyMatchFailure(op, "Indices list must not be empty.");
|
||||
|
||||
values = broadcastValuesToSliceSize(loc, input, values, optionalIndicesList,
|
||||
rewriter);
|
||||
// Filter to available indices and get the indicesMap:
|
||||
SmallVector<Value> indicesList;
|
||||
SmallVector<int64_t> indicesMap;
|
||||
|
|
|
@ -1494,7 +1494,7 @@ STABLEHLO_PASS_SET = {
|
|||
"RenormModuleFloat32_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = set()
|
||||
STABLEHLO_CRASHING_SET = {"IndexPutWithNoneAndBroadcastModule_basic"}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
|
@ -2427,6 +2427,7 @@ ONNX_XFAIL_SET = {
|
|||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||
"IntFloatModule_basic",
|
||||
"IntImplicitModule_basic",
|
||||
"IouOfModule_basic",
|
||||
|
|
|
@ -1269,3 +1269,36 @@ def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils):
|
|||
tu.randint(7, high=5),
|
||||
tu.rand(2, 3, 6, 7),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexPutWithNoneAndBroadcastModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 4, 5], torch.float32, True),
|
||||
([6, 1], torch.int64, True),
|
||||
([7], torch.int64, True),
|
||||
([1, 6, 7], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, input, index1, index2, value):
|
||||
return torch.ops.aten.index_put(
|
||||
input, (None, None, index1, index2), value, accumulate=True
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutWithNoneAndBroadcastModule())
|
||||
def IndexPutWithNoneAndBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 3, 4, 5),
|
||||
tu.randint(6, 1, high=4),
|
||||
tu.randint(7, high=5),
|
||||
tu.rand(1, 6, 7), # broadcasted to (2, 3, 6, 7)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue