mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx.Gather` for bad expansion (#3625)
A case where unsqueeze was require was missed causing compilation failures.pull/3631/head
parent
9ab93436c4
commit
39307f0462
|
@ -1809,10 +1809,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
flattenedIndices = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
loc, flattenIndicesTy, reshapedIndices, constZero);
|
||||
} else if (indicesRank > 1) {
|
||||
Value endDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
|
||||
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim);
|
||||
if (batchDimCount > indicesRank - 2) {
|
||||
flattenedIndices = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal);
|
||||
} else {
|
||||
Value endDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(indicesRank - 2));
|
||||
flattenedIndices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flattenIndicesTy, reshapedIndices, batchDimCountVal,
|
||||
endDim);
|
||||
}
|
||||
}
|
||||
|
||||
// step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices.
|
||||
|
@ -1834,8 +1840,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value endDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(batchDimCount + indicesLastDim - 1));
|
||||
Value flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flattenDataTy, data, batchDimCountVal, endDim);
|
||||
Value flattenedData = data;
|
||||
|
||||
if (indicesLastDim != 1) {
|
||||
flattenedData = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flattenDataTy, data, batchDimCountVal, endDim);
|
||||
}
|
||||
|
||||
// step 10. Now we have flattenedData and expandedIndices of same rank
|
||||
// to perform gather operation.
|
||||
|
@ -1851,6 +1861,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, gather, /*dim=*/constZero);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (unflattenIndicesDims.empty()) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeDimOp>(
|
||||
binder.op, resultType, gather, /*dim=*/batchDimCountVal);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value unflattenSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, intListTy, unflattenIndicesDims);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
|
||||
|
|
|
@ -180,6 +180,41 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gathernd_example_int32_batch_dim1
|
||||
func.func @test_gathernd_example_int32_batch_dim1(%arg0: !torch.vtensor<[2,2,2],si32>, %arg1: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[DIM0:.+]] = torch.aten.size.int %arg0, %[[INT0]]
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM1:.+]] = torch.aten.size.int %arg0, %[[INT1]]
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[DIM2:.+]] = torch.aten.size.int %arg0, %[[INT2]]
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[B0:.+]] = torch.aten.size.int %arg1, %[[INT0_2]]
|
||||
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_4:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, %[[INT1_3]], %[[INT0_0]], %[[INT1_4]], %[[INT1_1]]
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %[[SLICE]], %[[INT0_0]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SLICE]], %[[DIM1]], %[[INT1_1]]
|
||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %[[SLICE]]
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[B0]], %[[INT1_1]]
|
||||
// CHECK: %[[VIEW:.+]] = torch.aten.view %[[WHERE]], %[[LIST]]
|
||||
// CHECK: %[[INT1_5:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[UNSQ:.+]] = torch.aten.unsqueeze %[[VIEW]], %[[INT1_5]]
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[DIM0]], %[[INT1_1]], %[[DIM2]]
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[UNSQ]], %[[LIST]], %[[FALSE]]
|
||||
// CHECK: %[[INT1_6:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT1_5]], %[[EXPAND]], %[[FALSE]]
|
||||
// CHECK: %[[SQ:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT1_5]]
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 1 : si64} : (!torch.vtensor<[2,2,2],si32>, !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32>
|
||||
return %0 : !torch.vtensor<[2,2],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gather_elements
|
||||
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue