diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 03cf60589..957700d1a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1615,17 +1615,48 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp("Shape", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Shape", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + int64_t start, end; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(start, "start", 0) || + binder.s64IntegerAttr(end, "end", -1)) + return failure(); + + auto inputType = dyn_cast(operand.getType()); + int64_t inputRank = inputType.getSizes().size(); + + auto shapeType = Torch::ValueTensorType::get( + binder.op->getContext(), SmallVector{inputRank}, + resultType.getOptionalDtype()); + + Value shape = rewriter.create( + binder.getLoc(), shapeType, operand); + + if (start == 0 && end == -1) { + rewriter.replaceOp(binder.op, shape); + return success(); + } + + Value sv = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(start)); + + Value ev = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(end)); + + Value step = rewriter.create(binder.getLoc(), 1); + + Value dim = rewriter.create(binder.getLoc(), 0); + + shape = rewriter.create( + binder.getLoc(), resultType, shape, dim, sv, ev, step); + + rewriter.replaceOp(binder.op, shape); + return success(); + }); patterns.onOp("Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 41e4391a8..e57cd605b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2715,6 +2715,21 @@ func.func @test_sequence_map_extract_shapes(%arg0: !torch.list> } +// ----- + +// CHECK-LABEL: func.func @test_shape_start_1_end_negative_1 +func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64} { + // CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT2_0:.+]] = torch.constant.int -1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[SHAPE]], %[[INT0_0]], %[[INT1_0]], %[[INT2_0]], %[[INT1_1]] + %0 = torch.operator "onnx.Shape"(%arg0) {torch.onnx.end = -1 : si64, torch.onnx.start = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + + // ----- // CHECK-LABEL: func.func @test_upsample_nearest @@ -3133,7 +3148,7 @@ func.func @test_scatternd_min(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch. return %0 : !torch.vtensor<[4,4,4],f32> } -// ---- +// ----- // CHECK-LABEL: func.func @test_split_to_sequence_1 func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -3151,7 +3166,7 @@ func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !to return %1 : !torch.list> } -// ---- +// ----- // CHECK-LABEL: func.func @test_split_to_sequence_2 func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -3169,7 +3184,7 @@ func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !to return %1 : !torch.list> } -// ---- +// ----- // CHECK-LABEL: func.func @test_split_to_sequence_with_list( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,6],f32>,