[MLIR][TORCH] Fix OnnxToLinalg lowering issue for sub and sum op (#2954)

This commit adds the support for scalar conversion to byte. 
This commit also fixes the OnnxToLinalg lowering issue for Onnx.Sub and
Onnx.Sum op.
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/466 
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/467

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/2968/head
Vivek Khandelwal 2024-02-29 21:48:46 +05:30 committed by GitHub
parent 76b81e0ccd
commit 579ac8b666
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 104 additions and 14 deletions

View File

@ -88,7 +88,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype = std::nullopt,
std::optional<Type> dstOriginalDtype = std::nullopt);
std::optional<Type> dstOriginalDtype = std::nullopt,
std::optional<Value> originalScalar = std::nullopt);
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,

View File

@ -489,8 +489,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success();
}
// When binder.op->getNumOperands() > 2
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
binder.op->getContext());
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, valList[0], valList[1], const1);
for (int i = 2; i < numOperands; i++) {
@ -498,6 +496,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, curr, valList[i], const1);
} else {
SmallVector<int64_t> resultBroadcastShapeInt;
SmallVector<Value> resultBroadcastShapeValue;
Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr,
valList[i], resultBroadcastShapeInt,
resultBroadcastShapeValue);
auto baseType = Torch::ValueTensorType::get(
binder.op->getContext(), resultBroadcastShapeInt,
resultType.getOptionalDtype());
curr = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), baseType, curr, valList[i], const1);
}

View File

@ -645,7 +645,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
/*dstOriginalDtype=*/resultElementType);
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
/*dstOriginalDtype=*/resultElementType,
/*originalScalar=*/sub.getAlpha());
if (dtype.isa<mlir::FloatType>()) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::SubFOp>(loc, lhs, scaled);

View File

