[onnx] Fix onnx.gather lowering for rank-0 indices (#2973)

We assumed rank was atleast 1 however it can be rank-0, generating an
illegal pair of flatten / unflatten operations. Corrected this.
pull/2971/head
Rob Suderman 2024-03-04 08:25:19 -08:00 committed by GitHub
parent 916554f270
commit d51e80b648
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 17 deletions

View File

@ -572,10 +572,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto ctx = binder.op->getContext();
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
auto dataTy = cast<Torch::ValueTensorType>(data.getType());
if (!dataTy || !dataTy.hasSizes())
if (!dataTy || !dataTy.hasSizes() || !indicesTy.hasSizes())
return failure();
if (axis < 0)
axis += dataTy.getSizes().size();
int64_t dataRank = dataTy.getSizes().size();
int64_t indicesRank = indicesTy.getSizes().size();
axis = axis < 0 ? axis + dataRank : axis;
Value index = rewriter.create<Torch::ConstantIntOp>(
loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis));
@ -599,8 +601,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto intListTy = rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>());
auto indicesSize =
rewriter.create<Torch::AtenSizeOp>(loc, intListTy, indices);
llvm::SmallVector<Value> indicesDims;
for (int i = 0, s = indicesTy.getSizes().size(); i < s; ++i) {
Value k = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), i);
indicesDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), indices, k));
}
Value indicesSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), intListTy, indicesDims);
// Determine the collapsed dim size:
auto indicesCt = 1;
@ -615,20 +625,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto flattenTy = rewriter.getType<Torch::ValueTensorType>(
SmallVector<int64_t>{indicesCt}, indicesTy.getOptionalDtype());
if (indicesRank == 0) {
indices = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(), flattenTy, indices, zero);
} else if (indicesRank > 1) {
Value rank = rewriter.create<Torch::AtenDimOp>(loc, intTy, indices);
Value end = rewriter.create<Torch::AtenSubIntOp>(loc, rank, one);
indices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenTy, indices, zero, end);
}
llvm::SmallVector<int64_t> gatherShape(dataTy.getSizes());
gatherShape[axis] = indicesCt;
auto gatherTy = rewriter.getType<Torch::ValueTensorType>(
gatherShape, dataTy.getOptionalDtype());
Value gather = rewriter.create<Torch::AtenIndexSelectOp>(
loc, gatherTy, data, index, indices);
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
binder.op, resultType, gather, index, indicesSize);
if (indicesRank == 1) {
rewriter.replaceOp(binder.op, gather);
return success();
}
if (indicesRank > 1) {
gather = rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
binder.op, resultType, gather, index, indicesSizeList);
return success();
}
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op, resultType,
gather);
return success();
});
patterns.onOp(

View File

@ -2190,14 +2190,12 @@ ONNX_XFAIL_SET = {
"ElementwiseUnaryIntModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseWhereScalarModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleF16_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"FlattenDynamicModule_basic",
"GluStaticModule_basic",
"GroupNormModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexTensorHackedTwinModule3dInput_basic",
"IndexTensorHackedTwinModule_basic",

View File

@ -37,8 +37,8 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[
// -----
// CHECK-LABEL: func.func @test_gather
func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-LABEL: func.func @test_gather_nd
func.func @test_gather_nd(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK: %[[AXIS:.+]] = torch.constant.int 0
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[ONE:.+]] = torch.constant.int 1
@ -46,7 +46,15 @@ func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
// CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
// CHECK: %[[SZ:.+]] = torch.aten.size %[[SEL]]
// CHECK: %[[D0:.+]] = torch.constant.int 0
// CHECK: %[[SZ0:.+]] = torch.aten.size.int %[[SEL]], %[[D0]]
// CHECK: %[[D1:.+]] = torch.constant.int 1
// CHECK: %[[SZ1:.+]] = torch.aten.size.int %[[SEL]], %[[D1]]
// CHECK: %[[D2:.+]] = torch.constant.int 2
// CHECK: %[[SZ2:.+]] = torch.aten.size.int %[[SEL]], %[[D2]]
// CHECK: %[[D3:.+]] = torch.constant.int 3
// CHECK: %[[SZ3:.+]] = torch.aten.size.int %[[SEL]], %[[D3]]
// CHECK: %[[SZ:.+]] = torch.prim.ListConstruct %[[SZ0]], %[[SZ1]], %[[SZ2]], %[[SZ3]]
// CHECK: %[[DIM:.+]] = torch.aten.dim %[[SEL]]
// CHECK: %[[SUB:.+]] = torch.aten.sub.int %[[DIM]], %[[ONE]]
// CHECK: %[[FLAT:.+]] = torch.aten.flatten.using_ints %[[SEL]], %[[ZERO]], %[[SUB]]
@ -59,6 +67,25 @@ func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor
// -----
// CHECK-LABEL: func.func @test_gather_scalar
func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK: %[[AXIS:.+]] = torch.constant.int 0
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[ONE:.+]] = torch.constant.int 1
// CHECK: %[[LT:.+]] = torch.aten.le.Scalar %arg1, %[[ZERO]]
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
// CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
// CHECK: %[[FLAT:.+]] = torch.aten.unsqueeze %[[SEL]], %[[ZERO]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]]
// CHECK: %[[RES:.+]] = torch.aten.squeeze %[[ISEL]] : !torch.vtensor<[1,4,5],f32> -> !torch.vtensor<[4,5],f32>
// CHECK: return %[[RES]]
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32>
return %0 : !torch.vtensor<[4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_gather_elements
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0