mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix OnnxToLinalg lowering issue for sub and sum op (#2954)
This commit adds the support for scalar conversion to byte. This commit also fixes the OnnxToLinalg lowering issue for Onnx.Sub and Onnx.Sum op. Fixes https://github.com/nod-ai/SHARK-Turbine/issues/466 Fixes https://github.com/nod-ai/SHARK-Turbine/issues/467 Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/2968/head
parent
76b81e0ccd
commit
579ac8b666
|
@ -88,7 +88,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
|||
// should be converted builtin types.
|
||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||
std::optional<Type> srcOriginalDtype = std::nullopt,
|
||||
std::optional<Type> dstOriginalDtype = std::nullopt);
|
||||
std::optional<Type> dstOriginalDtype = std::nullopt,
|
||||
std::optional<Value> originalScalar = std::nullopt);
|
||||
|
||||
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value torchOptionalInt, Value builtinInt,
|
||||
|
|
|
@ -489,8 +489,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
return success();
|
||||
}
|
||||
// When binder.op->getNumOperands() > 2
|
||||
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
|
||||
binder.op->getContext());
|
||||
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
|
||||
binder.getLoc(), resultType, valList[0], valList[1], const1);
|
||||
for (int i = 2; i < numOperands; i++) {
|
||||
|
@ -498,6 +496,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
||||
binder.getLoc(), resultType, curr, valList[i], const1);
|
||||
} else {
|
||||
SmallVector<int64_t> resultBroadcastShapeInt;
|
||||
SmallVector<Value> resultBroadcastShapeValue;
|
||||
Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr,
|
||||
valList[i], resultBroadcastShapeInt,
|
||||
resultBroadcastShapeValue);
|
||||
auto baseType = Torch::ValueTensorType::get(
|
||||
binder.op->getContext(), resultBroadcastShapeInt,
|
||||
resultType.getOptionalDtype());
|
||||
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
||||
binder.getLoc(), baseType, curr, valList[i], const1);
|
||||
}
|
||||
|
|
|
@ -645,7 +645,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
/*dstOriginalDtype=*/resultElementType);
|
||||
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
/*dstOriginalDtype=*/resultElementType,
|
||||
/*originalScalar=*/sub.getAlpha());
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||
return b.create<arith::SubFOp>(loc, lhs, scaled);
|
||||
|
|
|
@ -245,12 +245,20 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
|||
elementType, encoding);
|
||||
}
|
||||
|
||||
static std::optional<int64_t> getIntegerValue(Value scalar) {
|
||||
if (auto constOp = scalar.getDefiningOp<Torch::ConstantIntOp>()) {
|
||||
return std::optional<int64_t>(constOp.getValue());
|
||||
}
|
||||
return std::optional<int64_t>();
|
||||
}
|
||||
|
||||
// Convert a scalar value to the target type. The scalar value can be an element
|
||||
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
||||
// should be converted builtin types.
|
||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||
std::optional<Type> srcOriginalDtype,
|
||||
std::optional<Type> dstOriginalDtype) {
|
||||
std::optional<Type> dstOriginalDtype,
|
||||
std::optional<Value> originalScalar) {
|
||||
Type scalarType = scalar.getType();
|
||||
if (scalarType == dtype)
|
||||
return scalar;
|
||||
|
@ -262,7 +270,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
return false;
|
||||
};
|
||||
|
||||
// We don't support conversion to Byte dtype.
|
||||
// We support conversion to Byte dtype only if the original scalar is an
|
||||
// integer constant with value lying between 0 - 63.
|
||||
if (isByteOrChar(dtype)) {
|
||||
if (!dstOriginalDtype.has_value()) {
|
||||
mlir::emitError(loc)
|
||||
|
@ -271,10 +280,22 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
return nullptr;
|
||||
}
|
||||
if (dstOriginalDtype->isUnsignedInteger()) {
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported: conversion to byte type for convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||
return nullptr;
|
||||
if (originalScalar.has_value()) {
|
||||
std::optional<int64_t> optConstVal =
|
||||
getIntegerValue(originalScalar.value());
|
||||
if (optConstVal.has_value()) {
|
||||
int64_t constVal = optConstVal.value();
|
||||
if (constVal < 0 || constVal > 63) {
|
||||
// Do the conversion only if the original integer value is between
|
||||
// 0 - 63.
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported: conversion to byte type for "
|
||||
"convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -223,6 +223,8 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
|
|||
return %0 : !torch.vtensor<[1,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
|
||||
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -232,6 +234,8 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1
|
|||
return %0 : !torch.vtensor<[1,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_without_axis
|
||||
func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
@ -240,6 +244,8 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
|
|||
return %0 : !torch.vtensor<[3,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
|
||||
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -295,6 +301,8 @@ func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sub_example
|
||||
func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -303,6 +311,8 @@ func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtenso
|
|||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sub
|
||||
func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -311,6 +321,8 @@ func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sub_uint8
|
||||
func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -324,19 +336,23 @@ func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten
|
|||
// CHECK-LABEL: func.func @test_sum_example
|
||||
func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
|
||||
// CHECK: torch.aten.add.Tensor %0, %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor
|
||||
// CHECK: torch.aten.add.Tensor %1, %arg3, %int1 : !torch.vtensor, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
|
||||
// CHECK: %[[SUM:.*]] = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
|
||||
// CHECK: %[[SUM_1:.*]] = torch.aten.add.Tensor %[[SUM]], %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
|
||||
// CHECK: %[[SUM_2:.*]] = torch.aten.add.Tensor %[[SUM_1]], %arg3, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
|
||||
%0 = torch.operator "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sum_one_input
|
||||
func.func @test_sum_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.Sum"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sum_two_inputs
|
||||
func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -370,6 +386,8 @@ func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,
|
|||
return %0 : !torch.vtensor<[3,4],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_xor3d
|
||||
func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1> -> !torch.vtensor<[3,4,5],i1>
|
||||
|
@ -377,6 +395,8 @@ func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[
|
|||
return %0 : !torch.vtensor<[3,4,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_xor4d
|
||||
func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],i1>, !torch.vtensor<[3,4,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1>
|
||||
|
@ -384,6 +404,8 @@ func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor
|
|||
return %0 : !torch.vtensor<[3,4,5,6],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_xor_bcast3v1d
|
||||
func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1>
|
||||
|
@ -391,6 +413,8 @@ func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.v
|
|||
return %0 : !torch.vtensor<[3,4,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_xor_bcast4v4d
|
||||
func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch.vtensor<[3,1,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[1,4,1,6],i1>, !torch.vtensor<[3,1,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1>
|
||||
|
@ -417,6 +441,8 @@ func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vten
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_squeeze_two_axes
|
||||
func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
@ -467,6 +493,8 @@ func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
|
|||
return %0 : !torch.vtensor<[1,3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_axis_1
|
||||
func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
@ -491,6 +519,8 @@ func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
|
|||
return %0 : !torch.vtensor<[3,1,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_axis_2
|
||||
func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
@ -515,6 +545,8 @@ func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
|
|||
return %0 : !torch.vtensor<[3,4,1,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_negative_axes
|
||||
func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
@ -539,6 +571,8 @@ func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %a
|
|||
return %0 : !torch.vtensor<[1,3,1,1,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_three_axes
|
||||
func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
@ -585,6 +619,8 @@ func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1:
|
|||
return %0 : !torch.vtensor<[3,4,1,5,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes
|
||||
func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
@ -642,6 +678,8 @@ func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_softmax_axis_1
|
||||
func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -651,6 +689,8 @@ func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_softmax_axis_2
|
||||
func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
|
@ -660,6 +700,8 @@ func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_softmax_default_axis
|
||||
func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
|
@ -669,6 +711,8 @@ func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !tor
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_softmax_large_number
|
||||
func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
|
@ -678,6 +722,8 @@ func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch
|
|||
return %0 : !torch.vtensor<[2,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_softmax_negative_axis
|
||||
func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
|
@ -773,6 +819,8 @@ func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[
|
|||
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_do_not_keepdims_example
|
||||
func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
|
@ -792,12 +840,16 @@ func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2]
|
|||
return %0 : !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example
|
||||
func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32>
|
||||
return %0 : !torch.vtensor<[3,2,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_empty_set_non_reduced_axis_zero
|
||||
func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
|
@ -817,6 +869,8 @@ func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor
|
|||
return %0 : !torch.vtensor<[2,0,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_keepdims_example
|
||||
func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
|
@ -836,6 +890,8 @@ func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>,
|
|||
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_sum_negative_axes_keepdims_example
|
||||
func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
|
@ -867,6 +923,8 @@ func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor<
|
|||
return %0 : !torch.vtensor<[1,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example
|
||||
func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
|
@ -886,6 +944,8 @@ func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2
|
|||
return %0 : !torch.vtensor<[3,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_mean_keepdims_example
|
||||
func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
|
@ -905,6 +965,8 @@ func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>,
|
|||
return %0 : !torch.vtensor<[3,1,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example
|
||||
func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
|
@ -958,7 +1020,6 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a
|
|||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs
|
||||
func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue