mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx.ScatterElements` for negative indices (#3582)
We need to adjust for negative scatter indice values. Added materializing out the inbounds adjustment.pull/3587/head
parent
306ed62edd
commit
3d33c5a206
|
@ -560,6 +560,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto loc = binder.getLoc();
|
||||
Value data = valList[0];
|
||||
Value indices = valList[1];
|
||||
Value updates = valList[2];
|
||||
|
@ -570,9 +571,33 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
||||
|
||||
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
loc, rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
|
||||
Value axisSize = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), data,
|
||||
constAxis);
|
||||
|
||||
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
|
||||
Value indicesAdd = rewriter.create<Torch::AtenAddScalarOp>(
|
||||
loc, indicesTy, indices, axisSize, one);
|
||||
|
||||
Value inputNeg = rewriter.create<Torch::AtenLtScalarOp>(
|
||||
loc,
|
||||
rewriter.getType<Torch::ValueTensorType>(indicesTy.getSizes(),
|
||||
rewriter.getI1Type()),
|
||||
indices, zero);
|
||||
|
||||
indices = rewriter.create<Torch::AtenWhereSelfOp>(
|
||||
loc, indicesTy, inputNeg, indicesAdd, indices);
|
||||
|
||||
if (reduction == "none") {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenScatterSrcOp>(
|
||||
binder.op, resultType, data, constAxis, indices, updates);
|
||||
|
|
|
@ -228,8 +228,14 @@ func.func @test_scatter(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<
|
|||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_with_axis
|
||||
func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.scatter.src %arg0, %int1, %arg1, %arg2 : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32> -> !torch.vtensor<[1,5],f32>
|
||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
|
||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
|
||||
// CHECK: torch.aten.scatter.src %arg0, %[[AXIS]], %[[WHERE]], %arg2
|
||||
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
||||
return %0 : !torch.vtensor<[1,5],f32>
|
||||
}
|
||||
|
@ -238,9 +244,15 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
|
|||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
|
||||
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
|
||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "add"
|
||||
// CHECK: torch.aten.scatter.reduce %arg0, %int1, %arg1, %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
|
||||
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
|
||||
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
||||
return %0 : !torch.vtensor<[1,5],f32>
|
||||
}
|
||||
|
@ -249,8 +261,14 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1
|
|||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_without_axis
|
||||
func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32>
|
||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
|
||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
|
||||
// CHECK: torch.aten.scatter.src %arg0, %[[AXIS]], %[[WHERE]], %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32>
|
||||
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3,3],f32>, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32>
|
||||
return %0 : !torch.vtensor<[3,3],f32>
|
||||
}
|
||||
|
@ -259,9 +277,15 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
|
|||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
|
||||
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
|
||||
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
|
||||
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "multiply"
|
||||
// CHECK: torch.aten.scatter.reduce %arg0, %int1, %arg1, %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
|
||||
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
|
||||
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
|
||||
return %0 : !torch.vtensor<[1,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue