From 8fb28661f9168c7b76a691125d8ebdff1732f920 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 11 Mar 2024 11:32:53 -0700 Subject: [PATCH] [onnx] Fix onnx.ReduceMean lowering (#3002) Reduce mean lowerings did not succesfully lower to `linalg` via torched. There were two separate paths that could be consolidated to a single simpler pass. This resulted in a significant improvement in test coverage. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 199 ++++++------------ projects/pt1/e2e_testing/xfail_sets.py | 18 -- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 120 +++++------ 3 files changed, 126 insertions(+), 211 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 34282bfef..b5e9162bc 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -845,157 +845,96 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*dtype=*/noneVal); return success(); }); - // onnx.ReduceMean with axes provided as argument introduced in opset 18 - patterns.onOp( - "ReduceMean", 18, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - Value axes; - int64_t keepDims; - int64_t noop_with_empty_axes; - if (binder.tensorOperands(data, axes) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - Value noneVal = rewriter.create(binder.getLoc()); - // deal with case when axes is empty - if (sizes.size() == 1 && sizes[0] == 0) { - if (noop_with_empty_axes == 0) { - Value keepDimsConstInt = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool, - /*dtype=*/noneVal); - } else { - rewriter.replaceOp(binder.op, data); - } - return success(); - } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - // convert axes (tensor) into torch int list while dealing with neg axis - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank - Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, adjustment); - Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); - } - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - Value keepDimBool; - if (keepDims == 1) { - keepDimBool = - rewriter.create(binder.getLoc(), true); - } else { - keepDimBool = - rewriter.create(binder.getLoc(), false); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/noneVal); - return success(); - }); - - // onnx.ReduceMean with axes provided as attribute patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; - llvm::SmallVector axes; int64_t keepDims; int64_t noop_with_empty_axes; - if (binder.tensorOperand(data) || binder.tensorResultType(resultType) || - binder.s64IntegerArrayAttr(axes, "axes", 0) || + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || binder.s64IntegerAttr(keepDims, "keepdims", 1) || binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", 0)) return failure(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Value noneVal = rewriter.create(binder.getLoc()); - // deal with case when axes is empty - if (axes.size() == 0) { - if (noop_with_empty_axes == 0) { - Value keepDimsConstInt = rewriter.create( + + SmallVector axesList; + + Value axesVal; + if (!binder.tensorOperandAtIndex(axesVal, 1)) { + Torch::BaseTensorType axesType = + axesVal.getType().cast(); + SmallVector dimList; + SmallVector selectSizes{1}; + auto selType = rewriter.getType( + selectSizes, axesType.getOptionalDtype()); + auto axesTy = dyn_cast(axesVal.getType()); + auto axesShape = axesTy.getSizes(); + + if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) + return failure(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + int64_t numAxes = axesShape[0]; + for (int64_t i = 0; i < numAxes; ++i) { + Value iv = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool, - /*dtype=*/noneVal); - } else { - rewriter.replaceOp(binder.op, data); + rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + binder.getLoc(), selType, axesVal, zero, iv); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); } + } + + SmallVector axesInts; + if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { + for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(axesInts[i])); + axesList.push_back(iv); + } + } + + // deal with case when axes is empty + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); return success(); } + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); int64_t adjustmentInt = cast(data.getType()).getSizes().size(); - // convert axes (tensor) into torch int list while dealing with neg axis - for (uint64_t i = 0; i < axes.size(); i++) { - // Go through the axes list and get each dim in the list - int64_t dim = axes[i]; - if (dim < 0) { - dim += adjustmentInt; - } - // deal with neg axis: if (axis < 0) axis += rank - Value finalDim = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); - dimList.push_back(finalDim); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(adjustmentInt)); + + // Handle if the axes value is less than zero: + for (int i = 0, s = axesList.size(); i < s; i++) { + Value isNegative = rewriter.create( + binder.getLoc(), axesList[i], zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), axesList[i], finalOffset); + axesList[i] = finalDim; } Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - Value keepDimBool; - if (keepDims == 1) { - keepDimBool = - rewriter.create(binder.getLoc(), true); - } else { - keepDimBool = - rewriter.create(binder.getLoc(), false); - } + axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + Value noneVal = rewriter.create(binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, data, dimValueList, keepDimBool, /*dtype=*/noneVal); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index dd4976018..428c19788 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1474,15 +1474,7 @@ LTC_XFAIL_SET = { ONNX_XFAIL_SET = { # Failure - cast error - "MeanDimNoneDimModule_basic", - "MeanDtypeModule_basic", - "MeanDynamicSizesModule_basic", - "MeanModule_basic", - "MseLossMeanReductionModule_basic", "PermuteNegativeIndexModule_basic", - "StdBiasedModule_basic", - "VarBiasedModule_basic", - "VarMeanBiasedModule_basic", # Failure - incorrect numerics "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", @@ -1992,17 +1984,7 @@ ONNX_XFAIL_SET = { "NativeDropoutTrainStaticShapeModule_basic", "ReduceProdDimIntFloatModule_basic", "StdCorrectionLargeInputModule_basic", - "StdCorrectionModule_basic", - "StdCorrectionNoneModule_basic", - "StdDimNoneDimModule_basic", - "StdUnbiasedModule_basic", "VarCorrectionLargeInputModule_basic", - "VarCorrectionModule_basic", - "VarCorrectionNoneModule_basic", - "VarDimNoneDimModule_basic", - "VarMeanCorrectionNoneModule_basic", - "VarMeanUnbiasedModule_basic", - "VarUnbiasedModule_basic", # Failure - onnx_lowering: onnx.ReduceSum "MseLossSumReductionWithDifferentElemTypeModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index bba74b6d9..508ed55d3 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -969,77 +969,71 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // ----- -// CHECK-LABEL: func.func @test_reduce_mean_default_axes_keepdims_example -func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool - // CHECK: torch.aten.mean.dim %arg0, %[[NONE]], %0, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> - %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> - return %0 : !torch.vtensor<[1,1,1],f32> +// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example +func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: %[[DIM:.+]] = torch.constant.int 0 + // CHECK: %[[A0:.+]] = torch.constant.int 0 + // CHECK: %[[SEL0:.+]] = torch.aten.select.int %[[TENSOR]], %[[DIM]], %[[A0]] + // CHECK: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[RANK:.+]] = torch.constant.int 3 + // CHECK: %[[LT0:.+]] = torch.aten.lt.int %[[ITEM0]], %[[ZERO]] + // CHECK: %[[BOOL0:.+]] = torch.aten.Int.bool %[[LT0]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[BOOL0]], %[[RANK]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ITEM0]], %[[MUL0]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD0]] + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[TRUE]], %[[NONE]] + // CHECK: return %[[SUM]] + %cst = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %cst) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> } // ----- -// CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example -func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: torch.aten.mean.dim %arg0, %6, %false, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> - %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> +// CHECK-LABEL: @test_reduce_mean_one_axes_dropdims_example +func.func @test_reduce_mean_one_axes_dropdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: %[[DIM:.+]] = torch.constant.int 0 + // CHECK: %[[A0:.+]] = torch.constant.int 0 + // CHECK: %[[SEL0:.+]] = torch.aten.select.int %[[TENSOR]], %[[DIM]], %[[A0]] + // CHECK: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[RANK:.+]] = torch.constant.int 3 + // CHECK: %[[LT0:.+]] = torch.aten.lt.int %[[ITEM0]], %[[ZERO]] + // CHECK: %[[BOOL0:.+]] = torch.aten.Int.bool %[[LT0]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[BOOL0]], %[[RANK]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ITEM0]], %[[MUL0]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD0]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[FALSE]], %[[NONE]] + // CHECK: return %[[SUM]] + %cst = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %cst) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> return %0 : !torch.vtensor<[3,2],f32> } - // ----- -// CHECK-LABEL: func.func @test_reduce_mean_keepdims_example -func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none +// CHECK-LABEL: @test_reduce_mean_one_axesattr_dropdims_example +func.func @test_reduce_mean_one_axesattr_dropdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.mean.dim %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> - %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> - return %0 : !torch.vtensor<[3,1,2],f32> -} - -// ----- - -// CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example -func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[INT3]] + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.mean.dim %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> - %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> - return %0 : !torch.vtensor<[3,1,2],f32> + // CHECK: %[[MEAN:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[FALSE]], %[[NONE]] + // CHECK: return %[[MEAN]] + %0 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes = [1 : si64]} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> } // ----- @@ -1387,11 +1381,11 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 // CHECK: %[[ZERO0:.*]] = torch.constant.int 0 // CHECK-NEXT: %[[ZERO1:.*]] = torch.constant.int 0 // CHECK-NEXT: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64> -// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> // CHECK-NEXT: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int -// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> // CHECK-NEXT: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int -// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> // CHECK-NEXT: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ZERO1]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32>