mirror of https://github.com/llvm/torch-mlir
[onnx] Handle negative indices for `onnx.GatherElements` (#3599)
Add a check for negative indices and offset appropriately for `onnx.GatherElements`.pull/3607/head
parent
78d0fa8998
commit
b48e55c2f7
|
@ -1605,6 +1605,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
||||
|
||||
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value axisSize = rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(),
|
||||
data, constAxis);
|
||||
Value indicesAdd = rewriter.create<Torch::AtenAddScalarOp>(
|
||||
binder.getLoc(), indicesTy, indices, axisSize, constOne);
|
||||
|
||||
auto boolTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
indicesTy.getSizes(), rewriter.getI1Type());
|
||||
Value lt = rewriter.create<Torch::AtenLtScalarOp>(
|
||||
binder.getLoc(), boolTy, indices, constZero);
|
||||
indices = rewriter.create<Torch::AtenWhereSelfOp>(
|
||||
binder.getLoc(), indicesTy, lt, indicesAdd, indices);
|
||||
|
||||
Value sparseGrad = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
|
|
@ -183,8 +183,13 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1
|
|||
// CHECK-LABEL: func.func @test_gather_elements
|
||||
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[DIM:.+]] = torch.aten.size.int %arg0, %[[INT0]]
|
||||
// CHECK-DAG: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[DIM]], %[[ONE]]
|
||||
// CHECK-DAG: %[[LT:.+]] = torch.aten.lt.Scalar %arg1, %[[INT0]]
|
||||
// CHECK-DAG: %[[WHERE:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0]], %arg1, %[[FALSE]]
|
||||
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0]], %[[WHERE]], %[[FALSE]]
|
||||
%0 = torch.operator "onnx.GatherElements"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue