[onnx] Fix `onnx.Shape` to include `start` and `end` processing (#3580)

`onnx.Shape` can select only a subset of indices using attributes. Add
support for these attributes.

---------

Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
pull/3595/head
Rob Suderman 2024-08-05 13:56:07 -07:00 committed by GitHub
parent 839fe90f86
commit b1a232222f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 14 deletions

View File

@ -1615,15 +1615,46 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
});
patterns.onOp("Shape", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
patterns.onOp(
"Shape", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t start, end;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(start, "start", 0) ||
binder.s64IntegerAttr(end, "end", -1))
return failure();
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
binder.op, resultType, operand);
auto inputType = dyn_cast<Torch::ValueTensorType>(operand.getType());
int64_t inputRank = inputType.getSizes().size();
auto shapeType = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{inputRank},
resultType.getOptionalDtype());
Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
binder.getLoc(), shapeType, operand);
if (start == 0 && end == -1) {
rewriter.replaceOp(binder.op, shape);
return success();
}
Value sv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(start));
Value ev = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(end));
Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);
Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
shape = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), resultType, shape, dim, sv, ev, step);
rewriter.replaceOp(binder.op, shape);
return success();
});

View File

@ -2715,6 +2715,21 @@ func.func @test_sequence_map_extract_shapes(%arg0: !torch.list<vtensor<[?,?,?],f
return %0 : !torch.list<vtensor<[3],si64>>
}
// -----
// CHECK-LABEL: func.func @test_shape_start_1_end_negative_1
func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64} {
// CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[INT2_0:.+]] = torch.constant.int -1
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[SHAPE]], %[[INT0_0]], %[[INT1_0]], %[[INT2_0]], %[[INT1_1]]
%0 = torch.operator "onnx.Shape"(%arg0) {torch.onnx.end = -1 : si64, torch.onnx.start = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}
// -----
// CHECK-LABEL: func.func @test_upsample_nearest
@ -3133,7 +3148,7 @@ func.func @test_scatternd_min(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.
return %0 : !torch.vtensor<[4,4,4],f32>
}
// ----
// -----
// CHECK-LABEL: func.func @test_split_to_sequence_1
func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.list<vtensor<[3,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
@ -3151,7 +3166,7 @@ func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !to
return %1 : !torch.list<vtensor<[3,6],f32>>
}
// ----
// -----
// CHECK-LABEL: func.func @test_split_to_sequence_2
func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.list<vtensor<[1,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
@ -3169,7 +3184,7 @@ func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !to
return %1 : !torch.list<vtensor<[1,6],f32>>
}
// ----
// -----
// CHECK-LABEL: func.func @test_split_to_sequence_with_list(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,6],f32>,