From 129a79417ae3aaa6041317c2af9ff313541e3fe8 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 1 Apr 2024 20:17:09 +0530 Subject: [PATCH] [MLIR][ONNX] Fix onnx.gather_nd implementation (#3070) The indices should be expanded before the torch.gather operation. Signed-off-by: Gaurav Shukla --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 116 ++++++++++-------- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 34 ++--- 2 files changed, 82 insertions(+), 68 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 23203d9b7..0683be040 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -671,7 +671,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (!indicesTy || !indicesTy.hasSizes()) return failure(); - // step1. Get shapes and ranks of data and indices. The last dimension + // step 1. Get shapes and ranks of data and indices. The last dimension // of indices is expected to be static. ArrayRef dataShape = dataTy.getSizes(); int64_t dataRank = dataShape.size(); @@ -693,25 +693,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure( binder.op, "expected last dimension of indices to be static"); - // step2. Get dimension list of data and indices. + // step 2. Get dimension list of data. SmallVector batchShape; + SmallVector batchDims; SmallVector dataDims; for (int64_t i = 0; i < dataRank; ++i) { Value k = rewriter.create(binder.getLoc(), i); Value dataDim = rewriter.create(loc, data, k); dataDims.push_back(dataDim); - if (i < batchDimCount) + if (i < batchDimCount) { batchShape.push_back(dataShape[i]); + batchDims.push_back(dataDim); + } } + + // step 3. Get dimension list of indices. + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); SmallVector indicesDimsMinusOne; SmallVector unflattenIndicesDims; + Value indicesFlattenDim = constOne; for (int64_t i = 0; i < indicesRank - 1; ++i) { Value k = rewriter.create(binder.getLoc(), i); Value indicesDim = rewriter.create(loc, indices, k); indicesDimsMinusOne.push_back(indicesDim); - if (i >= batchDimCount) + if (i >= batchDimCount) { unflattenIndicesDims.push_back(indicesDim); + indicesFlattenDim = rewriter.create( + loc, indicesFlattenDim, indicesDim); + } } ArrayRef indicesShapeMinusOne = indicesShape.drop_back(); @@ -719,26 +732,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // the ranks of data(`r`) and indices(`q`) to be same. So we will // perform collapse and reshape operations to match the ranks of data // and indices(making sure the semantics of the onnx.gather_nd are - // preserved), perform torch.gather operation, later expand the gather - // result to match onnx.gather_nd output. For example, assuming indices - // is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and + // preserved), perform torch.gather operation, later unflatten the + // gather result to match onnx.gather_nd output. For example, assuming + // indices is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and // batch_dims(`b`)=1. Firstly, modify indices to 1-D indexing as the // torch.gather op supports only single dimensional indexing. (this // algorithm would have been simpler if we can get a torch op that // supports indexing at multiple dimensions simultaneously). 1-D indexed // indices will be of shape (4, 5, 3, 1), now materialize it to // `r-b-indices_shape[-1]` dimension of data i.e. reshaping it to the - // shape (4, 5, 3, 1, 1). Next step is to flatten the indices and data - // to (4, 15, 1, 1) and (4, 110, 7, 4) shapes respectively and then - // perform the torch.gather operation. Post the gather operation, - // unflatten the indices dimensions of result to (4, 5, 3, 1, 1) and - // then expand it to get the final output of shape (4, 5, 3, 7, 4). + // shape (4, 5, 3, 1, 1). Next step is to flatten+expand the indices and + // flatten the data to (4, 15, 7, 4) and (4, 110, 7, 4) shapes + // respectively and then perform the torch.gather operation. Post the + // gather operation, unflatten the indices dimensions of result to (4, + // 5, 3, 7, 4) which is our required result. - // step3. Convert indices_shape[-1] dimensional indexing to 1D indexing. - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + // step 4. Convert indices_shape[-1] dimensional indexing to 1D + // indexing. Value sliceDim = rewriter.create( loc, rewriter.getI64IntegerAttr(indicesRank - 1)); SmallVector indicesSliceShape(indicesShapeMinusOne); @@ -774,7 +784,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( dataDims[batchDimCount + i]); } - // step4. Compute all the required result types here. + // step 5. Compute all the required result types here. SmallVector reshapeIndicesShape(indicesShapeMinusOne); SmallVector reshapeIndicesDims(indicesDimsMinusOne); // Determine the collapsed dim size of indices(index_shape[-1] is not @@ -801,19 +811,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( dataCt *= sz; } flattenDataShape.push_back(dataCt); - // Compute result size of the final expand op. - SmallVector expandResultDims(indicesDimsMinusOne); + // Compute the shape of expand op. + SmallVector expandIndicesDims(batchDims); + expandIndicesDims.push_back(indicesFlattenDim); + SmallVector expandIndicesShape(batchShape); + expandIndicesShape.push_back(indicesCt); // Append `r-b-indices_shape[-1]` unit or data dims appropriately to all // result types. for (int64_t i = batchDimCount + indicesLastDim; i < dataRank; ++i) { reshapeIndicesShape.push_back(1); flattenIndicesShape.push_back(1); flattenDataShape.push_back(dataShape[i]); + expandIndicesShape.push_back(dataShape[i]); reshapeIndicesDims.push_back(constOne); - expandResultDims.push_back(dataDims[i]); + expandIndicesDims.push_back(dataDims[i]); } - // step5. Reshape 1-D indexed indices to match the rank of flattened + // step 6. Reshape 1-D indexed indices to match the rank of flattened // data by inserting unit dimensions. auto intListTy = rewriter.getType( rewriter.getType()); @@ -825,7 +839,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value reshapedIndices = rewriter.create( loc, reshapeIndicesTy, updatedIndices, reshapeIndicesSizeList); - // step6. Flatten `q-b-1` dimensions of the indices. + // step 7. Flatten `q-b-1` dimensions of the indices. auto flattenIndicesTy = rewriter.getType( flattenIndicesShape, indicesTy.getOptionalDtype()); Value batchDimCountVal = rewriter.create( @@ -836,12 +850,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( loc, flattenIndicesTy, reshapedIndices, constZero); } else if (indicesRank > 1) { Value endDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(indicesRank - 1)); + loc, rewriter.getI64IntegerAttr(indicesRank - 2)); flattenedIndices = rewriter.create( loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim); } - // step7. Flatten indices_shape[-1] dimensions of data. + // step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices. + auto expandIndicesTy = rewriter.getType( + expandIndicesShape, indicesTy.getOptionalDtype()); + Value expandIndicesSizeList = + rewriter.create(loc, intListTy, + expandIndicesDims); + Value constFalse = rewriter.create( + loc, rewriter.getType(), + rewriter.getBoolAttr(false)); + Value expandedIndices = rewriter.create( + loc, expandIndicesTy, flattenedIndices, expandIndicesSizeList, + /*implicit=*/constFalse); + + // step 9. Flatten indices_shape[-1] dimensions of data. auto flattenDataTy = rewriter.getType( flattenDataShape, dataTy.getOptionalDtype()); Value endDim = rewriter.create( @@ -850,39 +877,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value flattenedData = rewriter.create( loc, flattenDataTy, data, batchDimCountVal, endDim); - // step8. Now we have flattenedData and flattenedIndices of same rank to - // perform gather operation. + // step 10. Now we have flattenedData and expandedIndices of same rank + // to perform gather operation. auto gatherTy = rewriter.getType( - flattenIndicesShape, dataTy.getOptionalDtype()); - Value constFalse = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getBoolAttr(false)); + expandIndicesShape, dataTy.getOptionalDtype()); Value gather = rewriter.create( - loc, gatherTy, flattenedData, batchDimCountVal, flattenedIndices, + loc, gatherTy, flattenedData, batchDimCountVal, expandedIndices, /*sparseGrad=*/constFalse); - // step9. Unflatten the collapsed indices dims of gather result. - auto unflattenTy = rewriter.getType( - reshapeIndicesShape, dataTy.getOptionalDtype()); - Value unflattenedGather = gather; + // step 11. Unflatten the collapsed indices dims of gather result. if (indicesRank == 1) { - unflattenedGather = rewriter.create( - loc, unflattenTy, gather, /*dim=*/constZero); - } else if (indicesRank > 1) { - Value unflattenSizeList = rewriter.create( - loc, intListTy, unflattenIndicesDims); - unflattenedGather = rewriter.create( - loc, unflattenTy, gather, batchDimCountVal, unflattenSizeList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, /*dim=*/constZero); + return success(); } - - // step10. Expand `r-b-indices_shape[-1]` dims of unflattenedGather - // result. - Value expandResultSizeList = - rewriter.create(loc, intListTy, - expandResultDims); - rewriter.replaceOpWithNewOp( - binder.op, resultType, unflattenedGather, expandResultSizeList, - /*implicit=*/constFalse); + Value unflattenSizeList = rewriter.create( + loc, intListTy, unflattenIndicesDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, batchDimCountVal, unflattenSizeList); return success(); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 3741e5c9f..dc32918f6 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -96,14 +96,16 @@ func.func @test_gather_nd(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1: !torch.vt // CHECK: %[[SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int -> !torch.int // CHECK: %[[INT3:.+]] = torch.constant.int 3 // CHECK: %[[SIZE3:.+]] = torch.aten.size.int %arg0, %[[INT3]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int -> !torch.int + // CHECK: %[[INT0_3:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_4:.+]] = torch.constant.int 1 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INDSIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0_0]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 // CHECK: %[[INDSIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_1]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1_4]], %[[INDSIZE1]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[INT2_2:.+]] = torch.constant.int 2 // CHECK: %[[INDSIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2_2]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int -> !torch.int - // CHECK: %[[INT0_3:.+]] = torch.constant.int 0 - // CHECK: %[[INT1_4:.+]] = torch.constant.int 1 + // CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[INDSIZE2]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[INT3_5:.+]] = torch.constant.int 3 // CHECK: %[[INT1_6:.+]] = torch.constant.int 1 // CHECK: %[[SLICE0:.+]] = torch.aten.slice.Tensor %arg1, %[[INT3_5]], %[[INT0_3]], %[[INT1_6]], %[[INT1_4]] : !torch.vtensor<[2,4,3,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,4,3,1],si64> @@ -119,17 +121,17 @@ func.func @test_gather_nd(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1: !torch.vt // CHECK: %[[LIST0:.+]] = torch.prim.ListConstruct %[[INDSIZE0]], %[[INDSIZE1]], %[[INDSIZE2]], %[[INT1_4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[VIEW:.+]] = torch.aten.view %[[NEWIND]], %[[LIST0]] : !torch.vtensor<[2,4,3,1],si64>, !torch.list -> !torch.vtensor<[2,4,3,1],si64> // CHECK: %[[INT1_8:.+]] = torch.constant.int 1 - // CHECK: %[[INT3_9:.+]] = torch.constant.int 3 - // CHECK: %[[FLATIND:.+]] = torch.aten.flatten.using_ints %[[VIEW]], %[[INT1_8]], %[[INT3_9]] : !torch.vtensor<[2,4,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,12,1],si64> + // CHECK: %[[INT2_9:.+]] = torch.constant.int 2 + // CHECK: %[[FLATIND:.+]] = torch.aten.flatten.using_ints %[[VIEW]], %[[INT1_8]], %[[INT2_9]] : !torch.vtensor<[2,4,3,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,12,1],si64> + // CHECK: %[[EXLIST:.+]] = torch.prim.ListConstruct %[[SIZE0]], %[[MUL2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[EXPANDIND:.+]] = torch.aten.expand %[[FLATIND]], %[[EXLIST]], %[[FALSE]] : !torch.vtensor<[2,12,1],si64>, !torch.list, !torch.bool -> !torch.vtensor<[2,12,5],si64> // CHECK: %[[INT2_10:.+]] = torch.constant.int 2 // CHECK: %[[FLATDATA:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT1_8]], %[[INT2_10]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,48,5],f32> - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT1_8]], %[[FLATIND]], %[[FALSE]] + // CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT1_8]], %[[EXPANDIND]], %[[FALSE]] // CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[INDSIZE1]], %[[INDSIZE2]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[UNFLAT:.+]] = torch.aten.unflatten.int %[[GATHER]], %[[INT1_8]], %[[LIST1]] : !torch.vtensor<[2,12,1],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,4,3,1],f32> - // CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[INDSIZE0]], %[[INDSIZE1]], %[[INDSIZE2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[RES:.+]] = torch.aten.expand %[[UNFLAT]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[2,4,3,1],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,4,3,5],f32> - // CHECK: return %[[RES]] : !torch.vtensor<[2,4,3,5],f32> + // CHECK: %[[UNFLAT:.+]] = torch.aten.unflatten.int %[[GATHER]], %[[INT1_8]], %[[LIST1]] : !torch.vtensor<[2,12,5],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,4,3,5],f32> + // CHECK: return %[[UNFLAT]] : !torch.vtensor<[2,4,3,5],f32> %0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 1 : si64} : (!torch.vtensor<[2,6,8,5],f32>, !torch.vtensor<[2,4,3,2], si64>) -> !torch.vtensor<[2,4,3,5],f32> return %0 : !torch.vtensor<[2,4,3,5],f32> } @@ -164,14 +166,14 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1 // CHECK: %[[VIEW:.+]] = torch.aten.view %[[NEWIND]], %[[LIST0]] : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[1,1],si64> // CHECK: %[[INT0_5:.+]] = torch.constant.int 0 // CHECK: %[[FLATIND:.+]] = torch.aten.unsqueeze %[[VIEW]], %[[INT0_0]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1,1,1],si64> + // CHECK: %[[EXLIST:.+]] = torch.prim.ListConstruct %[[INT1_1]], %[[SIZE2]], %[[SIZE3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[EXPANDIND:.+]] = torch.aten.expand %[[FLATIND]], %[[EXLIST]], %[[FALSE]] : !torch.vtensor<[1,1,1],si64>, !torch.list, !torch.bool -> !torch.vtensor<[1,8,5],si64> // CHECK: %[[INT1_6:.+]] = torch.constant.int 1 // CHECK: %[[FLATDATA:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT0_5]], %[[INT1_6]] : !torch.vtensor<[2,6,8,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,8,5],f32> - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT0_5]], %[[FLATIND]], %[[FALSE]] - // CHECK: %[[SQUEEZE:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT0_0]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> - // CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[SIZE2]], %[[SIZE3]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[RES:.+]] = torch.aten.expand %[[SQUEEZE]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[1,1],f32>, !torch.list, !torch.bool -> !torch.vtensor<[8,5],f32> - // CHECK: return %[[RES]] : !torch.vtensor<[8,5],f32> + // CHECK: %[[GATHER:.+]] = torch.aten.gather %[[FLATDATA]], %[[INT0_5]], %[[EXPANDIND]], %[[FALSE]] + // CHECK: %[[SQUEEZE:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT0_0]] : !torch.vtensor<[1,8,5],f32>, !torch.int -> !torch.vtensor<[8,5],f32> + // CHECK: return %[[SQUEEZE]] : !torch.vtensor<[8,5],f32> %0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 0 : si64} : (!torch.vtensor<[2,6,8,5],f32>, !torch.vtensor<[2], si64>) -> !torch.vtensor<[8,5],f32> return %0 : !torch.vtensor<[8,5],f32> }