diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6d2fb8153..3637f7f35 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -459,6 +459,316 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); + patterns.onOp( + "ReduceSum", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + Value noneVal = rewriter.create(binder.getLoc()); + // Deal with case when axes is empty + if (sizes.size() == 1 && sizes[0] == 0) { + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, /*dim=*/noneVal, + /*keepdim=*/keepDimsBool, /*dtype=*/noneVal); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); + patterns.onOp( + "ReduceMean", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + Value noneVal = rewriter.create(binder.getLoc()); + // deal with case when axes is empty + if (sizes.size() == 1 && sizes[0] == 0) { + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool, + /*dtype=*/noneVal); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); + patterns.onOp( + "ReduceMin", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // AtenAminOp allows us to pass a list of dims + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + // Deal with case when no axes arg is passed + if (binder.op->getNumOperands() == 1) { + if (binder.tensorOperand(data) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + SmallVector axesList; + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + Value axesValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + axesList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, axesValueList, keepDimsBool); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + // deal with case when axes is empty + if (sizes.size() == 1 && sizes[0] == 0) { + if (noop_with_empty_axes == 0) { + // create dims list with all dims [0, data.getSizes().size()) + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + dimList.push_back(curr); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimsBool); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool); + return success(); + }); patterns.onOp("Shape", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -550,7 +860,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } rewriter.replaceOp(binder.op, operand); - return success(); }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index e4c95fe2b..da2a5c44a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -11,6 +11,8 @@ func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_relu func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.relu %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -18,6 +20,8 @@ func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_round func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { //CHECK: torch.aten.round %arg0 : !torch.vtensor<[15],f32> -> !torch.vtensor<[15],f32> @@ -25,6 +29,8 @@ func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f3 return %0 : !torch.vtensor<[15],f32> } +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_with_axis func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -59,6 +65,8 @@ func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5], return %0 : !torch.vtensor<[1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_sigmoid_example func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.sigmoid %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -66,6 +74,8 @@ func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtenso return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sin_example func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.sin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -73,6 +83,8 @@ func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_tanh_example func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.tanh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -80,6 +92,8 @@ func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sqrt_example func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.sqrt %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -87,6 +101,8 @@ func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sub_bcast func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -119,6 +135,8 @@ func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],ui8> } +// ----- + // CHECK-LABEL: func.func @test_sum_example func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -143,6 +161,8 @@ func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vte return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_where_example func.func @test_where_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torch.vtensor<[2,2],f32>, %arg2: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> @@ -157,6 +177,8 @@ func.func @test_where_long_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torc return %0 : !torch.vtensor<[2,2],si64> } +// ----- + // CHECK-LABEL: func.func @test_xor2d func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> @@ -192,6 +214,8 @@ func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch return %0 : !torch.vtensor<[3,4,5,6],i1> } +// ----- + // CHECK-LABEL: func.func @test_squeeze func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -233,6 +257,8 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_axis_0 func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -421,6 +447,8 @@ func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_axis_0 func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -489,6 +517,275 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // ----- +// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example +func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool + // CHECK: torch.aten.sum.dim_IntList %arg0, %none, %0, %none : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_do_not_keepdims_example +func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %false, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example +func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> + return %0 : !torch.vtensor<[3,2,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_empty_set_non_reduced_axis_zero +func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> + return %0 : !torch.vtensor<[2,0,1],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_keepdims_example +func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_negative_axes_keepdims_example +func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_mean_default_axes_keepdims_example +func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool + // CHECK: torch.aten.mean.dim %arg0, %none, %0, %none : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example +func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: torch.aten.mean.dim %arg0, %6, %false, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_mean_keepdims_example +func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.mean.dim %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example +func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.mean.dim %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_bool_inputs +func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> + return %0 : !torch.vtensor<[4,1],i1> +} + +// CHECK-LABEL: func.func @test_reduce_min_default_axes_keepdims_example +func.func @test_reduce_min_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.prim.ListConstruct %int0, %int1_0, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.amin %arg0, %1, %0 : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// CHECK-LABEL: func.func @test_reduce_min_do_not_keepdims_example +func.func @test_reduce_min_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: torch.aten.amin %arg0, %6, %false : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_min_empty_set +func.func @test_reduce_min_empty_set(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,1,4],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> + return %0 : !torch.vtensor<[2,1,4],f32> +} + +// CHECK-LABEL: func.func @test_reduce_min_keepdims_example +func.func @test_reduce_min_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_min_negative_axes_keepdims_example +func.func @test_reduce_min_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_sinh func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>