mirror of https://github.com/llvm/torch-mlir
[onnx] Fix onnx.ReduceMean lowering (#3002)
Reduce mean lowerings did not succesfully lower to `linalg` via torched. There were two separate paths that could be consolidated to a single simpler pass. This resulted in a significant improvement in test coverage.pull/2920/head
parent
229ca3a9e1
commit
8fb28661f9
|
@ -845,157 +845,96 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
/*dtype=*/noneVal);
|
||||
return success();
|
||||
});
|
||||
// onnx.ReduceMean with axes provided as argument introduced in opset 18
|
||||
patterns.onOp(
|
||||
"ReduceMean", 18,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
Value axes;
|
||||
int64_t keepDims;
|
||||
int64_t noop_with_empty_axes;
|
||||
if (binder.tensorOperands(data, axes) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||
0))
|
||||
return failure();
|
||||
Torch::BaseTensorType axesType =
|
||||
axes.getType().cast<Torch::BaseTensorType>();
|
||||
SmallVector<Value> dimList;
|
||||
SmallVector<int64_t> selectSizes;
|
||||
selectSizes.push_back(1);
|
||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
||||
auto sizes =
|
||||
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
// deal with case when axes is empty
|
||||
if (sizes.size() == 1 && sizes[0] == 0) {
|
||||
if (noop_with_empty_axes == 0) {
|
||||
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
||||
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||
binder.getLoc(), keepDimsConstInt);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
||||
binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool,
|
||||
/*dtype=*/noneVal);
|
||||
} else {
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
int64_t adjustmentInt =
|
||||
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
||||
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
adjustmentInt));
|
||||
// convert axes (tensor) into torch int list while dealing with neg axis
|
||||
for (int i = 0; i < sizes[0]; i++) {
|
||||
// Go through the axes list and get each dim in the list
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
// deal with neg axis: if (axis < 0) axis += rank
|
||||
Value isNegative =
|
||||
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
||||
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||
isNegative);
|
||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), isNegative, adjustment);
|
||||
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), dim, finalOffset);
|
||||
dimList.push_back(finalDim);
|
||||
}
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
dimList);
|
||||
Value keepDimBool;
|
||||
if (keepDims == 1) {
|
||||
keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
} else {
|
||||
keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
||||
binder.op, resultType, data, dimValueList, keepDimBool,
|
||||
/*dtype=*/noneVal);
|
||||
return success();
|
||||
});
|
||||
|
||||
// onnx.ReduceMean with axes provided as attribute
|
||||
patterns.onOp(
|
||||
"ReduceMean", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data;
|
||||
llvm::SmallVector<int64_t> axes;
|
||||
int64_t keepDims;
|
||||
int64_t noop_with_empty_axes;
|
||||
if (binder.tensorOperand(data) || binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerArrayAttr(axes, "axes", 0) ||
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||
0))
|
||||
return failure();
|
||||
SmallVector<Value> dimList;
|
||||
SmallVector<int64_t> selectSizes;
|
||||
selectSizes.push_back(1);
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
// deal with case when axes is empty
|
||||
if (axes.size() == 0) {
|
||||
if (noop_with_empty_axes == 0) {
|
||||
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
||||
SmallVector<Value> axesList;
|
||||
|
||||
Value axesVal;
|
||||
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
||||
Torch::BaseTensorType axesType =
|
||||
axesVal.getType().cast<Torch::BaseTensorType>();
|
||||
SmallVector<Value> dimList;
|
||||
SmallVector<int64_t> selectSizes{1};
|
||||
auto selType = rewriter.getType<Torch::ValueTensorType>(
|
||||
selectSizes, axesType.getOptionalDtype());
|
||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||
auto axesShape = axesTy.getSizes();
|
||||
|
||||
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
||||
return failure();
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
int64_t numAxes = axesShape[0];
|
||||
for (int64_t i = 0; i < numAxes; ++i) {
|
||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
||||
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||
binder.getLoc(), keepDimsConstInt);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
||||
binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool,
|
||||
/*dtype=*/noneVal);
|
||||
} else {
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
rewriter.getI64IntegerAttr(i));
|
||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selType, axesVal, zero, iv);
|
||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||
axesList.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t> axesInts;
|
||||
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
|
||||
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
|
||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(axesInts[i]));
|
||||
axesList.push_back(iv);
|
||||
}
|
||||
}
|
||||
|
||||
// deal with case when axes is empty
|
||||
if (axesList.empty() && noop_with_empty_axes) {
|
||||
rewriter.replaceOp(binder.op, data);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
int64_t adjustmentInt =
|
||||
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
||||
// convert axes (tensor) into torch int list while dealing with neg axis
|
||||
for (uint64_t i = 0; i < axes.size(); i++) {
|
||||
// Go through the axes list and get each dim in the list
|
||||
int64_t dim = axes[i];
|
||||
if (dim < 0) {
|
||||
dim += adjustmentInt;
|
||||
}
|
||||
// deal with neg axis: if (axis < 0) axis += rank
|
||||
Value finalDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim));
|
||||
dimList.push_back(finalDim);
|
||||
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(adjustmentInt));
|
||||
|
||||
// Handle if the axes value is less than zero:
|
||||
for (int i = 0, s = axesList.size(); i < s; i++) {
|
||||
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
|
||||
binder.getLoc(), axesList[i], zero);
|
||||
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||
isNegative);
|
||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), isNegative, adjustment);
|
||||
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), axesList[i], finalOffset);
|
||||
axesList[i] = finalDim;
|
||||
}
|
||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
dimList);
|
||||
Value keepDimBool;
|
||||
if (keepDims == 1) {
|
||||
keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
} else {
|
||||
keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
}
|
||||
axesList);
|
||||
Value keepDimBool =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
||||
binder.op, resultType, data, dimValueList, keepDimBool,
|
||||
/*dtype=*/noneVal);
|
||||
|
|
|
@ -1474,15 +1474,7 @@ LTC_XFAIL_SET = {
|
|||
|
||||
ONNX_XFAIL_SET = {
|
||||
# Failure - cast error
|
||||
"MeanDimNoneDimModule_basic",
|
||||
"MeanDtypeModule_basic",
|
||||
"MeanDynamicSizesModule_basic",
|
||||
"MeanModule_basic",
|
||||
"MseLossMeanReductionModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"StdBiasedModule_basic",
|
||||
"VarBiasedModule_basic",
|
||||
"VarMeanBiasedModule_basic",
|
||||
|
||||
# Failure - incorrect numerics
|
||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||
|
@ -1992,17 +1984,7 @@ ONNX_XFAIL_SET = {
|
|||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"ReduceProdDimIntFloatModule_basic",
|
||||
"StdCorrectionLargeInputModule_basic",
|
||||
"StdCorrectionModule_basic",
|
||||
"StdCorrectionNoneModule_basic",
|
||||
"StdDimNoneDimModule_basic",
|
||||
"StdUnbiasedModule_basic",
|
||||
"VarCorrectionLargeInputModule_basic",
|
||||
"VarCorrectionModule_basic",
|
||||
"VarCorrectionNoneModule_basic",
|
||||
"VarDimNoneDimModule_basic",
|
||||
"VarMeanCorrectionNoneModule_basic",
|
||||
"VarMeanUnbiasedModule_basic",
|
||||
"VarUnbiasedModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.ReduceSum
|
||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||
|
|
|
@ -969,77 +969,71 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_mean_default_axes_keepdims_example
|
||||
func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.mean.dim %arg0, %[[NONE]], %0, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||
%0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||
// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example
|
||||
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIM:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[A0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SEL0:.+]] = torch.aten.select.int %[[TENSOR]], %[[DIM]], %[[A0]]
|
||||
// CHECK: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]]
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[RANK:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[LT0:.+]] = torch.aten.lt.int %[[ITEM0]], %[[ZERO]]
|
||||
// CHECK: %[[BOOL0:.+]] = torch.aten.Int.bool %[[LT0]]
|
||||
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[BOOL0]], %[[RANK]]
|
||||
// CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ITEM0]], %[[MUL0]]
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD0]]
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[TRUE]], %[[NONE]]
|
||||
// CHECK: return %[[SUM]]
|
||||
%cst = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
%0 = torch.operator "onnx.ReduceMean"(%arg0, %cst) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example
|
||||
func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: torch.aten.mean.dim %arg0, %6, %false, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||
// CHECK-LABEL: @test_reduce_mean_one_axes_dropdims_example
|
||||
func.func @test_reduce_mean_one_axes_dropdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
|
||||
// CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIM:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[A0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SEL0:.+]] = torch.aten.select.int %[[TENSOR]], %[[DIM]], %[[A0]]
|
||||
// CHECK: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]]
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[RANK:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[LT0:.+]] = torch.aten.lt.int %[[ITEM0]], %[[ZERO]]
|
||||
// CHECK: %[[BOOL0:.+]] = torch.aten.Int.bool %[[LT0]]
|
||||
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[BOOL0]], %[[RANK]]
|
||||
// CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ITEM0]], %[[MUL0]]
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD0]]
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[FALSE]], %[[NONE]]
|
||||
// CHECK: return %[[SUM]]
|
||||
%cst = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
%0 = torch.operator "onnx.ReduceMean"(%arg0, %cst) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||
return %0 : !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_mean_keepdims_example
|
||||
func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-LABEL: @test_reduce_mean_one_axesattr_dropdims_example
|
||||
func.func @test_reduce_mean_one_axesattr_dropdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: torch.aten.mean.dim %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example
|
||||
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]]
|
||||
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]]
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[INT3]]
|
||||
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]]
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]]
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||
// CHECK: torch.aten.mean.dim %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||
%0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||
// CHECK: %[[MEAN:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[FALSE]], %[[NONE]]
|
||||
// CHECK: return %[[MEAN]]
|
||||
%0 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes = [1 : si64]} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32>
|
||||
return %0 : !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -1387,11 +1381,11 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3
|
|||
// CHECK: %[[ZERO0:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[ZERO1:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ZERO1]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32>
|
||||
|
||||
|
|
Loading…
Reference in New Issue