[MLIR][ONNX] Fix onnx.gather_nd implementation (#3070)

The indices should be expanded before the torch.gather operation.

Signed-off-by: Gaurav Shukla <gaurav@amd.com>
pull/3098/head
Gaurav Shukla 2024-04-01 20:17:09 +05:30 committed by GitHub
parent da88efad89
commit 129a79417a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 68 deletions

View File

@ -671,7 +671,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (!indicesTy || !indicesTy.hasSizes())
return failure();
// step1. Get shapes and ranks of data and indices. The last dimension
// step 1. Get shapes and ranks of data and indices. The last dimension
// of indices is expected to be static.
ArrayRef<int64_t> dataShape = dataTy.getSizes();
int64_t dataRank = dataShape.size();
@ -693,25 +693,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return rewriter.notifyMatchFailure(
binder.op, "expected last dimension of indices to be static");
// step2. Get dimension list of data and indices.
// step 2. Get dimension list of data.
SmallVector<int64_t> batchShape;
SmallVector<Value> batchDims;
SmallVector<Value> dataDims;
for (int64_t i = 0; i < dataRank; ++i) {
Value k = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), i);
Value dataDim = rewriter.create<Torch::AtenSizeIntOp>(loc, data, k);
dataDims.push_back(dataDim);
if (i < batchDimCount)
if (i < batchDimCount) {
batchShape.push_back(dataShape[i]);
batchDims.push_back(dataDim);
}
}
// step 3. Get dimension list of indices.
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
SmallVector<Value> indicesDimsMinusOne;
SmallVector<Value> unflattenIndicesDims;
Value indicesFlattenDim = constOne;
for (int64_t i = 0; i < indicesRank - 1; ++i) {
Value k = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), i);
Value indicesDim =
rewriter.create<Torch::AtenSizeIntOp>(loc, indices, k);
indicesDimsMinusOne.push_back(indicesDim);
if (i >= batchDimCount)
if (i >= batchDimCount) {
unflattenIndicesDims.push_back(indicesDim);
indicesFlattenDim = rewriter.create<Torch::AtenMulIntOp>(
loc, indicesFlattenDim, indicesDim);
}
}
ArrayRef<int64_t> indicesShapeMinusOne = indicesShape.drop_back();
@ -719,26 +732,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// the ranks of data(`r`) and indices(`q`) to be same. So we will
// perform collapse and reshape operations to match the ranks of data
// and indices(making sure the semantics of the onnx.gather_nd are
// preserved), perform torch.gather operation, later expand the gather
// result to match onnx.gather_nd output. For example, assuming indices
// is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and
// preserved), perform torch.gather operation, later unflatten the
// gather result to match onnx.gather_nd output. For example, assuming
// indices is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and
// batch_dims(`b`)=1. Firstly, modify indices to 1-D indexing as the
// torch.gather op supports only single dimensional indexing. (this
// algorithm would have been simpler if we can get a torch op that
// supports indexing at multiple dimensions simultaneously). 1-D indexed
// indices will be of shape (4, 5, 3, 1), now materialize it to
// `r-b-indices_shape[-1]` dimension of data i.e. reshaping it to the
// shape (4, 5, 3, 1, 1). Next step is to flatten the indices and data
// to (4, 15, 1, 1) and (4, 110, 7, 4) shapes respectively and then
// perform the torch.gather operation. Post the gather operation,
// unflatten the indices dimensions of result to (4, 5, 3, 1, 1) and
// then expand it to get the final output of shape (4, 5, 3, 7, 4).
// shape (4, 5, 3, 1, 1). Next step is to flatten+expand the indices and
// flatten the data to (4, 15, 7, 4) and (4, 110, 7, 4) shapes
// respectively and then perform the torch.gather operation. Post the
// gather operation, unflatten the indices dimensions of result to (4,
// 5, 3, 7, 4) which is our required result.
// step3. Convert indices_shape[-1] dimensional indexing to 1D indexing.
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
// step 4. Convert indices_shape[-1] dimensional indexing to 1D
// indexing.
Value sliceDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesRank - 1));
SmallVector<int64_t> indicesSliceShape(indicesShapeMinusOne);
@ -774,7 +784,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
dataDims[batchDimCount + i]);
}
// step4. Compute all the required result types here.
// step 5. Compute all the required result types here.
SmallVector<int64_t> reshapeIndicesShape(indicesShapeMinusOne);
SmallVector<Value> reshapeIndicesDims(indicesDimsMinusOne);
// Determine the collapsed dim size of indices(index_shape[-1] is not
@ -801,19 +811,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
dataCt *= sz;
}
flattenDataShape.push_back(dataCt);
// Compute result size of the final expand op.
SmallVector<Value> expandResultDims(indicesDimsMinusOne);
// Compute the shape of expand op.
SmallVector<Value> expandIndicesDims(batchDims);
expandIndicesDims.push_back(indicesFlattenDim);
SmallVector<int64_t> expandIndicesShape(batchShape);
expandIndicesShape.push_back(indicesCt);
// Append `r-b-indices_shape[-1]` unit or data dims appropriately to all
// result types.
for (int64_t i = batchDimCount + indicesLastDim; i < dataRank; ++i) {
reshapeIndicesShape.push_back(1);
flattenIndicesShape.push_back(1);
flattenDataShape.push_back(dataShape[i]);
expandIndicesShape.push_back(dataShape[i]);
reshapeIndicesDims.push_back(constOne);
expandResultDims.push_back(dataDims[i]);
expandIndicesDims.push_back(dataDims[i]);
}
// step5. Reshape 1-D indexed indices to match the rank of flattened
// step 6. Reshape 1-D indexed indices to match the rank of flattened
// data by inserting unit dimensions.
auto intListTy = rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>());
@ -825,7 +839,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value reshapedIndices = rewriter.create<Torch::AtenViewOp>(
loc, reshapeIndicesTy, updatedIndices, reshapeIndicesSizeList);
// step6. Flatten `q-b-1` dimensions of the indices.
// step 7. Flatten `q-b-1` dimensions of the indices.
auto flattenIndicesTy = rewriter.getType<Torch::ValueTensorType>(
flattenIndicesShape, indicesTy.getOptionalDtype());
Value batchDimCountVal = rewriter.create<Torch::ConstantIntOp>(
@ -836,12 +850,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
loc, flattenIndicesTy, reshapedIndices, constZero);
} else if (indicesRank > 1) {
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesRank - 1));
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim);
}
// step7. Flatten indices_shape[-1] dimensions of data.
// step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices.
auto expandIndicesTy = rewriter.getType<Torch::ValueTensorType>(
expandIndicesShape, indicesTy.getOptionalDtype());
Value expandIndicesSizeList =
rewriter.create<Torch::PrimListConstructOp>(loc, intListTy,
expandIndicesDims);
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false));
Value expandedIndices = rewriter.create<Torch::AtenExpandOp>(
loc, expandIndicesTy, flattenedIndices, expandIndicesSizeList,
/*implicit=*/constFalse);
// step 9. Flatten indices_shape[-1] dimensions of data.
auto flattenDataTy = rewriter.getType<Torch::ValueTensorType>(
flattenDataShape, dataTy.getOptionalDtype());
Value endDim = rewriter.create<Torch::ConstantIntOp>(
@ -850,39 +877,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenDataTy, data, batchDimCountVal, endDim);
// step8. Now we have flattenedData and flattenedIndices of same rank to
// perform gather operation.
// step 10. Now we have flattenedData and expandedIndices of same rank
// to perform gather operation.
auto gatherTy = rewriter.getType<Torch::ValueTensorType>(
flattenIndicesShape, dataTy.getOptionalDtype());
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false));
expandIndicesShape, dataTy.getOptionalDtype());
Value gather = rewriter.create<Torch::AtenGatherOp>(
loc, gatherTy, flattenedData, batchDimCountVal, flattenedIndices,
loc, gatherTy, flattenedData, batchDimCountVal, expandedIndices,
/*sparseGrad=*/constFalse);
// step9. Unflatten the collapsed indices dims of gather result.
auto unflattenTy = rewriter.getType<Torch::ValueTensorType>(
reshapeIndicesShape, dataTy.getOptionalDtype());
Value unflattenedGather = gather;
// step 11. Unflatten the collapsed indices dims of gather result.
if (indicesRank == 1) {
unflattenedGather = rewriter.create<Torch::AtenSqueezeDimOp>(
loc, unflattenTy, gather, /*dim=*/constZero);
} else if (indicesRank > 1) {
Value unflattenSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, intListTy, unflattenIndicesDims);
unflattenedGather = rewriter.create<Torch::AtenUnflattenIntOp>(
loc, unflattenTy, gather, batchDimCountVal, unflattenSizeList);
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeDimOp>(
binder.op, resultType, gather, /*dim=*/constZero);
return success();
}
// step10. Expand `r-b-indices_shape[-1]` dims of unflattenedGather
// result.
Value expandResultSizeList =
rewriter.create<Torch::PrimListConstructOp>(loc, intListTy,
expandResultDims);
rewriter.replaceOpWithNewOp<Torch::AtenExpandOp>(
binder.op, resultType, unflattenedGather, expandResultSizeList,
/*implicit=*/constFalse);
Value unflattenSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, intListTy, unflattenIndicesDims);
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
binder.op, resultType, gather, batchDimCountVal, unflattenSizeList);
return success();
});
patterns.onOp(

View File

@ -96,14 +96,16 @@ func.func @test_gather_nd(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1: !torch.vt
// CHECK: %[[SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[SIZE3:.+]] = torch.aten.size.int %arg0, %[[INT3]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int -> !torch.int
// CHECK: %[[INT0_3:.+]] = torch.constant.int 0
// CHECK: %[[INT1_4:.+]] = torch.constant.int 1
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[INDSIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0_0]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[INDSIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_1]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1_4]], %[[INDSIZE1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT2_2:.+]] = torch.constant.int 2
// CHECK: %[[INDSIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2_2]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int
// CHECK: %[[INT0_3:.+]] = torch.constant.int 0
// CHECK: %[[INT1_4:.+]] = torch.constant.int 1
// CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[INDSIZE2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT3_5:.+]] = torch.constant.int 3
// CHECK: %[[INT1_6:.+]] = torch.constant.int 1
// CHECK: %[[SLICE0:.+]] = torch.aten.slice.Tensor %arg1, %[[INT3_5]], %[[INT0_3]], %[[INT1_6]], %[[INT1_4]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,4,3,1],si64>
@ -119,17 +121,17 @@ func.func @test_gather_nd(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1: !torch.vt
// CHECK: %[[LIST0:.+]] = torch.prim.ListConstruct %[[INDSIZE0]], %[[INDSIZE1]], %[[INDSIZE2]], %[[INT1_4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VIEW:.+]] = torch.aten.view %[[NEWIND]], %[[LIST0]] : !torch.vtensor<[2,4,3,1],si64>, !torch.list<int> -> !torch.vtensor<[2,4,3,1],si64>
// CHECK: %[[INT1_8:.+]] = torch.constant.int 1
// CHECK: %[[INT3_9:.+]] = torch.constant.int 3
// CHECK: %[[FLATIND:.+]] = torch.aten.flatten.using_ints %[[VIEW]], %[[INT1_8]], %[[INT3_9]] : !torch.vtensor<[2,4,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,12,1],si64>
// CHECK: %[[INT2_9:.+]] = torch.constant.int 2
// CHECK: %[[FLATIND:.+]] = torch.aten.flatten.using_ints %[[VIEW]], %[[INT1_8]], %[[INT2_9]] : !torch.vtensor<[2,4,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,12,1],si64>
// CHECK: %[[EXLIST:.+]] = torch.prim.ListConstruct %[[SIZE0]], %[[MUL2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[EXPANDIND:.+]] = torch.aten.expand %[[FLATIND]], %[[EXLIST]], %[[FALSE]] : !torch.vtensor<[2,12,1],si64>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,12,5],si64>
// CHECK: %[[INT2_10:.+]] = torch.constant.int 2
// CHECK: %[[FLATDATA:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT1_8]], %[[INT2_10]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,48,5],f32>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT1_8]], %[[FLATIND]], %[[FALSE]]
// CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT1_8]], %[[EXPANDIND]], %[[FALSE]]
// CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[INDSIZE1]], %[[INDSIZE2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[UNFLAT:.+]] = torch.aten.unflatten.int %[[GATHER]], %[[INT1_8]], %[[LIST1]] : !torch.vtensor<[2,12,1],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,4,3,1],f32>
// CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[INDSIZE0]], %[[INDSIZE1]], %[[INDSIZE2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.+]] = torch.aten.expand %[[UNFLAT]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[2,4,3,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,4,3,5],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[2,4,3,5],f32>
// CHECK: %[[UNFLAT:.+]] = torch.aten.unflatten.int %[[GATHER]], %[[INT1_8]], %[[LIST1]] : !torch.vtensor<[2,12,5],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[2,4,3,5],f32>
// CHECK: return %[[UNFLAT]] : !torch.vtensor<[2,4,3,5],f32>
%0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 1 : si64} : (!torch.vtensor<[2,6,8,5],f32>, !torch.vtensor<[2,4,3,2], si64>) -> !torch.vtensor<[2,4,3,5],f32>
return %0 : !torch.vtensor<[2,4,3,5],f32>
}
@ -164,14 +166,14 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1
// CHECK: %[[VIEW:.+]] = torch.aten.view %[[NEWIND]], %[[LIST0]] : !torch.vtensor<[1],si64>, !torch.list<int> -> !torch.vtensor<[1,1],si64>
// CHECK: %[[INT0_5:.+]] = torch.constant.int 0
// CHECK: %[[FLATIND:.+]] = torch.aten.unsqueeze %[[VIEW]], %[[INT0_0]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64>
// CHECK: %[[EXLIST:.+]] = torch.prim.ListConstruct %[[INT1_1]], %[[SIZE2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[EXPANDIND:.+]] = torch.aten.expand %[[FLATIND]], %[[EXLIST]], %[[FALSE]] : !torch.vtensor<[1,1,1],si64>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,8,5],si64>
// CHECK: %[[INT1_6:.+]] = torch.constant.int 1
// CHECK: %[[FLATDATA:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT0_5]], %[[INT1_6]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,8,5],f32>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT0_5]], %[[FLATIND]], %[[FALSE]]
// CHECK: %[[SQUEEZE:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT0_0]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
// CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[SIZE2]], %[[SIZE3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.+]] = torch.aten.expand %[[SQUEEZE]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[1,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[8,5],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[8,5],f32>
// CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT0_5]], %[[EXPANDIND]], %[[FALSE]]
// CHECK: %[[SQUEEZE:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT0_0]] : !torch.vtensor<[1,8,5],f32>, !torch.int -> !torch.vtensor<[8,5],f32>
// CHECK: return %[[SQUEEZE]] : !torch.vtensor<[8,5],f32>
%0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 0 : si64} : (!torch.vtensor<[2,6,8,5],f32>, !torch.vtensor<[2], si64>) -> !torch.vtensor<[8,5],f32>
return %0 : !torch.vtensor<[8,5],f32>
}