2023-11-22 13:02:55 +08:00
|
|
|
//===------------------------------------------------------------*- C++ -*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
2023-12-21 23:29:22 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2023-11-22 13:02:55 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::onnx_c;
|
|
|
|
|
|
|
|
// Simple rewrites for the default domain.
|
|
|
|
// See: https://onnx.ai/onnx/operators/
|
|
|
|
// For operators that are effectively version invariant, we register with
|
|
|
|
// sinceVersion==1. We interpret this to include the following spec
|
|
|
|
// diffs that are irrelevant to this level of lowering:
|
|
|
|
// * Supported element types.
|
|
|
|
// * Limited broadcasting to full broadcasting support.
|
|
|
|
//
|
|
|
|
// There are a lot of spec revisions that basically generalized elementwise
|
|
|
|
// to be more normal and a direct translation vs a special case. This
|
|
|
|
// results in a lot of ONNX test cases that all reduce to the exact same
|
|
|
|
// thing here, so we simplify.
|
|
|
|
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
2023-12-14 01:35:32 +08:00
|
|
|
OnnxCustomOpConversionPattern &patterns) {
|
2023-12-21 23:29:22 +08:00
|
|
|
patterns.onOp("HardSigmoid", 6,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value tensorOperand;
|
|
|
|
float alpha, beta;
|
|
|
|
if (binder.tensorOperand(tensorOperand) ||
|
|
|
|
binder.f32FloatAttr(alpha, "alpha", 0.2) ||
|
|
|
|
binder.f32FloatAttr(beta, "beta", 0.5) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta))
|
|
|
|
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getF64FloatAttr(alpha));
|
|
|
|
|
|
|
|
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getF64FloatAttr(beta));
|
|
|
|
|
|
|
|
// Expression: alpha * x + beta
|
|
|
|
Value alpha_x_plus_beta = rewriter.create<Torch::AtenAddScalarOp>(
|
|
|
|
binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha);
|
|
|
|
|
|
|
|
// Expression: min(1, alpha * x + beta)
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
|
|
|
Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(),
|
|
|
|
resultType, constantOne);
|
|
|
|
Value minExpression = rewriter.create<Torch::AtenMinimumOp>(
|
|
|
|
binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta);
|
2023-12-19 02:57:08 +08:00
|
|
|
|
2023-12-21 23:29:22 +08:00
|
|
|
// Expression: max(0, min(1, alpha * x + beta))
|
|
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
|
|
|
Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(),
|
|
|
|
resultType, constantZero);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMaximumOp>(
|
|
|
|
binder.op, resultType, zeroTensor, minExpression);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-19 02:57:08 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Value operand;
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
std::string approximate;
|
|
|
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.customOpNameStringAttr(approximate, "approximate", "none"))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Value vApproximate = rewriter.create<Torch::ConstantStrOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::StringType>(),
|
|
|
|
rewriter.getStringAttr(approximate));
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenGeluOp>(binder.op, resultType,
|
|
|
|
operand, vApproximate);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-28 03:07:35 +08:00
|
|
|
patterns.onOp("Less", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value lhs, rhs;
|
|
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLtTensorOp>(
|
|
|
|
binder.op, resultType, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
|
|
|
|
patterns.onOp("LessOrEqual", 1,
|
2023-12-14 01:35:32 +08:00
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
2023-12-28 03:07:35 +08:00
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value lhs, rhs;
|
|
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLeTensorOp>(
|
|
|
|
binder.op, resultType, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Log", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("MatMul", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
2023-12-14 01:35:32 +08:00
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value lhs, rhs;
|
|
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMatmulOp>(
|
|
|
|
binder.op, resultType, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-28 02:50:08 +08:00
|
|
|
patterns.onOp("Mul", 7,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value lhs, rhs;
|
|
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMulTensorOp>(
|
|
|
|
binder.op, resultType, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-17 01:42:11 +08:00
|
|
|
patterns.onOp("Greater", 16,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value lhs, rhs;
|
|
|
|
std::string direction;
|
|
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenGtTensorOp>(
|
|
|
|
binder.op, resultType, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("GreaterOrEqual", 16,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value lhs, rhs;
|
|
|
|
std::string direction;
|
|
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenGeTensorOp>(
|
|
|
|
binder.op, resultType, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-28 03:07:35 +08:00
|
|
|
patterns.onOp("Max", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
llvm::SmallVector<Value> operands;
|
|
|
|
if (binder.tensorOperandsList(operands) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
operands.size() == 0) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
Value result = operands[0];
|
|
|
|
for (int i = 1; i < operands.size(); i++) {
|
|
|
|
result = rewriter.create<Torch::AtenMaximumOp>(
|
|
|
|
binder.getLoc(), resultType, result, operands[i]);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(
|
|
|
|
binder.op, result.getDefiningOp());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Min", 1,
|
2023-12-17 01:42:11 +08:00
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
2023-12-28 03:07:35 +08:00
|
|
|
llvm::SmallVector<Value> operands;
|
|
|
|
if (binder.tensorOperandsList(operands) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
operands.size() == 0) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
Value result = operands[0];
|
|
|
|
for (int i = 1; i < operands.size(); i++) {
|
|
|
|
result = rewriter.create<Torch::AtenMinimumOp>(
|
|
|
|
binder.getLoc(), resultType, result, operands[i]);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(
|
|
|
|
binder.op, result.getDefiningOp());
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Neg", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
2023-12-17 01:42:11 +08:00
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
2023-12-28 03:07:35 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenNegOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
2023-12-17 01:42:11 +08:00
|
|
|
});
|
2023-12-28 03:07:35 +08:00
|
|
|
patterns.onOp("Not", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Or", 1,
|
2023-12-15 11:23:23 +08:00
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value lhs, rhs;
|
|
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
2023-12-28 03:07:35 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseOrTensorOp>(
|
2023-12-15 11:23:23 +08:00
|
|
|
binder.op, resultType, lhs, rhs);
|
2023-12-28 03:07:35 +08:00
|
|
|
return success();
|
2023-12-15 11:23:23 +08:00
|
|
|
});
|
2023-12-16 03:18:28 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"GatherElements", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data, indices;
|
|
|
|
int64_t axis;
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
|
|
binder.tensorOperandAtIndex(indices, 1) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(axis, "axis", 0))
|
|
|
|
return failure();
|
|
|
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
|
|
Value sparseGrad = rewriter.create<Torch::ConstantBoolOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
|
|
|
|
rewriter.getBoolAttr(false));
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenGatherOp>(
|
|
|
|
binder.op, resultType, data, constAxis, indices, sparseGrad);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-17 02:47:58 +08:00
|
|
|
patterns.onOp(
|
2023-12-28 02:08:09 +08:00
|
|
|
"Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
2023-12-17 02:47:58 +08:00
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value a, b, c;
|
|
|
|
float alpha, beta;
|
|
|
|
int64_t transA, transB;
|
|
|
|
if (binder.tensorOperandAtIndex(a, 0) ||
|
|
|
|
binder.tensorOperandAtIndex(b, 1) ||
|
|
|
|
binder.tensorOperandAtIndex(c, 2) ||
|
|
|
|
binder.s64IntegerAttr(transA, "transA", 0) ||
|
|
|
|
binder.s64IntegerAttr(transB, "transB", 0) ||
|
|
|
|
binder.f32FloatAttr(alpha, "alpha", 1.0) ||
|
|
|
|
binder.f32FloatAttr(beta, "beta", 1.0) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
|
|
|
|
auto transpose = [&](Value m) -> Value {
|
|
|
|
auto tty = m.getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto shape = tty.getOptionalSizes();
|
|
|
|
if (shape.has_value()) {
|
|
|
|
llvm::SmallVector<int64_t> newShape(shape.value());
|
|
|
|
std::reverse(newShape.begin(), newShape.end());
|
|
|
|
shape = std::move(newShape);
|
|
|
|
}
|
|
|
|
auto oty = Torch::ValueTensorType::get(tty.getContext(), shape,
|
|
|
|
tty.getOptionalDtype());
|
|
|
|
return rewriter.create<Torch::AtenTransposeIntOp>(binder.getLoc(),
|
|
|
|
oty, m, zero, one);
|
|
|
|
};
|
|
|
|
|
|
|
|
if (transA) {
|
|
|
|
a = transpose(a);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (transB) {
|
|
|
|
b = transpose(b);
|
|
|
|
}
|
|
|
|
|
|
|
|
Value mm =
|
|
|
|
rewriter.create<Torch::AtenMmOp>(binder.getLoc(), resultType, a, b);
|
|
|
|
if (alpha == 1.0 && beta == 1.0) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
|
|
|
|
binder.op, resultType, mm, c, one);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (alpha != 1.0 && beta != 1.0) {
|
|
|
|
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getF64FloatAttr(alpha));
|
|
|
|
mm = rewriter.create<Torch::AtenMulScalarOp>(
|
|
|
|
binder.getLoc(), resultType, mm, constAlpha);
|
|
|
|
alpha = 1.0;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (alpha != 1.0) {
|
|
|
|
std::swap(alpha, beta);
|
|
|
|
std::swap(mm, c);
|
|
|
|
}
|
|
|
|
|
|
|
|
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getF64FloatAttr(beta));
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
|
|
|
|
binder.op, resultType, mm, c, constBeta);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-27 02:25:31 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"GlobalAveragePool", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
|
|
|
|
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Expected input type having sizes");
|
|
|
|
}
|
|
|
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
|
|
|
unsigned inputRank = inputShape.size();
|
|
|
|
if (!resultType || !resultType.hasSizes()) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Expected result type having sizes");
|
|
|
|
}
|
|
|
|
ArrayRef<int64_t> resultShape = resultType.getSizes();
|
|
|
|
|
|
|
|
SmallVector<Value> cstKernel, cstPadding, cstStrides;
|
|
|
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
|
|
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
|
|
|
for (unsigned i = 2; i < inputRank; i++) {
|
|
|
|
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
|
|
|
|
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
|
|
|
|
cstPadding.push_back(cstZero);
|
|
|
|
cstStrides.push_back(cstOne);
|
|
|
|
}
|
|
|
|
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
cstKernel);
|
|
|
|
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
cstPadding);
|
|
|
|
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
cstStrides);
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
Value cstCeilMode = cstFalse;
|
|
|
|
Value cstCountIncludePad = cstFalse;
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
|
|
|
if (inputRank == 3) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool1dOp>(
|
|
|
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
|
|
|
paddingList, cstCeilMode, cstCountIncludePad);
|
|
|
|
return success();
|
|
|
|
} else if (inputRank == 4) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool2dOp>(
|
|
|
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
|
|
|
paddingList, cstCeilMode, cstCountIncludePad,
|
|
|
|
/*divisor_override=*/cstNone);
|
|
|
|
return success();
|
|
|
|
} else if (inputRank == 5) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool3dOp>(
|
|
|
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
|
|
|
paddingList, cstCeilMode, cstCountIncludePad,
|
|
|
|
/*divisor_override=*/cstNone);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
});
|
2023-12-28 02:08:09 +08:00
|
|
|
patterns.onOp("LeakyRelu", 1,
|
2023-12-16 03:18:28 +08:00
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
float alpha;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.f32FloatAttr(alpha, "alpha", 0.01))
|
|
|
|
return failure();
|
|
|
|
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getF64FloatAttr(alpha));
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLeakyReluOp>(
|
|
|
|
binder.op, resultType, operand, constAlpha);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-28 01:34:48 +08:00
|
|
|
patterns.onOp("Pow", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value lhs, rhs;
|
|
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenPowTensorTensorOp>(
|
|
|
|
binder.op, resultType, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
}
|