mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx.Shape` to include `start` and `end` processing (#3580)
`onnx.Shape` can select only a subset of indices using attributes. Add support for these attributes. --------- Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>pull/3595/head
parent
839fe90f86
commit
b1a232222f
|
@ -1615,17 +1615,48 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp("Shape", 9,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
|
||||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Shape", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
int64_t start, end;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(start, "start", 0) ||
|
||||
binder.s64IntegerAttr(end, "end", -1))
|
||||
return failure();
|
||||
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(operand.getType());
|
||||
int64_t inputRank = inputType.getSizes().size();
|
||||
|
||||
auto shapeType = Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), SmallVector<int64_t>{inputRank},
|
||||
resultType.getOptionalDtype());
|
||||
|
||||
Value shape = rewriter.create<Torch::Aten_ShapeAsTensorOp>(
|
||||
binder.getLoc(), shapeType, operand);
|
||||
|
||||
if (start == 0 && end == -1) {
|
||||
rewriter.replaceOp(binder.op, shape);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value sv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(start));
|
||||
|
||||
Value ev = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(end));
|
||||
|
||||
Value step = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 1);
|
||||
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), 0);
|
||||
|
||||
shape = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(), resultType, shape, dim, sv, ev, step);
|
||||
|
||||
rewriter.replaceOp(binder.op, shape);
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp("Sinh", 9,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -2715,6 +2715,21 @@ func.func @test_sequence_map_extract_shapes(%arg0: !torch.list<vtensor<[?,?,?],f
|
|||
return %0 : !torch.list<vtensor<[3],si64>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_shape_start_1_end_negative_1
|
||||
func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64} {
|
||||
// CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0
|
||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2_0:.+]] = torch.constant.int -1
|
||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[SHAPE]], %[[INT0_0]], %[[INT1_0]], %[[INT2_0]], %[[INT1_1]]
|
||||
%0 = torch.operator "onnx.Shape"(%arg0) {torch.onnx.end = -1 : si64, torch.onnx.start = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64>
|
||||
return %0 : !torch.vtensor<[1],si64>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_upsample_nearest
|
||||
|
@ -3133,7 +3148,7 @@ func.func @test_scatternd_min(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.
|
|||
return %0 : !torch.vtensor<[4,4,4],f32>
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_split_to_sequence_1
|
||||
func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.list<vtensor<[3,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
|
@ -3151,7 +3166,7 @@ func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !to
|
|||
return %1 : !torch.list<vtensor<[3,6],f32>>
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_split_to_sequence_2
|
||||
func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.list<vtensor<[1,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
|
@ -3169,7 +3184,7 @@ func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !to
|
|||
return %1 : !torch.list<vtensor<[1,6],f32>>
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_split_to_sequence_with_list(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,6],f32>,
|
||||
|
|
Loading…
Reference in New Issue