[onnx] Handle negative indices for `onnx.GatherElements` (#3599)

Add a check for negative indices and offset appropriately for
`onnx.GatherElements`.
pull/3607/head
Rob Suderman 2024-08-06 18:54:01 -07:00 committed by GitHub
parent 78d0fa8998
commit b48e55c2f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 1 deletions

View File

@ -1605,6 +1605,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
Value constZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value constOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value axisSize = rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(),
data, constAxis);
Value indicesAdd = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), indicesTy, indices, axisSize, constOne);
auto boolTy = rewriter.getType<Torch::ValueTensorType>(
indicesTy.getSizes(), rewriter.getI1Type());
Value lt = rewriter.create<Torch::AtenLtScalarOp>(
binder.getLoc(), boolTy, indices, constZero);
indices = rewriter.create<Torch::AtenWhereSelfOp>(
binder.getLoc(), indicesTy, lt, indicesAdd, indices);
Value sparseGrad = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false));

View File

@ -183,8 +183,13 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1
// CHECK-LABEL: func.func @test_gather_elements
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
// CHECK-DAG: %[[DIM:.+]] = torch.aten.size.int %arg0, %[[INT0]]
// CHECK-DAG: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[DIM]], %[[ONE]]
// CHECK-DAG: %[[LT:.+]] = torch.aten.lt.Scalar %arg1, %[[INT0]]
// CHECK-DAG: %[[WHERE:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0]], %arg1, %[[FALSE]]
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0]], %[[WHERE]], %[[FALSE]]
%0 = torch.operator "onnx.GatherElements"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}