mirror of https://github.com/llvm/torch-mlir
1068 lines
47 KiB
C++
1068 lines
47 KiB
C++
//===------------------------------------------------------------*- C++ -*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::torch;
|
|
using namespace mlir::torch::onnx_c;
|
|
|
|
// Simple rewrites for the default domain.
|
|
// See: https://onnx.ai/onnx/operators/
|
|
// For operators that are effectively version invariant, we register with
|
|
// sinceVersion==1. We interpret this to include the following spec
|
|
// diffs that are irrelevant to this level of lowering:
|
|
// * Supported element types.
|
|
// * Limited broadcasting to full broadcasting support.
|
|
//
|
|
// There are a lot of spec revisions that basically generalized elementwise
|
|
// to be more normal and a direct translation vs a special case. This
|
|
// results in a lot of ONNX test cases that all reduce to the exact same
|
|
// thing here, so we simplify.
|
|
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|
OnnxCustomOpConversionPattern &patterns) {
|
|
patterns.onOp(
|
|
"HardSigmoid", 6,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value tensorOperand;
|
|
float alpha, beta;
|
|
if (binder.tensorOperand(tensorOperand) ||
|
|
binder.f32FloatAttr(alpha, "alpha", 0.2f) ||
|
|
binder.f32FloatAttr(beta, "beta", 0.5f) ||
|
|
binder.tensorResultType(resultType))
|
|
return failure();
|
|
|
|
// HardSigmoid computes the following expression:
|
|
// max(0, min(1, alpha * x + beta))
|
|
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
rewriter.getF64FloatAttr(alpha));
|
|
|
|
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
rewriter.getF64FloatAttr(beta));
|
|
|
|
// Expression: alpha * x + beta
|
|
Value alpha_x_plus_beta = rewriter.create<Torch::AtenAddScalarOp>(
|
|
binder.getLoc(), resultType, tensorOperand, constBeta,
|
|
/*alpha=*/constAlpha);
|
|
|
|
// Expression: min(1, alpha * x + beta)
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
|
Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(),
|
|
resultType, constantOne);
|
|
Value minExpression = rewriter.create<Torch::AtenMinimumOp>(
|
|
binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta);
|
|
|
|
// Expression: max(0, min(1, alpha * x + beta))
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
|
Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(),
|
|
resultType, constantZero);
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMaximumOp>(
|
|
binder.op, resultType, zeroTensor, minExpression);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Value operand;
|
|
Torch::ValueTensorType resultType;
|
|
std::string approximate;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
binder.tensorResultType(resultType) ||
|
|
binder.customOpNameStringAttr(approximate, "approximate", "none"))
|
|
return failure();
|
|
|
|
Value vApproximate = rewriter.create<Torch::ConstantStrOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::StringType>(),
|
|
rewriter.getStringAttr(approximate));
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenGeluOp>(binder.op, resultType,
|
|
operand, vApproximate);
|
|
return success();
|
|
});
|
|
patterns.onOp("Less", 13,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs;
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLtTensorOp>(
|
|
binder.op, resultType, lhs, rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp("LessOrEqual", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs;
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLeTensorOp>(
|
|
binder.op, resultType, lhs, rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp("Log", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value operand;
|
|
if (binder.tensorOperand(operand) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
|
|
binder.op, resultType, operand);
|
|
return success();
|
|
});
|
|
patterns.onOp("MatMul", 13,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs;
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
binder.tensorResultType(resultType))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMatmulOp>(
|
|
binder.op, resultType, lhs, rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"MatMulInteger", 10,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs, lhsZp, rhsZp;
|
|
if (binder.tensorOperandAtIndex(lhs, 0) ||
|
|
binder.tensorOperandAtIndex(rhs, 1) ||
|
|
binder.tensorResultType(resultType))
|
|
return failure();
|
|
|
|
if (binder.tensorOperandAtIndex(lhsZp, 2)) {
|
|
lhsZp = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
}
|
|
|
|
if (binder.tensorOperandAtIndex(rhsZp, 3)) {
|
|
rhsZp = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
}
|
|
|
|
auto lhsTy = dyn_cast<Torch::ValueTensorType>(lhs.getType());
|
|
auto rhsTy = dyn_cast<Torch::ValueTensorType>(rhs.getType());
|
|
|
|
if (auto zpTy = dyn_cast<Torch::ValueTensorType>(lhsZp.getType())) {
|
|
for (auto dim : zpTy.getSizes())
|
|
if (dim != 1)
|
|
return failure();
|
|
lhsZp = rewriter.create<Torch::AtenItemOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), lhsZp);
|
|
}
|
|
|
|
if (auto zpTy = dyn_cast<Torch::ValueTensorType>(rhsZp.getType())) {
|
|
for (auto dim : zpTy.getSizes())
|
|
if (dim != 1)
|
|
return failure();
|
|
rhsZp = rewriter.create<Torch::AtenItemOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), rhsZp);
|
|
}
|
|
|
|
Value scale = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
rewriter.getF64FloatAttr(1.0));
|
|
|
|
auto q = [&](Type qty) -> Type {
|
|
if (qty.isSignedInteger(8))
|
|
return rewriter.getType<Torch::QInt8Type>();
|
|
if (qty.isUnsignedInteger(8))
|
|
return rewriter.getType<Torch::QUInt8Type>();
|
|
if (qty.isSignedInteger(32))
|
|
return rewriter.getType<Torch::QInt32Type>();
|
|
return {};
|
|
};
|
|
|
|
Type lhsQTy = rewriter.getType<Torch::ValueTensorType>(
|
|
lhsTy.getOptionalSizes(), q(lhsTy.getDtype()));
|
|
Type rhsQTy = rewriter.getType<Torch::ValueTensorType>(
|
|
rhsTy.getOptionalSizes(), q(rhsTy.getDtype()));
|
|
|
|
lhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
binder.getLoc(), lhsQTy, lhs, scale, lhsZp);
|
|
rhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
binder.getLoc(), rhsQTy, rhs, scale, rhsZp);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMmOp>(binder.op, resultType, lhs,
|
|
rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp("Mul", 7,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs;
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMulTensorOp>(
|
|
binder.op, resultType, lhs, rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp("NonZero", 13,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value operand;
|
|
if (binder.tensorOperand(operand) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenNonzeroOp>(
|
|
binder.op, resultType, operand);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
std::string autoPad;
|
|
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
"auto_pad bind failure");
|
|
if (autoPad != "NOTSET")
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
|
|
|
Torch::ValueTensorType resultType;
|
|
Value operand;
|
|
bool ceilMode;
|
|
int64_t storageOrder;
|
|
// TODO: Add support for indices output and storage_order
|
|
if (binder.tensorOperand(operand) ||
|
|
binder.s64BoolAttr(ceilMode, "ceil_mode", false) ||
|
|
binder.s64IntegerAttr(storageOrder, "storage_order", 0) ||
|
|
binder.tensorResultType(resultType))
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op,
|
|
"operand/ceil_mode/storage_order/resultType bind failure");
|
|
if (storageOrder != 0)
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "storage_order setting is not supported.");
|
|
// Determine the rank of input tensor.
|
|
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
|
|
if (!maybeRank)
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
"Unimplemented: unranked tensor");
|
|
int64_t rank = *maybeRank;
|
|
int64_t spatial = rank - 2;
|
|
|
|
SmallVector<int64_t> kernel, padding, strides, dilations;
|
|
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
"kernel_shape bind failure");
|
|
if (kernel.size() != static_cast<size_t>(spatial))
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "kernel list size does not match the number of axes");
|
|
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
|
|
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
|
|
if (!padding.empty() &&
|
|
padding.size() != static_cast<size_t>(2 * spatial))
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "padding list must contain (begin,end) pair for each "
|
|
"spatial axis");
|
|
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
|
|
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
|
|
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "strides list size does not match the number of axes");
|
|
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
"dilations bind failure");
|
|
|
|
if (padding.empty())
|
|
padding.resize(spatial, 0);
|
|
if (strides.empty())
|
|
strides.resize(spatial, 1);
|
|
if (dilations.empty())
|
|
dilations.resize(spatial, 1);
|
|
|
|
// If the padding is symmetric we can push the padding operation to the
|
|
// torch operator.
|
|
if (padding.size() == static_cast<size_t>(2 * spatial)) {
|
|
bool equal = true;
|
|
for (int i = 0; i < spatial; ++i) {
|
|
equal = equal && (padding[i] == padding[i + spatial]);
|
|
}
|
|
if (equal)
|
|
padding.resize(spatial);
|
|
}
|
|
|
|
// Torch pool operators require equal padding on each size of each
|
|
// dimension so we materialize the padding behavior explicitly and set
|
|
// the padding to 0.
|
|
if (padding.size() == static_cast<size_t>(2 * spatial)) {
|
|
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
|
|
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
|
|
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
|
|
shuffledPadding.resize(2 * rank);
|
|
for (int i = 0; i < spatial; ++i) {
|
|
paddedShape[i + 2] += padding[i] + padding[i + spatial];
|
|
shuffledPadding[2 * i] = padding[i];
|
|
shuffledPadding[2 * i + 1] = padding[i + spatial];
|
|
}
|
|
|
|
Value shuffledPaddingList =
|
|
createConstantIntList(binder, rewriter, padding);
|
|
Value zero;
|
|
if (resultType.getDtype().isa<FloatType>()) {
|
|
zero = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
rewriter.getF64FloatAttr(
|
|
std::numeric_limits<double>::lowest()));
|
|
} else if (resultType.getDtype().isa<IntegerType>()) {
|
|
zero = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(
|
|
std::numeric_limits<int64_t>::lowest()));
|
|
}
|
|
|
|
auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>(
|
|
paddedShape, operandTy.getDtype());
|
|
operand = rewriter.create<Torch::AtenConstantPadNdOp>(
|
|
binder.getLoc(), paddedInputTy, operand, shuffledPaddingList,
|
|
zero);
|
|
padding.clear();
|
|
padding.resize(spatial, 0);
|
|
}
|
|
|
|
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
|
|
Value paddingList = createConstantIntList(binder, rewriter, padding);
|
|
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
|
Value dilationsList =
|
|
createConstantIntList(binder, rewriter, dilations);
|
|
Value cstCeilMode =
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
|
|
|
if (rank == 3)
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
"Unimplemented: AtenMaxPool1dOp");
|
|
if (rank == 4) {
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
|
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
|
paddingList, dilationsList, cstCeilMode);
|
|
return success();
|
|
}
|
|
if (rank == 5) {
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool3dOp>(
|
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
|
paddingList, dilationsList, cstCeilMode);
|
|
return success();
|
|
}
|
|
return rewriter.notifyMatchFailure(binder.op, "No rank is matched.");
|
|
});
|
|
patterns.onOp("Greater", 16,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs;
|
|
std::string direction;
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
binder.tensorResultType(resultType))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<Torch::AtenGtTensorOp>(
|
|
binder.op, resultType, lhs, rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp("GreaterOrEqual", 16,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs;
|
|
std::string direction;
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
binder.tensorResultType(resultType))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<Torch::AtenGeTensorOp>(
|
|
binder.op, resultType, lhs, rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"InstanceNormalization", 6,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
llvm::SmallVector<Value> operands;
|
|
float eps;
|
|
|
|
if (binder.tensorOperands(operands, 3) ||
|
|
binder.tensorResultType(resultType) || operands.size() != 3 ||
|
|
binder.f32FloatAttr(eps, "epsilon", 1e-05f)) {
|
|
return failure();
|
|
}
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
Value boolTrue =
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
Value boolFalse =
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
auto epsValue = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getF64FloatAttr(eps));
|
|
|
|
auto momentum = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
|
|
rewriter.replaceOpWithNewOp<Torch::AtenInstanceNormOp>(
|
|
binder.op, resultType, /* input */ operands[0],
|
|
/* weight */ operands[1],
|
|
/* bias */ operands[2], /* running mean */ none,
|
|
/* running var */ none,
|
|
/* use input stats */ boolTrue, momentum, epsValue,
|
|
/* cudnn enabled */ boolFalse);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
llvm::SmallVector<Value> operands;
|
|
if (binder.tensorOperandsList(operands) ||
|
|
binder.tensorResultType(resultType) || operands.size() == 0) {
|
|
return failure();
|
|
}
|
|
Value result = operands[0];
|
|
for (uint64_t i = 1; i < operands.size(); i++) {
|
|
result = rewriter.create<Torch::AtenMaximumOp>(
|
|
binder.getLoc(), resultType, result, operands[i]);
|
|
}
|
|
rewriter.replaceOp(binder.op, result.getDefiningOp());
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
llvm::SmallVector<Value> operands;
|
|
if (binder.tensorOperandsList(operands) ||
|
|
binder.tensorResultType(resultType) || operands.size() == 0) {
|
|
return failure();
|
|
}
|
|
Value result = operands[0];
|
|
for (uint64_t i = 1; i < operands.size(); i++) {
|
|
result = rewriter.create<Torch::AtenMinimumOp>(
|
|
binder.getLoc(), resultType, result, operands[i]);
|
|
}
|
|
rewriter.replaceOp(binder.op, result.getDefiningOp());
|
|
return success();
|
|
});
|
|
patterns.onOp("Neg", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value operand;
|
|
if (binder.tensorOperand(operand) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenNegOp>(
|
|
binder.op, resultType, operand);
|
|
return success();
|
|
});
|
|
patterns.onOp("Not", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value operand;
|
|
if (binder.tensorOperand(operand) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
|
|
binder.op, resultType, operand);
|
|
return success();
|
|
});
|
|
patterns.onOp("Or", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs;
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseOrTensorOp>(
|
|
binder.op, resultType, lhs, rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value data, indices;
|
|
int64_t axis;
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
binder.tensorOperandAtIndex(indices, 1) ||
|
|
binder.tensorResultType(resultType) ||
|
|
binder.s64IntegerAttr(axis, "axis", 0))
|
|
return failure();
|
|
Location loc = binder.getLoc();
|
|
auto ctx = binder.op->getContext();
|
|
auto indicesTy = cast<Torch::ValueTensorType>(indices.getType());
|
|
auto dataTy = cast<Torch::ValueTensorType>(data.getType());
|
|
if (!dataTy || !dataTy.hasSizes())
|
|
return failure();
|
|
if (axis < 0)
|
|
axis += dataTy.getSizes().size();
|
|
|
|
Value index = rewriter.create<Torch::ConstantIntOp>(
|
|
loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis));
|
|
|
|
// Apply bounds checking on the input:
|
|
auto intTy = rewriter.getType<Torch::IntType>();
|
|
auto boolTy = rewriter.getType<Torch::ValueTensorType>(
|
|
indicesTy.getSizes(), rewriter.getI1Type());
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
loc, intTy, rewriter.getI64IntegerAttr(0));
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
loc, intTy, rewriter.getI64IntegerAttr(1));
|
|
Value lt =
|
|
rewriter.create<Torch::AtenLeScalarOp>(loc, boolTy, indices, zero);
|
|
Value dim =
|
|
rewriter.create<Torch::AtenSizeIntOp>(loc, intTy, data, index);
|
|
Value add = rewriter.create<Torch::AtenAddScalarOp>(loc, indicesTy,
|
|
indices, dim, one);
|
|
indices = rewriter.create<Torch::AtenWhereSelfOp>(loc, indicesTy, lt,
|
|
add, indices);
|
|
|
|
auto intListTy = rewriter.getType<Torch::ListType>(
|
|
rewriter.getType<Torch::IntType>());
|
|
auto indicesSize =
|
|
rewriter.create<Torch::AtenSizeOp>(loc, intListTy, indices);
|
|
|
|
// Determine the collapsed dim size:
|
|
auto indicesCt = 1;
|
|
for (auto sz : indicesTy.getSizes()) {
|
|
if (sz == Torch::kUnknownSize) {
|
|
indicesCt = Torch::kUnknownSize;
|
|
break;
|
|
}
|
|
|
|
indicesCt *= sz;
|
|
}
|
|
|
|
auto flattenTy = rewriter.getType<Torch::ValueTensorType>(
|
|
SmallVector<int64_t>{indicesCt}, indicesTy.getOptionalDtype());
|
|
Value rank = rewriter.create<Torch::AtenDimOp>(loc, intTy, indices);
|
|
Value end = rewriter.create<Torch::AtenSubIntOp>(loc, rank, one);
|
|
indices = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
|
loc, flattenTy, indices, zero, end);
|
|
|
|
llvm::SmallVector<int64_t> gatherShape(dataTy.getSizes());
|
|
gatherShape[axis] = indicesCt;
|
|
|
|
auto gatherTy = rewriter.getType<Torch::ValueTensorType>(
|
|
gatherShape, dataTy.getOptionalDtype());
|
|
Value gather = rewriter.create<Torch::AtenIndexSelectOp>(
|
|
loc, gatherTy, data, index, indices);
|
|
rewriter.replaceOpWithNewOp<Torch::AtenUnflattenIntOp>(
|
|
binder.op, resultType, gather, index, indicesSize);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"GatherElements", 13,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value data, indices;
|
|
int64_t axis;
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
binder.tensorOperandAtIndex(indices, 1) ||
|
|
binder.tensorResultType(resultType) ||
|
|
binder.s64IntegerAttr(axis, "axis", 0))
|
|
return failure();
|
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
Value sparseGrad = rewriter.create<Torch::ConstantBoolOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
|
|
rewriter.getBoolAttr(false));
|
|
rewriter.replaceOpWithNewOp<Torch::AtenGatherOp>(
|
|
binder.op, resultType, data, constAxis, indices, sparseGrad);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value a, b, c;
|
|
float alpha, beta;
|
|
int64_t transA, transB;
|
|
if (binder.tensorOperandAtIndex(a, 0) ||
|
|
binder.tensorOperandAtIndex(b, 1) ||
|
|
binder.tensorOperandAtIndex(c, 2) ||
|
|
binder.s64IntegerAttr(transA, "transA", 0) ||
|
|
binder.s64IntegerAttr(transB, "transB", 0) ||
|
|
binder.f32FloatAttr(alpha, "alpha", 1.0f) ||
|
|
binder.f32FloatAttr(beta, "beta", 1.0f) ||
|
|
binder.tensorResultType(resultType))
|
|
return failure();
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
auto transpose = [&](Value m) -> Value {
|
|
auto tty = m.getType().cast<Torch::ValueTensorType>();
|
|
auto shape = tty.getOptionalSizes();
|
|
if (shape.has_value()) {
|
|
llvm::SmallVector<int64_t> newShape(shape.value());
|
|
std::reverse(newShape.begin(), newShape.end());
|
|
shape = std::move(newShape);
|
|
}
|
|
auto oty = Torch::ValueTensorType::get(tty.getContext(), shape,
|
|
tty.getOptionalDtype());
|
|
return rewriter.create<Torch::AtenTransposeIntOp>(binder.getLoc(),
|
|
oty, m, zero, one);
|
|
};
|
|
|
|
if (transA) {
|
|
a = transpose(a);
|
|
}
|
|
|
|
if (transB) {
|
|
b = transpose(b);
|
|
}
|
|
|
|
Value mm =
|
|
rewriter.create<Torch::AtenMmOp>(binder.getLoc(), resultType, a, b);
|
|
if (alpha == 1.0 && beta == 1.0) {
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
|
|
binder.op, resultType, mm, c, one);
|
|
return success();
|
|
}
|
|
|
|
if (alpha != 1.0 && beta != 1.0) {
|
|
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
rewriter.getF64FloatAttr(alpha));
|
|
mm = rewriter.create<Torch::AtenMulScalarOp>(
|
|
binder.getLoc(), resultType, mm, constAlpha);
|
|
alpha = 1.0;
|
|
}
|
|
|
|
if (alpha != 1.0) {
|
|
std::swap(alpha, beta);
|
|
std::swap(mm, c);
|
|
}
|
|
|
|
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
rewriter.getF64FloatAttr(beta));
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
|
|
binder.op, resultType, mm, c, constBeta);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"GlobalAveragePool", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value operand;
|
|
if (binder.tensorOperand(operand) ||
|
|
binder.tensorResultType(resultType))
|
|
return failure();
|
|
|
|
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
|
|
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "Expected input type having sizes");
|
|
}
|
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
|
unsigned inputRank = inputShape.size();
|
|
if (!resultType || !resultType.hasSizes()) {
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "Expected result type having sizes");
|
|
}
|
|
ArrayRef<int64_t> resultShape = resultType.getSizes();
|
|
|
|
SmallVector<Value> cstKernel, cstPadding, cstStrides;
|
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
|
for (unsigned i = 2; i < inputRank; i++) {
|
|
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
|
|
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
|
|
cstPadding.push_back(cstZero);
|
|
cstStrides.push_back(cstOne);
|
|
}
|
|
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
|
binder.getLoc(),
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
cstKernel);
|
|
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
|
binder.getLoc(),
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
cstPadding);
|
|
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
|
binder.getLoc(),
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
cstStrides);
|
|
Value cstFalse =
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
Value cstCeilMode = cstFalse;
|
|
Value cstCountIncludePad = cstFalse;
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
if (inputRank == 3) {
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool1dOp>(
|
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
|
paddingList, cstCeilMode, cstCountIncludePad);
|
|
return success();
|
|
} else if (inputRank == 4) {
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool2dOp>(
|
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
|
paddingList, cstCeilMode, cstCountIncludePad,
|
|
/*divisor_override=*/cstNone);
|
|
return success();
|
|
} else if (inputRank == 5) {
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool3dOp>(
|
|
binder.op, resultType, operand, kernelSizeList, stridesList,
|
|
paddingList, cstCeilMode, cstCountIncludePad,
|
|
/*divisor_override=*/cstNone);
|
|
return success();
|
|
}
|
|
return failure();
|
|
});
|
|
patterns.onOp(
|
|
"LayerNormalization", 17,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType yType, meanType, invStdDevType;
|
|
Value x, scale, b;
|
|
int64_t axis, stashType;
|
|
float epsilon;
|
|
if (binder.tensorOperandAtIndex(x, 0) ||
|
|
binder.tensorOperandAtIndex(scale, 1) ||
|
|
binder.tensorOperandAtIndex(b, 2) ||
|
|
binder.tensorResultTypeAtIndex(yType, 0) ||
|
|
binder.s64IntegerAttr(axis, "axis", -1) ||
|
|
binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) ||
|
|
binder.s64IntegerAttr(stashType, "stash_type", 1))
|
|
return failure();
|
|
Value constEpsilon = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
rewriter.getF64FloatAttr(epsilon));
|
|
unsigned rank = 1;
|
|
if (std::optional<unsigned> maybeRank = Torch::getTensorRank(x))
|
|
rank = *maybeRank;
|
|
SmallVector<Value> normalized;
|
|
axis = Torch::toPositiveDim(axis, rank);
|
|
auto xType = x.getType().cast<Torch::ValueTensorType>();
|
|
if (!xType.hasSizes()) {
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "Expected input (X) to have sizes");
|
|
}
|
|
ArrayRef<int64_t> xShape = xType.getSizes();
|
|
for (int64_t n = axis; n < rank; n++) {
|
|
normalized.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n])));
|
|
}
|
|
Value normalized_shape = rewriter.create<Torch::PrimListConstructOp>(
|
|
binder.getLoc(),
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
normalized);
|
|
|
|
int64_t numResults = binder.op->getNumResults();
|
|
if (numResults == 1) {
|
|
SmallVector<int64_t> reducedShape(rank, 1);
|
|
for (int64_t i = 0; i < axis; i++)
|
|
reducedShape[i] = xShape[i];
|
|
auto reducedType = xType.getWithSizesAndDtype(
|
|
reducedShape, xType.getOptionalDtype());
|
|
Value y = rewriter
|
|
.create<Torch::AtenNativeLayerNormOp>(
|
|
binder.getLoc(), yType, /*meanType=*/reducedType,
|
|
/*invStdDevType=*/reducedType, x, normalized_shape,
|
|
scale, b, constEpsilon)
|
|
.getResult0();
|
|
rewriter.replaceOp(binder.op, y);
|
|
return success();
|
|
}
|
|
if (numResults == 3) {
|
|
if (binder.tensorResultTypeAtIndex(meanType, 1) ||
|
|
binder.tensorResultTypeAtIndex(invStdDevType, 2))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<Torch::AtenNativeLayerNormOp>(
|
|
binder.op, yType, meanType, invStdDevType, x, normalized_shape,
|
|
scale, b, constEpsilon);
|
|
return success();
|
|
}
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "Unimplemented: expected either 1 or 3 results");
|
|
});
|
|
patterns.onOp("LeakyRelu", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value operand;
|
|
float alpha;
|
|
if (binder.tensorOperand(operand) ||
|
|
binder.tensorResultType(resultType) ||
|
|
binder.f32FloatAttr(alpha, "alpha", 0.01f))
|
|
return failure();
|
|
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
rewriter.getF64FloatAttr(alpha));
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLeakyReluOp>(
|
|
binder.op, resultType, operand, constAlpha);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value data, pads, axes;
|
|
std::string mode;
|
|
|
|
// TODO: The `axes` parameter is not supported yet.
|
|
if (!binder.tensorOperandAtIndex(axes, 3)) {
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "The axes parameter is not supported yet");
|
|
}
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
binder.tensorOperandAtIndex(pads, 1) ||
|
|
binder.tensorResultType(resultType) ||
|
|
binder.customOpNameStringAttr(mode, "mode", "constant"))
|
|
return failure();
|
|
Location loc = binder.getLoc();
|
|
|
|
Value constantValue;
|
|
if (binder.getNumOperands() >= 3) {
|
|
if (binder.tensorOperandAtIndex(constantValue, 2)) {
|
|
llvm::errs() << "failed to bind to index 2\n";
|
|
return failure();
|
|
}
|
|
} else {
|
|
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto maybeZeroAttr = [&]() -> std::optional<Attribute> {
|
|
if (dataTensorType.getDtype().isa<IntegerType>()) {
|
|
return rewriter.getI64IntegerAttr(0);
|
|
}
|
|
if (dataTensorType.getDtype().isa<FloatType>()) {
|
|
return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f);
|
|
}
|
|
return std::nullopt;
|
|
}();
|
|
|
|
if (!maybeZeroAttr) {
|
|
return rewriter.notifyMatchFailure(
|
|
binder.op, "expected integer or float data tensor");
|
|
}
|
|
|
|
auto shapedType = dataTensorType.toBuiltinTensor();
|
|
auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr);
|
|
constantValue = rewriter.create<Torch::ValueTensorLiteralOp>(
|
|
loc, dataTensorType, splat);
|
|
}
|
|
|
|
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
|
// tensor.
|
|
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
|
|
if (!padsTensorType || !padsTensorType.hasSizes()) {
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
"Expect non empty pad tensor");
|
|
}
|
|
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
|
|
int64_t padsRank = padsShape.size();
|
|
if (padsRank != 1) {
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
"Expect 1-D pad tensor");
|
|
}
|
|
|
|
// Extract all the values of 1-D pad tensor and create a list of all
|
|
// these values as torch.pad op expects pad list.
|
|
int64_t padsSize = padsShape[0];
|
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
SmallVector<Value> padsTensorValue;
|
|
SmallVector<int64_t> emptyShape;
|
|
Type padsElemType =
|
|
Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape,
|
|
padsTensorType.getOptionalDtype());
|
|
for (uint32_t i = 0; i < padsSize; ++i) {
|
|
Value index = rewriter.create<Torch::ConstantIntOp>(
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
padsTensorValue.emplace_back(rewriter.create<Torch::AtenSelectIntOp>(
|
|
loc, padsElemType, pads, constZero, index));
|
|
}
|
|
|
|
// The torch.pad op expects a different arrangement of padding pairs for
|
|
// each dimension as compared to the onnx.pad op. So, rearranging pad
|
|
// tensor to satisfy torch.pad op semantics.
|
|
SmallVector<Value> padsRearrange;
|
|
for (uint32_t i = 0; i < padsSize / 2; i++) {
|
|
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) - 1 - i]);
|
|
padsRearrange.emplace_back(padsTensorValue[padsSize - 1 - i]);
|
|
}
|
|
|
|
Value padsSizeList =
|
|
rewriter
|
|
.create<Torch::PrimTolistOp>(
|
|
loc,
|
|
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
|
padsRearrange)
|
|
.getResult(0);
|
|
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
|
loc, rewriter.getStringAttr(mode));
|
|
|
|
// The constant value is a 0-d tensor, which needs to be converted to a
|
|
// float scalar as torch.pad op expects a float scalar.
|
|
auto constValueType =
|
|
constantValue.getType().cast<Torch::ValueTensorType>();
|
|
if (!constValueType) {
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
"Expect non-none constant value");
|
|
}
|
|
auto resultTensorType = Torch::ValueTensorType::get(
|
|
constValueType.getContext(), emptyShape, rewriter.getF64Type());
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
Value constFloatValue = rewriter.create<Torch::AtenToDtypeOp>(
|
|
loc, resultTensorType, constantValue,
|
|
Torch::getDtypeIntValueForType(rewriter, loc,
|
|
resultTensorType.getOptionalDtype()),
|
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
|
/*memory_format=*/none);
|
|
Value constScalar = rewriter.create<Torch::AtenItemOp>(
|
|
loc, rewriter.getType<Torch::FloatType>(), constFloatValue);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenPadOp>(
|
|
binder.op, resultType, data, padsSizeList, modeVal, constScalar);
|
|
return success();
|
|
});
|
|
patterns.onOp("Pow", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value lhs, rhs;
|
|
if (binder.tensorOperands(lhs, rhs) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenPowTensorTensorOp>(
|
|
binder.op, resultType, lhs, rhs);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value tensor;
|
|
if (binder.tensorOperand(tensor) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
rewriter.replaceOpWithNewOp<Torch::AtenCloneOp>(
|
|
binder.op, resultType, tensor, /*memory_format=*/noneVal);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"Mean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
if (binder.op->getNumOperands() == 1) {
|
|
Torch::ValueTensorType resultType;
|
|
Value x;
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
return failure();
|
|
rewriter.replaceOp(binder.op, x);
|
|
return success();
|
|
}
|
|
Torch::ValueTensorType resultType;
|
|
SmallVector<Value> valList;
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
Value numOperandsConstant = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), numOperands));
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
binder.tensorResultType(resultType))
|
|
return failure();
|
|
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
// Short circuit to binary add
|
|
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
binder.getLoc(), resultType, valList[0], valList[1], constOne);
|
|
if (numOperands == 2) {
|
|
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
|
|
binder.op, resultType, curr, numOperandsConstant);
|
|
return success();
|
|
}
|
|
// When binder.op->getNumOperands() > 2
|
|
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
|
|
binder.op->getContext());
|
|
for (int i = 2; i < numOperands; i++) {
|
|
if (i == numOperands - 1) {
|
|
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
binder.getLoc(), resultType, curr, valList[i], constOne);
|
|
} else {
|
|
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
binder.getLoc(), baseType, curr, valList[i], constOne);
|
|
}
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
|
|
binder.op, resultType, curr, numOperandsConstant);
|
|
return success();
|
|
});
|
|
patterns.onOp(
|
|
"IsInf", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value tensor;
|
|
int64_t neg;
|
|
int64_t pos;
|
|
if (binder.tensorOperand(tensor) ||
|
|
binder.s64IntegerAttr(neg, "detect_negative", 1) ||
|
|
binder.s64IntegerAttr(pos, "detect_positive", 1) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
if (neg == 0) {
|
|
// replace all negative infs with 0
|
|
tensor = rewriter.create<Torch::AtenReluOp>(
|
|
binder.getLoc(),
|
|
dyn_cast<Torch::ValueTensorType>(tensor.getType()), tensor);
|
|
}
|
|
if (pos == 0) {
|
|
// first use neg op to flip positive inf to negative inf. Then relu to
|
|
// replace all positive infs with 0.
|
|
Value flip = rewriter.create<Torch::AtenNegOp>(
|
|
binder.getLoc(),
|
|
dyn_cast<Torch::ValueTensorType>(tensor.getType()), tensor);
|
|
tensor = rewriter.create<Torch::AtenReluOp>(
|
|
binder.getLoc(), dyn_cast<Torch::ValueTensorType>(flip.getType()),
|
|
flip);
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIsinfOp>(binder.op, resultType,
|
|
tensor);
|
|
return success();
|
|
});
|
|
patterns.onOp("IsNaN", 9,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value tensor;
|
|
if (binder.tensorOperand(tensor) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIsnanOp>(
|
|
binder.op, resultType, tensor);
|
|
return success();
|
|
});
|
|
patterns.onOp("PRelu", 1,
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
Torch::ValueTensorType resultType;
|
|
Value tensor;
|
|
Value slope;
|
|
if (binder.tensorOperands(tensor, slope) ||
|
|
binder.tensorResultType(resultType)) {
|
|
return failure();
|
|
}
|
|
rewriter.replaceOpWithNewOp<Torch::AtenPreluOp>(
|
|
binder.op, resultType, tensor, slope);
|
|
return success();
|
|
});
|
|
}
|