diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 9550e982b..e39c42b50 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -308,12 +308,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "kernel list size does not match the number of axes"); } - if (binder.s64IntegerArrayAttr(padding, "pads", {0})) { + SmallVector defaultPadding(2 * (rank - 2), 0); + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { return failure(); } - if (padding.size() != 1 && padding.size() != rank - 2) { + if (padding.size() != 2 * (rank - 2)) { return rewriter.notifyMatchFailure( - binder.op, "padding list size does not match the number of axes"); + binder.op, + "padding list size does not match twice the number of axes"); } if (binder.s64IntegerArrayAttr(strides, "strides", {1})) { return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 5ebba10c9..1760a0a20 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( patterns.onOp( "Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; - Value data, pads, constantValue, axes; + Value data, pads, axes; std::string mode; // TODO: The `axes` parameter is not supported yet. @@ -871,12 +871,41 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(pads, 1) || - binder.tensorOperandAtIndex(constantValue, 2) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(mode, "mode", "constant")) return failure(); Location loc = binder.getLoc(); + Value constantValue; + if (binder.getNumOperands() >= 3) { + if (binder.tensorOperandAtIndex(constantValue, 2)) { + llvm::errs() << "failed to bind to index 2\n"; + return failure(); + } + } else { + auto dataTensorType = data.getType().cast(); + + auto maybeZeroAttr = [&]() -> std::optional { + if (dataTensorType.getDtype().isa()) { + return rewriter.getI64IntegerAttr(0); + } + if (dataTensorType.getDtype().isa()) { + return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f); + } + return std::nullopt; + }(); + + if (!maybeZeroAttr) { + return rewriter.notifyMatchFailure( + binder.op, "expected integer or float data tensor"); + } + + auto shapedType = dataTensorType.toBuiltinTensor(); + auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr); + constantValue = rewriter.create( + loc, dataTensorType, splat); + } + // Get pads shape and rank. The pads tensor is expected to be 1-D // tensor. auto padsTensorType = pads.getType().cast(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 764cfc247..8e46aa9ec 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1531,18 +1531,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); } } else { - // The default axes value is the range from 0 to the number of - // dimensions + // The default axes value is the range from 0 to the size of first + // dimension of `starts` and `ends`. Value none = rewriter.create(loc); - auto defaultAxesType = Torch::ValueTensorType::get( - context, ArrayRef{operandTy.getRank()}, - rewriter.getIntegerType(64, /*signed*/ 1)); Value arangeLength = rewriter.create( loc, rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - operandTy.getRank())); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize)); axes = rewriter.create( - loc, defaultAxesType, arangeLength, none, none, none, none); + loc, startsTorchTy, arangeLength, none, none, none, none); } // Binding `steps` from its arguments or through a default value @@ -1553,22 +1549,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } } else { // The default `steps` value is a 1d tensor filled with ones with a - // size of the dimension of the operand + // size equal to the size of `starts` and `ends`. Value none = rewriter.create(loc); - auto defaultStepsType = Torch::ValueTensorType::get( - context, ArrayRef{operandTy.getRank()}, - rewriter.getIntegerType(64, /*signed*/ 1)); Value sizeStepInput = rewriter.create( loc, rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - operandTy.getRank())); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize)); Value sizeStepsInput = rewriter.create( loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), sizeStepInput); steps = rewriter.create( - loc, defaultStepsType, sizeStepsInput, none, none, none, none); + loc, startsTorchTy, sizeStepsInput, none, none, none, none); } if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 && diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index a71d4e428..2ee21c1e3 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -699,13 +699,25 @@ func.func @test_averagepool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !to // CHECK-LABEL: @test_averagepool_3d_default func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false_2, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32> + // CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false{{.*}}, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32> %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> return %0 : !torch.vtensor<[1,3,31,31,31],f32> } // ----- +// CHECK-LABEL: @test_averagepool_with_padding +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,20,64,48],f32> +// CHECK: torch.aten.avg_pool2d %[[ARG]], {{.*}} : !torch.vtensor<[1,20,64,48],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,20,32,24],f32> + +func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32>) -> !torch.vtensor<[1,20,32,24],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 19 : si64} { + + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,20,64,48],f32>) -> !torch.vtensor<[1,20,32,24],f32> + return %0 : !torch.vtensor<[1,20,32,24],f32> +} + +// ----- + // CHECK-LABEL: @test_conv_with_strides_no_padding func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 92fbe86ca..bbef289ff 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -447,6 +447,24 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], // ----- +// CHECK-LABEL: @test_pad_optional_constant +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> +// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant" +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[SEVEN:.*]] = torch.constant.int 7 +// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %0, %[[SEVEN]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64> +// CHECK: %[[ITEM:.*]] = torch.aten.item %[[DTYPE]] : !torch.vtensor<[],f64> -> !torch.float +// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> + +func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_pow func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 8a5d5b1ef..9f2354d13 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1205,6 +1205,31 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 // ----- +// CHECK-LABEL: @test_slice_default_axes_and_steps +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[20,10,5],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1],si64>, +// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[1],si64> + +// CHECK: %[[ZERO0:.*]] = torch.constant.int 0 +// CHECK: %[[ZERO1:.*]] = torch.constant.int 0 +// CHECK: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64> +// CHECK: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: %[[SELECT2:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT2]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ITEM2]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + +func.func @test_slice_default_axes_and_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_slice_default_steps func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { //CHECK: %[[NONE:.*]] = torch.constant.none