From 579ac8b66628b5707ca1a7c4c41fbf4c829b30b5 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 29 Feb 2024 21:48:46 +0530 Subject: [PATCH] [MLIR][TORCH] Fix OnnxToLinalg lowering issue for sub and sum op (#2954) This commit adds the support for scalar conversion to byte. This commit also fixes the OnnxToLinalg lowering issue for Onnx.Sub and Onnx.Sum op. Fixes https://github.com/nod-ai/SHARK-Turbine/issues/466 Fixes https://github.com/nod-ai/SHARK-Turbine/issues/467 Signed-Off By: Vivek Khandelwal --- include/torch-mlir/Conversion/Utils/Utils.h | 3 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 10 ++- .../TorchToLinalg/Uncategorized.cpp | 3 +- lib/Conversion/Utils/Utils.cpp | 33 +++++++-- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 69 +++++++++++++++++-- 5 files changed, 104 insertions(+), 14 deletions(-) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 516954b88..b76efe869 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -88,7 +88,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype = std::nullopt, - std::optional dstOriginalDtype = std::nullopt); + std::optional dstOriginalDtype = std::nullopt, + std::optional originalScalar = std::nullopt); Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 3deba85a6..b697a4fa2 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -489,8 +489,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } // When binder.op->getNumOperands() > 2 - auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( - binder.op->getContext()); Value curr = rewriter.create( binder.getLoc(), resultType, valList[0], valList[1], const1); for (int i = 2; i < numOperands; i++) { @@ -498,6 +496,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( curr = rewriter.create( binder.getLoc(), resultType, curr, valList[i], const1); } else { + SmallVector resultBroadcastShapeInt; + SmallVector resultBroadcastShapeValue; + Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr, + valList[i], resultBroadcastShapeInt, + resultBroadcastShapeValue); + auto baseType = Torch::ValueTensorType::get( + binder.op->getContext(), resultBroadcastShapeInt, + resultType.getOptionalDtype()); curr = rewriter.create( binder.getLoc(), baseType, curr, valList[i], const1); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d28369cc5..8b4297a62 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -645,7 +645,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( /*dstOriginalDtype=*/resultElementType); Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, /*srcOriginalDtype=*/std::nullopt, - /*dstOriginalDtype=*/resultElementType); + /*dstOriginalDtype=*/resultElementType, + /*originalScalar=*/sub.getAlpha()); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 3df9da94b..064215c51 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -245,12 +245,20 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, elementType, encoding); } +static std::optional getIntegerValue(Value scalar) { + if (auto constOp = scalar.getDefiningOp()) { + return std::optional(constOp.getValue()); + } + return std::optional(); +} + // Convert a scalar value to the target type. The scalar value can be an element // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype, - std::optional dstOriginalDtype) { + std::optional dstOriginalDtype, + std::optional originalScalar) { Type scalarType = scalar.getType(); if (scalarType == dtype) return scalar; @@ -262,7 +270,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return false; }; - // We don't support conversion to Byte dtype. + // We support conversion to Byte dtype only if the original scalar is an + // integer constant with value lying between 0 - 63. if (isByteOrChar(dtype)) { if (!dstOriginalDtype.has_value()) { mlir::emitError(loc) @@ -271,10 +280,22 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return nullptr; } if (dstOriginalDtype->isUnsignedInteger()) { - mlir::emitError(loc) - << "unsupported: conversion to byte type for convertScalarToDtype " - << scalarType << "(scalar type) -> " << dtype << "(dtype)"; - return nullptr; + if (originalScalar.has_value()) { + std::optional optConstVal = + getIntegerValue(originalScalar.value()); + if (optConstVal.has_value()) { + int64_t constVal = optConstVal.value(); + if (constVal < 0 || constVal > 63) { + // Do the conversion only if the original integer value is between + // 0 - 63. + mlir::emitError(loc) + << "unsupported: conversion to byte type for " + "convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + return nullptr; + } + } + } } } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 42be32166..58b4287a4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -223,6 +223,8 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar return %0 : !torch.vtensor<[1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -232,6 +234,8 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1 return %0 : !torch.vtensor<[1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_without_axis func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -240,6 +244,8 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, return %0 : !torch.vtensor<[3,3],f32> } +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -295,6 +301,8 @@ func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_sub_example func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -303,6 +311,8 @@ func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sub func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -311,6 +321,8 @@ func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_sub_uint8 func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -324,19 +336,23 @@ func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten // CHECK-LABEL: func.func @test_sum_example func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: torch.aten.add.Tensor %0, %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor - // CHECK: torch.aten.add.Tensor %1, %arg3, %int1 : !torch.vtensor, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SUM:.*]] = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SUM_1:.*]] = torch.aten.add.Tensor %[[SUM]], %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SUM_2:.*]] = torch.aten.add.Tensor %[[SUM_1]], %arg3, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sum_one_input func.func @test_sum_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.Sum"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sum_two_inputs func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -370,6 +386,8 @@ func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3, return %0 : !torch.vtensor<[3,4],i1> } +// ----- + // CHECK-LABEL: func.func @test_xor3d func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1> -> !torch.vtensor<[3,4,5],i1> @@ -377,6 +395,8 @@ func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[ return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_xor4d func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],i1>, !torch.vtensor<[3,4,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1> @@ -384,6 +404,8 @@ func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor return %0 : !torch.vtensor<[3,4,5,6],i1> } +// ----- + // CHECK-LABEL: func.func @test_xor_bcast3v1d func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> @@ -391,6 +413,8 @@ func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.v return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_xor_bcast4v4d func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch.vtensor<[3,1,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[1,4,1,6],i1>, !torch.vtensor<[3,1,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1> @@ -417,6 +441,8 @@ func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_squeeze_two_axes func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -467,6 +493,8 @@ func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[1,3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_axis_1 func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -491,6 +519,8 @@ func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[3,1,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_axis_2 func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -515,6 +545,8 @@ func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[3,4,1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_negative_axes func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -539,6 +571,8 @@ func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %a return %0 : !torch.vtensor<[1,3,1,1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_three_axes func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -585,6 +619,8 @@ func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -642,6 +678,8 @@ func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_axis_1 func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -651,6 +689,8 @@ func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_axis_2 func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT2:.*]] = torch.constant.int 2 @@ -660,6 +700,8 @@ func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_default_axis func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT2:.*]] = torch.constant.int 2 @@ -669,6 +711,8 @@ func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !tor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_large_number func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -678,6 +722,8 @@ func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch return %0 : !torch.vtensor<[2,4],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_negative_axis func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT2:.*]] = torch.constant.int 2 @@ -773,6 +819,8 @@ func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[ return %0 : !torch.vtensor<[1,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_do_not_keepdims_example func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -792,12 +840,16 @@ func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2] return %0 : !torch.vtensor<[3,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> return %0 : !torch.vtensor<[3,2,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_empty_set_non_reduced_axis_zero func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -817,6 +869,8 @@ func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor return %0 : !torch.vtensor<[2,0,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_keepdims_example func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -836,6 +890,8 @@ func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, return %0 : !torch.vtensor<[3,1,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_negative_axes_keepdims_example func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -867,6 +923,8 @@ func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor< return %0 : !torch.vtensor<[1,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -886,6 +944,8 @@ func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2 return %0 : !torch.vtensor<[3,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_mean_keepdims_example func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -905,6 +965,8 @@ func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, return %0 : !torch.vtensor<[3,1,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -958,7 +1020,6 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // ----- - // CHECK-LABEL: func.func @test_reduce_min_bool_inputs func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[IDX:.+]] = torch.constant.int 0