[onnx] Fix `onnx.Gather` for bad expansion (#3625)

A case where unsqueeze was require was missed causing compilation
failures.
pull/3631/head
Rob Suderman 2024-08-13 09:38:55 -07:00 committed by GitHub
parent 9ab93436c4
commit 39307f0462
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 6 deletions

View File

@ -1809,10 +1809,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
flattenedIndices = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, flattenIndicesTy, reshapedIndices, constZero);
} else if (indicesRank > 1) {
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim);
if (batchDimCount > indicesRank - 2) {
flattenedIndices = rewriter.create<Torch::AtenUnsqueezeOp>(
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal);
} else {
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal,
endDim);
}
}
// step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices.
@ -1834,8 +1840,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc,
rewriter.getI64IntegerAttr(batchDimCount + indicesLastDim - 1));
Value flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenDataTy, data, batchDimCountVal, endDim);
Value flattenedData = data;
if (indicesLastDim != 1) {
flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenDataTy, data, batchDimCountVal, endDim);
}
// step 10. Now we have flattenedData and expandedIndices of same rank
// to perform gather operation.
@ -1851,6 +1861,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, gather, /*dim=*/constZero);
return success();
}
if (unflattenIndicesDims.empty()) {
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeDimOp>(
binder.op, resultType, gather, /*dim=*/batchDimCountVal);
return success();
}
Value unflattenSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, intListTy, unflattenIndicesDims);
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(

View File

@ -180,6 +180,41 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1
// -----
// CHECK-LABEL: func.func @test_gathernd_example_int32_batch_dim1
func.func @test_gathernd_example_int32_batch_dim1(%arg0: !torch.vtensor<[2,2,2],si32>, %arg1: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[DIM0:.+]] = torch.aten.size.int %arg0, %[[INT0]]
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[DIM1:.+]] = torch.aten.size.int %arg0, %[[INT1]]
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[DIM2:.+]] = torch.aten.size.int %arg0, %[[INT2]]
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
// CHECK: %[[B0:.+]] = torch.aten.size.int %arg1, %[[INT0_2]]
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
// CHECK: %[[INT1_4:.+]] = torch.constant.int 1
// CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, %[[INT1_3]], %[[INT0_0]], %[[INT1_4]], %[[INT1_1]]
// CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %[[SLICE]], %[[INT0_0]]
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SLICE]], %[[DIM1]], %[[INT1_1]]
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %[[SLICE]]
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[B0]], %[[INT1_1]]
// CHECK: %[[VIEW:.+]] = torch.aten.view %[[WHERE]], %[[LIST]]
// CHECK: %[[INT1_5:.+]] = torch.constant.int 1
// CHECK: %[[UNSQ:.+]] = torch.aten.unsqueeze %[[VIEW]], %[[INT1_5]]
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[DIM0]], %[[INT1_1]], %[[DIM2]]
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[UNSQ]], %[[LIST]], %[[FALSE]]
// CHECK: %[[INT1_6:.+]] = torch.constant.int 1
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT1_5]], %[[EXPAND]], %[[FALSE]]
// CHECK: %[[SQ:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT1_5]]
%none = torch.constant.none
%0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 1 : si64} : (!torch.vtensor<[2,2,2],si32>, !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32>
return %0 : !torch.vtensor<[2,2],si32>
}
// -----
// CHECK-LABEL: func.func @test_gather_elements
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0