mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for Reduction Ops (#2657)
This commit adds the OnnxToTorch support for ReduceSum, ReduceMean, and ReduceMin ops.pull/2664/merge
parent
deacb8ef38
commit
698ff3a736
|
@ -459,6 +459,316 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"ReduceSum", 13,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value data;
|
||||||
|
Value axes;
|
||||||
|
int64_t keepDims;
|
||||||
|
int64_t noop_with_empty_axes;
|
||||||
|
if (binder.tensorOperands(data, axes) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||||
|
0))
|
||||||
|
return failure();
|
||||||
|
Torch::BaseTensorType axesType =
|
||||||
|
axes.getType().cast<Torch::BaseTensorType>();
|
||||||
|
SmallVector<Value> dimList;
|
||||||
|
SmallVector<int64_t> selectSizes;
|
||||||
|
selectSizes.push_back(1);
|
||||||
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||||
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
||||||
|
auto sizes =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
||||||
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
// Deal with case when axes is empty
|
||||||
|
if (sizes.size() == 1 && sizes[0] == 0) {
|
||||||
|
if (noop_with_empty_axes == 0) {
|
||||||
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
||||||
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||||
|
binder.getLoc(), keepDimsConstInt);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
||||||
|
binder.op, resultType, data, /*dim=*/noneVal,
|
||||||
|
/*keepdim=*/keepDimsBool, /*dtype=*/noneVal);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOp(binder.op, data);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
int64_t adjustmentInt =
|
||||||
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
||||||
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
adjustmentInt));
|
||||||
|
// convert axes (tensor) into torch int list while dealing with neg axis
|
||||||
|
for (int i = 0; i < sizes[0]; i++) {
|
||||||
|
// Go through the axes list and get each dim in the list
|
||||||
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||||
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||||
|
// deal with neg axis: if (axis < 0) axis += rank
|
||||||
|
Value isNegative =
|
||||||
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
||||||
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||||
|
isNegative);
|
||||||
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||||
|
binder.getLoc(), isNegative, adjustment);
|
||||||
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
||||||
|
binder.getLoc(), dim, finalOffset);
|
||||||
|
dimList.push_back(finalDim);
|
||||||
|
}
|
||||||
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
dimList);
|
||||||
|
Value keepDimBool;
|
||||||
|
if (keepDims == 1) {
|
||||||
|
keepDimBool =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||||
|
} else {
|
||||||
|
keepDimBool =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
||||||
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
||||||
|
/*dtype=*/noneVal);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"ReduceMean", 13,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value data;
|
||||||
|
Value axes;
|
||||||
|
int64_t keepDims;
|
||||||
|
int64_t noop_with_empty_axes;
|
||||||
|
if (binder.tensorOperands(data, axes) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||||
|
0))
|
||||||
|
return failure();
|
||||||
|
Torch::BaseTensorType axesType =
|
||||||
|
axes.getType().cast<Torch::BaseTensorType>();
|
||||||
|
SmallVector<Value> dimList;
|
||||||
|
SmallVector<int64_t> selectSizes;
|
||||||
|
selectSizes.push_back(1);
|
||||||
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||||
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
||||||
|
auto sizes =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
||||||
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
// deal with case when axes is empty
|
||||||
|
if (sizes.size() == 1 && sizes[0] == 0) {
|
||||||
|
if (noop_with_empty_axes == 0) {
|
||||||
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
||||||
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||||
|
binder.getLoc(), keepDimsConstInt);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
||||||
|
binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool,
|
||||||
|
/*dtype=*/noneVal);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOp(binder.op, data);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
int64_t adjustmentInt =
|
||||||
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
||||||
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
adjustmentInt));
|
||||||
|
// convert axes (tensor) into torch int list while dealing with neg axis
|
||||||
|
for (int i = 0; i < sizes[0]; i++) {
|
||||||
|
// Go through the axes list and get each dim in the list
|
||||||
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||||
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||||
|
// deal with neg axis: if (axis < 0) axis += rank
|
||||||
|
Value isNegative =
|
||||||
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
||||||
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||||
|
isNegative);
|
||||||
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||||
|
binder.getLoc(), isNegative, adjustment);
|
||||||
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
||||||
|
binder.getLoc(), dim, finalOffset);
|
||||||
|
dimList.push_back(finalDim);
|
||||||
|
}
|
||||||
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
dimList);
|
||||||
|
Value keepDimBool;
|
||||||
|
if (keepDims == 1) {
|
||||||
|
keepDimBool =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||||
|
} else {
|
||||||
|
keepDimBool =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
||||||
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
||||||
|
/*dtype=*/noneVal);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"ReduceMin", 13,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
// AtenAminOp allows us to pass a list of dims
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value data;
|
||||||
|
Value axes;
|
||||||
|
int64_t keepDims;
|
||||||
|
int64_t noop_with_empty_axes;
|
||||||
|
// Deal with case when no axes arg is passed
|
||||||
|
if (binder.op->getNumOperands() == 1) {
|
||||||
|
if (binder.tensorOperand(data) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
||||||
|
"noop_with_empty_axes", 0))
|
||||||
|
return failure();
|
||||||
|
if (noop_with_empty_axes == 0) {
|
||||||
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
||||||
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||||
|
binder.getLoc(), keepDimsConstInt);
|
||||||
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
||||||
|
.getSizes()
|
||||||
|
.size();
|
||||||
|
SmallVector<Value> axesList;
|
||||||
|
for (int i = 0; i < numDims; i++) {
|
||||||
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
axesList.push_back(curr);
|
||||||
|
}
|
||||||
|
Value axesValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(
|
||||||
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
|
axesList);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
||||||
|
binder.op, resultType, data, axesValueList, keepDimsBool);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOp(binder.op, data);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (binder.tensorOperands(data, axes) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||||
|
0))
|
||||||
|
return failure();
|
||||||
|
Torch::BaseTensorType axesType =
|
||||||
|
axes.getType().cast<Torch::BaseTensorType>();
|
||||||
|
SmallVector<Value> dimList;
|
||||||
|
SmallVector<int64_t> selectSizes;
|
||||||
|
selectSizes.push_back(1);
|
||||||
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||||
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
||||||
|
auto sizes =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
||||||
|
// deal with case when axes is empty
|
||||||
|
if (sizes.size() == 1 && sizes[0] == 0) {
|
||||||
|
if (noop_with_empty_axes == 0) {
|
||||||
|
// create dims list with all dims [0, data.getSizes().size())
|
||||||
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
||||||
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
||||||
|
binder.getLoc(), keepDimsConstInt);
|
||||||
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
||||||
|
.getSizes()
|
||||||
|
.size();
|
||||||
|
for (int i = 0; i < numDims; i++) {
|
||||||
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
dimList.push_back(curr);
|
||||||
|
}
|
||||||
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(
|
||||||
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
|
dimList);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
||||||
|
binder.op, resultType, data, dimValueList, keepDimsBool);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOp(binder.op, data);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
int64_t adjustmentInt =
|
||||||
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
||||||
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
adjustmentInt));
|
||||||
|
// convert axes (tensor) into torch int list while dealing with neg axis
|
||||||
|
for (int i = 0; i < sizes[0]; i++) {
|
||||||
|
// Go through the axes list and get each dim in the list
|
||||||
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||||
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||||
|
// deal with neg axis: if (axis < 0) axis += rank
|
||||||
|
Value isNegative =
|
||||||
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
||||||
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||||
|
isNegative);
|
||||||
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||||
|
binder.getLoc(), isNegative, adjustment);
|
||||||
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
||||||
|
binder.getLoc(), dim, finalOffset);
|
||||||
|
dimList.push_back(finalDim);
|
||||||
|
}
|
||||||
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
dimList);
|
||||||
|
Value keepDimBool;
|
||||||
|
if (keepDims == 1) {
|
||||||
|
keepDimBool =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||||
|
} else {
|
||||||
|
keepDimBool =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
||||||
|
binder.op, resultType, data, dimValueList, keepDimBool);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
|
||||||
patterns.onOp("Shape", 9,
|
patterns.onOp("Shape", 9,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
@ -550,7 +860,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(binder.op, operand);
|
rewriter.replaceOp(binder.op, operand);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,8 @@ func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor
|
||||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_relu
|
// CHECK-LABEL: func.func @test_relu
|
||||||
func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.relu %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
// CHECK: torch.aten.relu %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
@ -18,6 +20,8 @@ func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
||||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_round
|
// CHECK-LABEL: func.func @test_round
|
||||||
func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
//CHECK: torch.aten.round %arg0 : !torch.vtensor<[15],f32> -> !torch.vtensor<[15],f32>
|
//CHECK: torch.aten.round %arg0 : !torch.vtensor<[15],f32> -> !torch.vtensor<[15],f32>
|
||||||
|
@ -25,6 +29,8 @@ func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f3
|
||||||
return %0 : !torch.vtensor<[15],f32>
|
return %0 : !torch.vtensor<[15],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_scatter_elements_with_axis
|
// CHECK-LABEL: func.func @test_scatter_elements_with_axis
|
||||||
func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
@ -59,6 +65,8 @@ func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],
|
||||||
return %0 : !torch.vtensor<[1,5],f32>
|
return %0 : !torch.vtensor<[1,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_sigmoid_example
|
// CHECK-LABEL: func.func @test_sigmoid_example
|
||||||
func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.sigmoid %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
// CHECK: torch.aten.sigmoid %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
@ -66,6 +74,8 @@ func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtenso
|
||||||
return %0 : !torch.vtensor<[3],f32>
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_sin_example
|
// CHECK-LABEL: func.func @test_sin_example
|
||||||
func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.sin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
// CHECK: torch.aten.sin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
@ -73,6 +83,8 @@ func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3
|
||||||
return %0 : !torch.vtensor<[3],f32>
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_tanh_example
|
// CHECK-LABEL: func.func @test_tanh_example
|
||||||
func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.tanh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
// CHECK: torch.aten.tanh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
@ -80,6 +92,8 @@ func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[
|
||||||
return %0 : !torch.vtensor<[3],f32>
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_sqrt_example
|
// CHECK-LABEL: func.func @test_sqrt_example
|
||||||
func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.sqrt %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
// CHECK: torch.aten.sqrt %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
@ -87,6 +101,8 @@ func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[
|
||||||
return %0 : !torch.vtensor<[3],f32>
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_sub_bcast
|
// CHECK-LABEL: func.func @test_sub_bcast
|
||||||
func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
@ -119,6 +135,8 @@ func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten
|
||||||
return %0 : !torch.vtensor<[3,4,5],ui8>
|
return %0 : !torch.vtensor<[3,4,5],ui8>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_sum_example
|
// CHECK-LABEL: func.func @test_sum_example
|
||||||
func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
@ -143,6 +161,8 @@ func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vte
|
||||||
return %0 : !torch.vtensor<[3],f32>
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_where_example
|
// CHECK-LABEL: func.func @test_where_example
|
||||||
func.func @test_where_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torch.vtensor<[2,2],f32>, %arg2: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_where_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torch.vtensor<[2,2],f32>, %arg2: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
// CHECK: torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32>
|
||||||
|
@ -157,6 +177,8 @@ func.func @test_where_long_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torc
|
||||||
return %0 : !torch.vtensor<[2,2],si64>
|
return %0 : !torch.vtensor<[2,2],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_xor2d
|
// CHECK-LABEL: func.func @test_xor2d
|
||||||
func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>
|
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>
|
||||||
|
@ -192,6 +214,8 @@ func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch
|
||||||
return %0 : !torch.vtensor<[3,4,5,6],i1>
|
return %0 : !torch.vtensor<[3,4,5,6],i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_squeeze
|
// CHECK-LABEL: func.func @test_squeeze
|
||||||
func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
@ -233,6 +257,8 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1:
|
||||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_unsqueeze_axis_0
|
// CHECK-LABEL: func.func @test_unsqueeze_axis_0
|
||||||
func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
@ -421,6 +447,8 @@ func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg
|
||||||
return %0 : !torch.vtensor<[3,4,1,5,1,1],f32>
|
return %0 : !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_softmax_axis_0
|
// CHECK-LABEL: func.func @test_softmax_axis_0
|
||||||
func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
@ -489,6 +517,275 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example
|
||||||
|
func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.sum.dim_IntList %arg0, %none, %0, %none : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_do_not_keepdims_example
|
||||||
|
func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: torch.aten.sum.dim_IntList %arg0, %6, %false, %none : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example
|
||||||
|
func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
%0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_empty_set_non_reduced_axis_zero
|
||||||
|
func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,0,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_keepdims_example
|
||||||
|
func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_sum_negative_axes_keepdims_example
|
||||||
|
func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_mean_default_axes_keepdims_example
|
||||||
|
func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.mean.dim %arg0, %none, %0, %none : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example
|
||||||
|
func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: torch.aten.mean.dim %arg0, %6, %false, %none : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_mean_keepdims_example
|
||||||
|
func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.mean.dim %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example
|
||||||
|
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.mean.dim %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs
|
||||||
|
func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,1],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_min_default_axes_keepdims_example
|
||||||
|
func.func @test_reduce_min_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: torch.prim.ListConstruct %int0, %int1_0, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: torch.aten.amin %arg0, %1, %0 : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_min_do_not_keepdims_example
|
||||||
|
func.func @test_reduce_min_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: torch.aten.amin %arg0, %6, %false : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_min_empty_set
|
||||||
|
func.func @test_reduce_min_empty_set(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,1,4],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,1,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_min_keepdims_example
|
||||||
|
func.func @test_reduce_min_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_min_negative_axes_keepdims_example
|
||||||
|
func.func @test_reduce_min_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
||||||
|
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
|
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_sinh
|
// CHECK-LABEL: func.func @test_sinh
|
||||||
func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} {
|
func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} {
|
||||||
// CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
// CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
|
Loading…
Reference in New Issue