@ -245,12 +245,20 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
elementType, encoding);
}
static std::optional<int64_t> getIntegerValue(Value scalar) {
if (auto constOp = scalar.getDefiningOp<Torch::ConstantIntOp>()) {
return std::optional<int64_t>(constOp.getValue());
}
return std::optional<int64_t>();
}
// Convert a scalar value to the target type. The scalar value can be an element
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype,
std::optional<Type> dstOriginalDtype) {
std::optional<Type> dstOriginalDtype,
std::optional<Value> originalScalar) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
@ -262,7 +270,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return false;
};
// We don't support conversion to Byte dtype.
// We support conversion to Byte dtype only if the original scalar is an
// integer constant with value lying between 0 - 63.
if (isByteOrChar(dtype)) {
if (!dstOriginalDtype.has_value()) {
mlir::emitError(loc)
@ -271,12 +280,24 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
return nullptr;
}
if (dstOriginalDtype->isUnsignedInteger()) {
if (originalScalar.has_value()) {
std::optional<int64_t> optConstVal =
getIntegerValue(originalScalar.value());
if (optConstVal.has_value()) {
int64_t constVal = optConstVal.value();
if (constVal < 0 || constVal > 63) {
// Do the conversion only if the original integer value is between
// 0 - 63.
mlir::emitError(loc)
<< "unsupported: conversion to byte type for convertScalarToDtype "
<< "unsupported: conversion to byte type for "
"convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
}
}
}
}
}
// If the dtype is i1, i.e., a boolean type.
if (dtype.isSignlessInteger(1)) {

View File

@ -223,6 +223,8 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
return %0 : !torch.vtensor<[1,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -232,6 +234,8 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1
return %0 : !torch.vtensor<[1,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_scatter_elements_without_axis
func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
@ -240,6 +244,8 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
return %0 : !torch.vtensor<[3,3],f32>
}
// -----
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -295,6 +301,8 @@ func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_sub_example
func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -303,6 +311,8 @@ func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtenso
return %0 : !torch.vtensor<[3],f32>
}
// -----
// CHECK-LABEL: func.func @test_sub
func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -311,6 +321,8 @@ func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_sub_uint8
func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -324,19 +336,23 @@ func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten
// CHECK-LABEL: func.func @test_sum_example
func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: torch.aten.add.Tensor %0, %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor
// CHECK: torch.aten.add.Tensor %1, %arg3, %int1 : !torch.vtensor, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[SUM:.*]] = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[SUM_1:.*]] = torch.aten.add.Tensor %[[SUM]], %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[SUM_2:.*]] = torch.aten.add.Tensor %[[SUM_1]], %arg3, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
// -----
// CHECK-LABEL: func.func @test_sum_one_input
func.func @test_sum_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.Sum"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
// -----
// CHECK-LABEL: func.func @test_sum_two_inputs
func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -370,6 +386,8 @@ func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,
return %0 : !torch.vtensor<[3,4],i1>
}
// -----
// CHECK-LABEL: func.func @test_xor3d
func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1> -> !torch.vtensor<[3,4,5],i1>
@ -377,6 +395,8 @@ func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[
return %0 : !torch.vtensor<[3,4,5],i1>
}
// -----
// CHECK-LABEL: func.func @test_xor4d
func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],i1>, !torch.vtensor<[3,4,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1>
@ -384,6 +404,8 @@ func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor
return %0 : !torch.vtensor<[3,4,5,6],i1>
}
// -----
// CHECK-LABEL: func.func @test_xor_bcast3v1d
func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1>
@ -391,6 +413,8 @@ func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.v
return %0 : !torch.vtensor<[3,4,5],i1>
}
// -----
// CHECK-LABEL: func.func @test_xor_bcast4v4d
func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch.vtensor<[3,1,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[1,4,1,6],i1>, !torch.vtensor<[3,1,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1>
@ -417,6 +441,8 @@ func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vten
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_squeeze_two_axes
func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
@ -467,6 +493,8 @@ func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
return %0 : !torch.vtensor<[1,3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_axis_1
func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
@ -491,6 +519,8 @@ func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
return %0 : !torch.vtensor<[3,1,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_axis_2
func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
@ -515,6 +545,8 @@ func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
return %0 : !torch.vtensor<[3,4,1,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_negative_axes
func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
@ -539,6 +571,8 @@ func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %a
return %0 : !torch.vtensor<[1,3,1,1,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_three_axes
func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
@ -585,6 +619,8 @@ func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1:
return %0 : !torch.vtensor<[3,4,1,5,1,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes
func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
@ -642,6 +678,8 @@ func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_softmax_axis_1
func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -651,6 +689,8 @@ func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_softmax_axis_2
func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -660,6 +700,8 @@ func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_softmax_default_axis
func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -669,6 +711,8 @@ func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !tor
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_softmax_large_number
func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
@ -678,6 +722,8 @@ func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch
return %0 : !torch.vtensor<[2,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_softmax_negative_axis
func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT2:.*]] = torch.constant.int 2
@ -773,6 +819,8 @@ func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[
return %0 : !torch.vtensor<[1,1,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_do_not_keepdims_example
func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none
@ -792,12 +840,16 @@ func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2]
return %0 : !torch.vtensor<[3,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example
func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32>
return %0 : !torch.vtensor<[3,2,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_empty_set_non_reduced_axis_zero
func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none
@ -817,6 +869,8 @@ func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor
return %0 : !torch.vtensor<[2,0,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_keepdims_example
func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none
@ -836,6 +890,8 @@ func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>,
return %0 : !torch.vtensor<[3,1,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_sum_negative_axes_keepdims_example
func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none
@ -867,6 +923,8 @@ func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor<
return %0 : !torch.vtensor<[1,1,1],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example
func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none
@ -886,6 +944,8 @@ func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2
return %0 : !torch.vtensor<[3,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_mean_keepdims_example
func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none
@ -905,6 +965,8 @@ func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>,
return %0 : !torch.vtensor<[3,1,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none
@ -958,7 +1020,6 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a
// -----
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs
func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[IDX:.+]] = torch.constant.int 0