mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for Maxpool Op (#2695)
Add Maxpool ONNX op support. Add Utils.h/cpp files to create a constant int list for ONNX.pull/2754/head
parent
670a99ae19
commit
c7452af4fa
|
@ -0,0 +1,23 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
||||
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
||||
|
||||
namespace mlir::torch::onnx_c {
|
||||
|
||||
Value createConstantIntList(OpBinder binder,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<int64_t> cstInput);
|
||||
|
||||
} // namespace mlir::torch::onnx_c
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
|
@ -5,6 +5,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch
|
|||
Passes.cpp
|
||||
Patterns.cpp
|
||||
TorchOnnxToTorch.cpp
|
||||
Utils.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -148,6 +149,84 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, lhs, rhs);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"auto_pad bind failure");
|
||||
if (autoPad != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
bool ceilMode;
|
||||
int64_t storageOrder;
|
||||
// TODO: Add support for indices output and storage_order
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.s64BoolAttr(ceilMode, "ceil_mode", false) ||
|
||||
binder.s64IntegerAttr(storageOrder, "storage_order", 0) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"operand/ceil_mode/storage_order/resultType bind failure");
|
||||
if (storageOrder != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "storage_order setting is not supported.");
|
||||
// Determine the rank of input tensor.
|
||||
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
|
||||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: unranked tensor");
|
||||
unsigned rank = *maybeRank;
|
||||
|
||||
SmallVector<int64_t> kernel, padding, strides, dilations;
|
||||
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"kernel_shape bind failure");
|
||||
if (kernel.size() != rank - 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "kernel list size does not match the number of axes");
|
||||
if (binder.s64IntegerArrayAttr(padding, "pads", {0}))
|
||||
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
|
||||
if (padding.size() != 1 && padding.size() != rank - 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "padding list size does not match the number of axes");
|
||||
if (binder.s64IntegerArrayAttr(strides, "strides", {1}))
|
||||
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
|
||||
if (strides.size() != 1 && strides.size() != rank - 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "strides list size does not match the number of axes");
|
||||
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"dilations bind failure");
|
||||
|
||||
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
|
||||
Value paddingList = createConstantIntList(binder, rewriter, padding);
|
||||
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
||||
Value dilationsList =
|
||||
createConstantIntList(binder, rewriter, dilations);
|
||||
Value cstCeilMode =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
||||
|
||||
if (rank == 3)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: AtenMaxPool1dOp");
|
||||
if (rank == 4) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
|
||||
binder.op, resultType, operand, kernelSizeList, stridesList,
|
||||
paddingList, dilationsList, cstCeilMode);
|
||||
return success();
|
||||
}
|
||||
if (rank == 5) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool3dOp>(
|
||||
binder.op, resultType, operand, kernelSizeList, stridesList,
|
||||
paddingList, dilationsList, cstCeilMode);
|
||||
return success();
|
||||
}
|
||||
return rewriter.notifyMatchFailure(binder.op, "No rank is matched.");
|
||||
});
|
||||
patterns.onOp("Greater", 16,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::onnx_c;
|
||||
|
||||
Value mlir::torch::onnx_c::createConstantIntList(
|
||||
OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||
SmallVector<int64_t> cstInput) {
|
||||
SmallVector<Value> cstValue;
|
||||
for (int64_t i : cstInput) {
|
||||
cstValue.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
return rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstValue);
|
||||
}
|
|
@ -13,6 +13,8 @@ func.func @test_greater(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtenso
|
|||
return %0 : !torch.vtensor<[3,4,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_greater_or_equal
|
||||
func.func @test_greater_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
||||
|
@ -22,6 +24,8 @@ func.func @test_greater_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor
|
|||
return %0 : !torch.vtensor<[3,4,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_less
|
||||
func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
||||
|
@ -31,6 +35,8 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[
|
|||
return %0 : !torch.vtensor<[3,4,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gather_elements
|
||||
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
|
||||
|
@ -99,7 +105,7 @@ func.func @test_gemm_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtenso
|
|||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gemm_alpha_beta
|
||||
func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
|
@ -137,6 +143,8 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_matmul_2d
|
||||
func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>
|
||||
|
@ -173,6 +181,62 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_2d_default
|
||||
func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31,31],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32>
|
||||
return %0 : !torch.vtensor<[1,3,31,31],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_2d_ceil
|
||||
func.func @test_maxpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
// CHECK: %[[I3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[I3_1:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[LIST33:.*]] = torch.prim.ListConstruct %[[I3]], %[[I3_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: torch.aten.max_pool2d %arg0, %[[LIST33]], %[[LIST22]], %[[LIST0]], %[[LIST]], %[[TRUE]] : !torch.vtensor<[1,1,4,4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,2,2],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32>
|
||||
return %0 : !torch.vtensor<[1,1,2,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_3d_default
|
||||
func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[I2_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[LIST222:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]], %[[I2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: torch.aten.max_pool3d %arg0, %[[LIST222]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31,31,31],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32>
|
||||
return %0 : !torch.vtensor<[1,3,31,31,31],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_gelu_default_1
|
||||
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[STR1:.*]] = torch.constant.str "none"
|
||||
|
@ -222,6 +286,8 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.
|
|||
return %0 : !torch.vtensor<[3,4,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_pow
|
||||
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
|
@ -229,6 +295,8 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_hardsigmoid_example
|
||||
func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01
|
||||
|
@ -252,6 +320,8 @@ func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vt
|
|||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_hardsigmoid
|
||||
func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01
|
||||
|
@ -274,6 +344,8 @@ func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_hardsigmoid_default
|
||||
func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224
|
||||
|
@ -331,6 +403,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3
|
|||
return %0 : !torch.vtensor<[1,1,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_max_example
|
||||
func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||
|
@ -338,6 +412,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3
|
|||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_min_example
|
||||
func.func @test_min_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||
|
@ -345,6 +421,7 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3
|
|||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_log
|
||||
func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
|
@ -353,6 +430,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_neg
|
||||
func.func @test_neg(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
|
@ -360,6 +439,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3
|
|||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_not_2d
|
||||
func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>
|
||||
|
@ -367,6 +448,8 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
|
|||
return %0 : !torch.vtensor<[3,4],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_or2d
|
||||
func.func @test_or2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>
|
||||
|
|
Loading…
Reference in New Issue