2023-11-22 13:02:55 +08:00
|
|
|
|
//===------------------------------------------------------------*- C++ -*-===//
|
|
|
|
|
//
|
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
2024-02-06 08:09:41 +08:00
|
|
|
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
2024-01-19 08:33:10 +08:00
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
2023-11-22 13:02:55 +08:00
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::torch;
|
|
|
|
|
using namespace mlir::torch::onnx_c;
|
|
|
|
|
|
|
|
|
|
// Simple rewrites for the default domain.
|
|
|
|
|
// See: https://onnx.ai/onnx/operators/
|
|
|
|
|
// For operators that are effectively version invariant, we register with
|
|
|
|
|
// sinceVersion==1. We interpret this to include the following spec
|
|
|
|
|
// diffs that are irrelevant to this level of lowering:
|
|
|
|
|
// * Supported element types.
|
|
|
|
|
// * Limited broadcasting to full broadcasting support.
|
|
|
|
|
//
|
|
|
|
|
// There are a lot of spec revisions that basically generalized elementwise
|
|
|
|
|
// to be more normal and a direct translation vs a special case. This
|
|
|
|
|
// results in a lot of ONNX test cases that all reduce to the exact same
|
|
|
|
|
// thing here, so we simplify.
|
2024-01-16 03:26:46 +08:00
|
|
|
|
|
|
|
|
|
// utilities
|
|
|
|
|
namespace {
|
2024-04-16 14:54:46 +08:00
|
|
|
|
// In case the ReduceSum Op was not the first operation performed on the data,
|
|
|
|
|
// we provide the original operand through storeResult, which will be modified
|
|
|
|
|
// if the result will be passed onto another operation, and will be used for
|
|
|
|
|
// noop_with_empty_axes handling before that.
|
|
|
|
|
LogicalResult reducedSumImpl(OpBinder binder,
|
|
|
|
|
ConversionPatternRewriter &rewriter, Value data,
|
|
|
|
|
Torch::ValueTensorType resultType,
|
|
|
|
|
Value &storeResult, int64_t keepDims,
|
|
|
|
|
int64_t noop_with_empty_axes,
|
|
|
|
|
bool isIntermediateOp) {
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
|
Value axesVal;
|
|
|
|
|
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
2024-04-16 14:54:46 +08:00
|
|
|
|
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "unimplemented: expected input and result to have shapes");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
|
|
|
|
SmallVector<int64_t> inputShape{inputType.getSizes()};
|
|
|
|
|
SmallVector<int64_t> resultShape{resultType.getSizes()};
|
|
|
|
|
// if the shapes are equal, none of the dims is reduced
|
|
|
|
|
if (llvm::equal(inputShape, resultShape)) {
|
|
|
|
|
// simply fill in the op and return
|
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
if (areAllElementsDistinct(inputShape)) {
|
|
|
|
|
// The check for the input shape elements to be distinct is added
|
|
|
|
|
// for the cases like:
|
|
|
|
|
// Input: [3, 2, 2] -> Output: [3, 2]
|
|
|
|
|
// For the above case, from the input and output shape it can't be
|
|
|
|
|
// inferred whether the dim:1 is reduced or dim:2. To avoid these
|
|
|
|
|
// type of cases, the check has been placed.
|
|
|
|
|
SmallVector<int64_t> reduceDims;
|
|
|
|
|
unsigned resultShapeCounter = 0;
|
|
|
|
|
for (unsigned i = 0; i < inputShape.size(); i++) {
|
|
|
|
|
if (resultShapeCounter < resultShape.size() &&
|
|
|
|
|
inputShape[i] == resultShape[resultShapeCounter]) {
|
|
|
|
|
resultShapeCounter++;
|
|
|
|
|
} else {
|
|
|
|
|
reduceDims.push_back(i);
|
|
|
|
|
if (resultShapeCounter < resultShape.size() &&
|
|
|
|
|
resultShape[resultShapeCounter] == 1)
|
|
|
|
|
resultShapeCounter++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto i : reduceDims) {
|
|
|
|
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (axesList.empty()) {
|
|
|
|
|
Torch::BaseTensorType axesType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(axesVal.getType());
|
2024-04-16 14:54:46 +08:00
|
|
|
|
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
|
|
|
|
auto axesShape = axesTy.getSizes();
|
|
|
|
|
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
|
|
|
|
SmallVector<int64_t> selectSizes{1};
|
|
|
|
|
auto selType = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
selectSizes, axesType.getOptionalDtype());
|
|
|
|
|
int64_t numAxes = axesShape[0];
|
|
|
|
|
for (int64_t i = 0; i < numAxes; ++i) {
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getI64IntegerAttr(i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(), selType, axesVal, zero, iv);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
|
axesList.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> axesInts;
|
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
|
|
|
|
|
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getI64IntegerAttr(axesInts[i]));
|
|
|
|
|
axesList.push_back(iv);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Do not include absolute value in the noop
|
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
|
rewriter.replaceOp(binder.op, storeResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
|
axesList);
|
|
|
|
|
Value keepDimBool =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
|
|
|
|
Value dType = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
// If we are using the ReducedSum as an intermediate op to be passed into
|
|
|
|
|
// another operation, we might not want to replace the Op. So we create a new
|
|
|
|
|
// Op and store the result in a variable.
|
|
|
|
|
if (!isIntermediateOp) {
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
|
|
|
|
/*dtype=*/dType);
|
|
|
|
|
} else {
|
|
|
|
|
storeResult = rewriter.create<Torch::AtenSumDimIntListOp>(
|
|
|
|
|
binder.getLoc(), resultType, data, dimValueList, keepDimBool,
|
|
|
|
|
/*dtype=*/dType);
|
|
|
|
|
}
|
|
|
|
|
return success();
|
|
|
|
|
}
|
2024-01-16 03:26:46 +08:00
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-11-22 13:02:55 +08:00
|
|
|
|
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
2023-12-15 00:53:47 +08:00
|
|
|
|
OnnxCustomOpConversionPattern &patterns) {
|
2024-01-30 01:59:33 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"QuantizeLinear", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
llvm::SmallVector<Value> operands;
|
|
|
|
|
if (binder.tensorOperands(operands, 3) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
2024-01-19 08:33:10 +08:00
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
|
Value operand = operands[0];
|
|
|
|
|
Value scale = operands[1];
|
|
|
|
|
Value zeropoint = operands[2];
|
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
|
2024-01-30 01:59:33 +08:00
|
|
|
|
if (!scaleTy || !scaleTy.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
|
|
|
|
if (!resultType.hasDtype())
|
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
|
"requires known result dtype");
|
|
|
|
|
|
|
|
|
|
if (scaleTy.getSizes().size() == 0) {
|
|
|
|
|
Type qTy = resultType.getDtype();
|
|
|
|
|
|
|
|
|
|
if (qTy.isUnsignedInteger(8)) {
|
|
|
|
|
qTy = rewriter.getType<Torch::QUInt8Type>();
|
|
|
|
|
} else if (qTy.isSignedInteger(8)) {
|
|
|
|
|
qTy = rewriter.getType<Torch::QInt8Type>();
|
|
|
|
|
} else if (qTy.isSignedInteger(32)) {
|
|
|
|
|
qTy = rewriter.getType<Torch::QInt32Type>();
|
|
|
|
|
} else {
|
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
|
"unsupported result dtype");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
resultType.getOptionalSizes(), qTy);
|
|
|
|
|
auto torchqTy = Torch::getScalarTypeForType(qTy);
|
|
|
|
|
|
|
|
|
|
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
|
static_cast<int64_t>(torchqTy)));
|
|
|
|
|
|
|
|
|
|
scale = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
|
|
|
|
|
zeropoint = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
|
|
|
|
|
|
|
|
|
|
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
|
|
|
|
binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(
|
|
|
|
|
binder.op, resultType, quantize);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return failure();
|
|
|
|
|
});
|
2024-02-06 08:09:41 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"QLinearConv", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
llvm::SmallVector<Value> operands;
|
|
|
|
|
if ((binder.tensorOperands(operands, 8) &&
|
|
|
|
|
binder.tensorOperands(operands, 9)) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
Value a = operands[0];
|
|
|
|
|
Value aScale = operands[1];
|
|
|
|
|
Value aZp = operands[2];
|
|
|
|
|
Value b = operands[3];
|
|
|
|
|
Value bScale = operands[4];
|
|
|
|
|
Value bZp = operands[5];
|
|
|
|
|
Value cScale = operands[6];
|
|
|
|
|
Value cZp = operands[7];
|
|
|
|
|
Value c = operands.size() == 9 ? operands[8] : nullptr;
|
|
|
|
|
|
|
|
|
|
auto check = [](Value v) {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
2024-02-06 08:09:41 +08:00
|
|
|
|
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
|
|
|
|
|
};
|
|
|
|
|
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
|
|
|
|
!check(cScale) || !check(cScale))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "not supported for non per-tensor quantization");
|
|
|
|
|
|
|
|
|
|
auto extract = [&rewriter, &binder](Value v) {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
2024-02-06 08:09:41 +08:00
|
|
|
|
Type extractTy = rewriter.getType<Torch::FloatType>();
|
|
|
|
|
if (isa<IntegerType>(vTy.getDtype()))
|
|
|
|
|
extractTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
|
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
|
|
|
|
|
v);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
aZp = extract(aZp);
|
|
|
|
|
bZp = extract(bZp);
|
|
|
|
|
cZp = extract(cZp);
|
|
|
|
|
aScale = extract(aScale);
|
|
|
|
|
bScale = extract(bScale);
|
|
|
|
|
cScale = extract(cScale);
|
|
|
|
|
|
|
|
|
|
auto make = [&rewriter, &binder](Value v, Value scale,
|
|
|
|
|
Value zp) -> Value {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
2024-02-06 08:09:41 +08:00
|
|
|
|
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
|
|
|
|
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
|
|
|
binder.getLoc(), newTy, v, scale, zp);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
a = make(a, aScale, aZp);
|
|
|
|
|
b = make(b, bScale, bZp);
|
|
|
|
|
|
|
|
|
|
auto cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
resultType.getOptionalSizes(),
|
|
|
|
|
rewriter.getIntegerType(32, /*issigned=*/true));
|
|
|
|
|
|
|
|
|
|
// TODO(suderman): insert convolution operator.
|
|
|
|
|
llvm::SmallVector<Value> newOperands = {a, b};
|
|
|
|
|
if (c)
|
|
|
|
|
newOperands.push_back(c);
|
|
|
|
|
|
|
|
|
|
cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
resultType.getOptionalSizes(),
|
|
|
|
|
rewriter.getType<Torch::QInt32Type>());
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<NamedAttribute> newAttributes;
|
|
|
|
|
newAttributes.push_back(
|
|
|
|
|
rewriter.getNamedAttr("name", rewriter.getStringAttr("onnx.Conv")));
|
|
|
|
|
for (auto namedAttr : binder.op->getAttrDictionary()) {
|
|
|
|
|
if (namedAttr.getName().getValue().compare("name") == 0)
|
|
|
|
|
continue;
|
|
|
|
|
llvm::errs() << namedAttr.getName() << "\n";
|
|
|
|
|
newAttributes.push_back(namedAttr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c = rewriter
|
|
|
|
|
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
|
2024-02-29 04:18:02 +08:00
|
|
|
|
newAttributes,
|
|
|
|
|
binder.op->getRegions().size())
|
2024-02-06 08:09:41 +08:00
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
|
|
Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
|
|
|
|
|
bScale);
|
|
|
|
|
Value outZp = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
|
|
|
binder.getLoc(), cTy, c, outScale, outZp);
|
|
|
|
|
cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
resultType.getOptionalSizes(), rewriter.getF32Type());
|
|
|
|
|
|
|
|
|
|
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
|
|
|
|
|
c);
|
|
|
|
|
cTy = dyn_cast<Torch::ValueTensorType>(
|
|
|
|
|
getQTorchTypeFromTorchIntType(resultType));
|
|
|
|
|
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(
|
|
|
|
|
rewriter.getIntegerType(64),
|
|
|
|
|
static_cast<int64_t>(
|
|
|
|
|
Torch::getScalarTypeForType(cTy.getDtype()))));
|
|
|
|
|
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
|
|
|
|
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
|
|
|
|
|
c);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-01-25 04:28:48 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"QLinearMatMul", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
llvm::SmallVector<Value> operands;
|
|
|
|
|
if (binder.tensorOperands(operands, 8) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
Value a = operands[0];
|
|
|
|
|
Value aScale = operands[1];
|
|
|
|
|
Value aZp = operands[2];
|
|
|
|
|
Value b = operands[3];
|
|
|
|
|
Value bScale = operands[4];
|
|
|
|
|
Value bZp = operands[5];
|
|
|
|
|
Value cScale = operands[6];
|
|
|
|
|
Value cZp = operands[7];
|
|
|
|
|
|
|
|
|
|
auto check = [](Value v) {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
2024-01-25 04:28:48 +08:00
|
|
|
|
for (auto dim : vTy.getSizes())
|
|
|
|
|
if (dim != 1)
|
|
|
|
|
return false;
|
|
|
|
|
return true;
|
|
|
|
|
};
|
|
|
|
|
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
|
|
|
|
!check(cScale) || !check(cScale))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "not supported for non per-tensor quantization");
|
|
|
|
|
|
|
|
|
|
Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
rewriter.getType<Torch::ListType>(
|
|
|
|
|
rewriter.getType<Torch::IntType>()),
|
|
|
|
|
ValueRange{});
|
|
|
|
|
auto extract = [&rewriter, &binder, &emptyList](Value v) {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
2024-01-25 04:28:48 +08:00
|
|
|
|
if (!vTy.getSizes().empty()) {
|
|
|
|
|
vTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
|
|
|
|
|
v = rewriter.create<Torch::AtenReshapeOp>(binder.getLoc(), vTy, v,
|
|
|
|
|
emptyList);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type extractTy = rewriter.getType<Torch::FloatType>();
|
|
|
|
|
if (isa<IntegerType>(vTy.getDtype()))
|
|
|
|
|
extractTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
|
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
|
|
|
|
|
v);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
aZp = extract(aZp);
|
|
|
|
|
bZp = extract(bZp);
|
|
|
|
|
cZp = extract(cZp);
|
|
|
|
|
aScale = extract(aScale);
|
|
|
|
|
bScale = extract(bScale);
|
|
|
|
|
cScale = extract(cScale);
|
|
|
|
|
|
2024-02-06 08:09:41 +08:00
|
|
|
|
auto make = [&rewriter, &binder](Value v, Value scale,
|
|
|
|
|
Value zp) -> Value {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
2024-02-06 08:09:41 +08:00
|
|
|
|
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
2024-01-25 04:28:48 +08:00
|
|
|
|
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
|
|
|
binder.getLoc(), newTy, v, scale, zp);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
a = make(a, aScale, aZp);
|
|
|
|
|
b = make(b, bScale, bZp);
|
|
|
|
|
|
|
|
|
|
auto cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
resultType.getOptionalSizes(),
|
|
|
|
|
rewriter.getIntegerType(32, /*issigned=*/true));
|
|
|
|
|
|
|
|
|
|
Value c;
|
|
|
|
|
if (cTy.getSizes().size() == 2) {
|
|
|
|
|
c = rewriter.create<Torch::AtenMmOp>(binder.getLoc(), cTy, a, b);
|
|
|
|
|
} else {
|
|
|
|
|
c = rewriter.create<Torch::AtenBmmOp>(binder.getLoc(), cTy, a, b);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
resultType.getOptionalSizes(),
|
|
|
|
|
rewriter.getType<Torch::QInt32Type>());
|
|
|
|
|
|
|
|
|
|
Value mmScale = rewriter.create<Torch::AtenMulFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
|
|
|
|
|
bScale);
|
|
|
|
|
Value mmZp = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
|
|
|
binder.getLoc(), cTy, c, mmScale, mmZp);
|
|
|
|
|
cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
resultType.getOptionalSizes(), rewriter.getF32Type());
|
|
|
|
|
|
|
|
|
|
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
|
|
|
|
|
c);
|
2024-02-06 08:09:41 +08:00
|
|
|
|
cTy = dyn_cast<Torch::ValueTensorType>(
|
|
|
|
|
getQTorchTypeFromTorchIntType(resultType));
|
2024-01-25 04:28:48 +08:00
|
|
|
|
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(
|
|
|
|
|
rewriter.getIntegerType(64),
|
|
|
|
|
static_cast<int64_t>(
|
|
|
|
|
Torch::getScalarTypeForType(cTy.getDtype()))));
|
|
|
|
|
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
|
|
|
|
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
|
|
|
|
|
c);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
patterns.onOp("Reciprocal", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReciprocalOp>(
|
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Relu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value x;
|
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReluOp>(binder.op, resultType,
|
|
|
|
|
x);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp("Round", 11,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenRoundOp>(
|
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
2024-04-22 00:32:18 +08:00
|
|
|
|
"ScatterElements", 1,
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
|
int64_t axis;
|
|
|
|
|
std::string reduction;
|
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
|
binder.s64IntegerAttr(axis, "axis", 0) ||
|
|
|
|
|
binder.customOpNameStringAttr(reduction, "reduction", "none") ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
Value data = valList[0];
|
|
|
|
|
Value indices = valList[1];
|
|
|
|
|
Value updates = valList[2];
|
|
|
|
|
|
|
|
|
|
// ONNX allows negative axis.
|
|
|
|
|
if (axis < 0)
|
|
|
|
|
axis +=
|
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
|
|
|
|
|
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
|
|
|
|
|
|
|
|
if (reduction == "none") {
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenScatterSrcOp>(
|
|
|
|
|
binder.op, resultType, data, constAxis, indices, updates);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: Implement max and min cases
|
|
|
|
|
if (reduction == "mul") {
|
|
|
|
|
reduction = "multiply";
|
|
|
|
|
} else if (reduction == "max" || reduction == "min") {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "max/min reduction unsupported for scatter elements");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value cstStrReduction =
|
|
|
|
|
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
|
|
|
|
|
binder.op, resultType, data, constAxis, indices, updates,
|
|
|
|
|
cstStrReduction);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value x;
|
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSigmoidOp>(binder.op, resultType,
|
|
|
|
|
x);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp("Sin", 7,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSinOp>(
|
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp("Tanh", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTanhOp>(
|
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp("Sqrt", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSqrtOp>(
|
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Sub", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value x;
|
|
|
|
|
Value y;
|
|
|
|
|
if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
Value const1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSubTensorOp>(
|
|
|
|
|
binder.op, resultType, x, y, const1);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Sum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
if (binder.op->getNumOperands() == 1) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value x;
|
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOp(binder.op, x);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
Value const1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
|
// Short circuit to binary add
|
|
|
|
|
if (numOperands == 2) {
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
|
|
|
|
|
binder.op, resultType, valList[0], valList[1], const1);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
// When binder.op->getNumOperands() > 2
|
|
|
|
|
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
|
binder.getLoc(), resultType, valList[0], valList[1], const1);
|
|
|
|
|
for (int i = 2; i < numOperands; i++) {
|
|
|
|
|
if (i == numOperands - 1) {
|
|
|
|
|
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
|
binder.getLoc(), resultType, curr, valList[i], const1);
|
|
|
|
|
} else {
|
2024-03-01 00:18:46 +08:00
|
|
|
|
SmallVector<int64_t> resultBroadcastShapeInt;
|
|
|
|
|
SmallVector<Value> resultBroadcastShapeValue;
|
|
|
|
|
Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr,
|
|
|
|
|
valList[i], resultBroadcastShapeInt,
|
|
|
|
|
resultBroadcastShapeValue);
|
|
|
|
|
auto baseType = Torch::ValueTensorType::get(
|
|
|
|
|
binder.op->getContext(), resultBroadcastShapeInt,
|
|
|
|
|
resultType.getOptionalDtype());
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
|
binder.getLoc(), baseType, curr, valList[i], const1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(binder.op, curr);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp("Where", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
Value condition = valList[0];
|
|
|
|
|
Value x = valList[1];
|
|
|
|
|
Value y = valList[2];
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
|
|
|
|
|
binder.op, resultType, condition, x, y);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Xor", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value x;
|
|
|
|
|
Value y;
|
|
|
|
|
if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLogicalXorOp>(binder.op,
|
|
|
|
|
resultType, x, y);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
2024-04-09 03:23:33 +08:00
|
|
|
|
"Squeeze", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
Torch::ValueTensorType resultType;
|
2024-04-22 16:52:42 +08:00
|
|
|
|
SmallVector<Value> inputOperands;
|
|
|
|
|
if (binder.tensorOperands(inputOperands, binder.op->getNumOperands()) ||
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
2024-04-22 16:52:42 +08:00
|
|
|
|
|
|
|
|
|
Value data = inputOperands[0];
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
2024-04-22 16:52:42 +08:00
|
|
|
|
if (!inputType.hasSizes() || !resultType.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented: expected input and result to have shapes");
|
|
|
|
|
|
|
|
|
|
int64_t inputRank = inputType.getSizes().size();
|
|
|
|
|
int64_t resultRank = resultType.getSizes().size();
|
|
|
|
|
int64_t rankDiff = inputRank - resultRank;
|
|
|
|
|
if (rankDiff == 0) {
|
|
|
|
|
// In this case, no dimension is squeezed. Hence just replace the op
|
|
|
|
|
// with input.
|
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (inputOperands.size() == 1) {
|
|
|
|
|
// Case: `axes` value is not present which means squeeze all the
|
|
|
|
|
// dimensions with shape value 1.
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op,
|
|
|
|
|
resultType, data);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
2024-04-22 16:52:42 +08:00
|
|
|
|
|
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
|
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
|
|
|
|
// If the input shape and result shape is statically known then the
|
|
|
|
|
// list of dims to be squeezed can be derived from those shapes. As a
|
|
|
|
|
// result, we don't have to wait for the dim values to be known at
|
|
|
|
|
// runtime which is also expected by the downstream pipeline.
|
|
|
|
|
SmallVector<int64_t> inputShape(inputType.getSizes());
|
|
|
|
|
SmallVector<int64_t> resultShape(resultType.getSizes());
|
|
|
|
|
SmallVector<int64_t> squeezeDims;
|
|
|
|
|
unsigned resultShapeCounter = 0;
|
|
|
|
|
for (unsigned i = 0; i < inputRank; i++) {
|
|
|
|
|
if (resultShapeCounter < resultRank &&
|
|
|
|
|
inputShape[i] == resultShape[resultShapeCounter]) {
|
|
|
|
|
resultShapeCounter++;
|
|
|
|
|
} else {
|
|
|
|
|
squeezeDims.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto i : squeezeDims) {
|
|
|
|
|
dimList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dimList.empty()) {
|
|
|
|
|
Value axes = inputOperands[1];
|
|
|
|
|
Torch::BaseTensorType axesType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(axes.getType());
|
2024-04-22 16:52:42 +08:00
|
|
|
|
SmallVector<int64_t> selectSizes{1};
|
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
|
selectSizes, axesType.getOptionalDtype());
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
2024-04-22 16:52:42 +08:00
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
for (int i = 0; i < rankDiff; i++) {
|
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
|
dimList.push_back(dim);
|
|
|
|
|
}
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
}
|
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
2024-04-22 16:52:42 +08:00
|
|
|
|
rewriter.getType<Torch::ListType>(
|
|
|
|
|
rewriter.getType<Torch::IntType>()),
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
dimList);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimsSqueezeOp>(
|
|
|
|
|
binder.op, resultType, data, dimValueList);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Unsqueeze", 13,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
// Unlike squeeze where we are able to lower to Torch::PrimsSqueezeOp,
|
|
|
|
|
// pytorch does not support torch.unsqueeze to insert multiple new dims.
|
|
|
|
|
// discussion can be found here:
|
|
|
|
|
// https://github.com/pytorch/pytorch/issues/9410
|
|
|
|
|
// So, for now, we unroll into multiple unsqueezes.
|
2024-04-22 16:52:42 +08:00
|
|
|
|
Location loc = binder.getLoc();
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
Torch::ValueTensorType resultType;
|
2024-04-22 16:52:42 +08:00
|
|
|
|
Value data, axes;
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
2024-04-22 16:52:42 +08:00
|
|
|
|
if (!inputType.hasSizes() || !resultType.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented: expected input and result to have shapes");
|
|
|
|
|
|
|
|
|
|
int64_t inputRank = inputType.getSizes().size();
|
|
|
|
|
int64_t resultRank = resultType.getSizes().size();
|
|
|
|
|
int64_t rankDiff = resultRank - inputRank;
|
|
|
|
|
if (rankDiff == 0) {
|
|
|
|
|
// In this case, no dimension is unsqueezed. Hence just replace the op
|
|
|
|
|
// with input.
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
2024-04-22 16:52:42 +08:00
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> unsqueezeDims;
|
|
|
|
|
SmallVector<int64_t> inputShape(inputType.getSizes());
|
|
|
|
|
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
|
|
|
|
// If the input shape and result shape is statically known then the
|
|
|
|
|
// list of dims to be squeezed can be derived from those shapes. As a
|
|
|
|
|
// result, we don't have to wait for the dim values to be known at
|
|
|
|
|
// runtime which is also expected by the downstream pipeline.
|
|
|
|
|
SmallVector<int64_t> resultShape(resultType.getSizes());
|
|
|
|
|
unsigned inputShapeCounter = 0;
|
|
|
|
|
for (unsigned i = 0; i < resultRank; i++) {
|
|
|
|
|
if (inputShapeCounter < inputRank &&
|
|
|
|
|
inputShape[inputShapeCounter] == resultShape[i]) {
|
|
|
|
|
inputShapeCounter++;
|
|
|
|
|
} else {
|
|
|
|
|
unsqueezeDims.push_back(i);
|
|
|
|
|
}
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
}
|
2024-04-22 16:52:42 +08:00
|
|
|
|
} else {
|
|
|
|
|
SmallVector<int64_t> unsqueezeDimsInts;
|
|
|
|
|
if (!matchPattern(axes, m_OnnxListOfConstantInts(unsqueezeDimsInts)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "only support constant int axes values");
|
|
|
|
|
|
|
|
|
|
for (auto dim : unsqueezeDimsInts)
|
|
|
|
|
unsqueezeDims.push_back(dim < 0 ? dim + resultRank : dim);
|
|
|
|
|
// If we don't sort, unsqueezing first on 4 and then on 0 would fail
|
|
|
|
|
// for shape = {x,y,z}, and axes [4,0]
|
|
|
|
|
llvm::sort(unsqueezeDims.begin(), unsqueezeDims.end());
|
|
|
|
|
}
|
|
|
|
|
Value result = data;
|
|
|
|
|
SmallVector<int64_t> unsqueezeShape = inputShape;
|
|
|
|
|
for (auto dim : unsqueezeDims) {
|
|
|
|
|
unsqueezeShape.insert(unsqueezeShape.begin() + dim, 1);
|
|
|
|
|
Type unsqueezeType = resultType.getWithSizesAndDtype(
|
|
|
|
|
unsqueezeShape, resultType.getOptionalDtype());
|
|
|
|
|
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dim));
|
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(loc, unsqueezeType,
|
|
|
|
|
result, cstDim);
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(binder.op, result);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Softmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value input;
|
|
|
|
|
int64_t axis;
|
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
|
binder.s64IntegerAttr(axis, "axis", -1) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// ONNX allows negative axis.
|
|
|
|
|
if (axis < 0)
|
|
|
|
|
axis +=
|
|
|
|
|
cast<Torch::ValueTensorType>(input.getType()).getSizes().size();
|
|
|
|
|
|
|
|
|
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
|
|
|
|
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSoftmaxIntOp>(
|
|
|
|
|
binder.op, resultType, input, constAxis, /*dtype=*/noneVal);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2023-12-15 00:53:47 +08:00
|
|
|
|
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
2024-04-29 12:00:01 +08:00
|
|
|
|
// y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0
|
2023-12-15 00:53:47 +08:00
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
float alpha, gamma;
|
|
|
|
|
Value operand;
|
2024-04-29 12:00:01 +08:00
|
|
|
|
// Refer https://onnx.ai/onnx/operators/onnx__Selu.html for the default
|
|
|
|
|
// alpha and gamma values.
|
2023-12-15 00:53:47 +08:00
|
|
|
|
if (binder.tensorOperand(operand) ||
|
2024-04-29 12:00:01 +08:00
|
|
|
|
binder.f32FloatAttr(alpha, "alpha", 1.67326) ||
|
|
|
|
|
binder.f32FloatAttr(gamma, "gamma", 1.0507) ||
|
2023-12-15 00:53:47 +08:00
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
2024-04-29 12:00:01 +08:00
|
|
|
|
Torch::ValueTensorType inputType =
|
|
|
|
|
operand.getType().cast<Torch::ValueTensorType>();
|
|
|
|
|
|
2023-12-15 00:53:47 +08:00
|
|
|
|
Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
|
|
|
|
|
|
|
|
|
|
Value vScale = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));
|
|
|
|
|
|
2024-04-29 12:00:01 +08:00
|
|
|
|
Value cstOne = rewriter.create<Torch::ConstantFloatOp>(
|
2023-12-15 00:53:47 +08:00
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));
|
|
|
|
|
|
2024-04-29 12:00:01 +08:00
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
Value zeroTensor = rewriter.create<Torch::AtenZerosLikeOp>(
|
|
|
|
|
binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone,
|
|
|
|
|
cstNone, cstNone);
|
|
|
|
|
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
|
|
|
|
|
resultType, operand);
|
|
|
|
|
Value expMulAlpha = rewriter.create<Torch::AtenMulScalarOp>(
|
|
|
|
|
binder.getLoc(), resultType, exp, vAlpha);
|
|
|
|
|
Value expMulAlphaSubAlpha = rewriter.create<Torch::AtenSubScalarOp>(
|
|
|
|
|
binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne);
|
|
|
|
|
Value neg = rewriter.create<Torch::AtenMulScalarOp>(
|
|
|
|
|
binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale);
|
|
|
|
|
Value pos = rewriter.create<Torch::AtenMulScalarOp>(
|
|
|
|
|
binder.getLoc(), resultType, operand, vScale);
|
|
|
|
|
Type compareType = inputType.getWithSizesAndDtype(
|
|
|
|
|
inputType.getOptionalSizes(), rewriter.getI1Type());
|
|
|
|
|
Value xLessThanZero = rewriter.create<Torch::AtenLtTensorOp>(
|
|
|
|
|
binder.getLoc(), compareType, operand, zeroTensor);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
|
|
|
|
|
binder.op, resultType, xLessThanZero, neg, pos);
|
2023-12-15 00:53:47 +08:00
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-04-16 14:54:46 +08:00
|
|
|
|
patterns.onOp("ReduceL1", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
int64_t keepDims, noop_with_empty_axes;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperandAtIndex(operand, 0) ||
|
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
|
|
|
|
"noop_with_empty_axes", 0))
|
|
|
|
|
return failure();
|
2024-04-03 12:27:19 +08:00
|
|
|
|
|
2024-04-16 14:54:46 +08:00
|
|
|
|
Value data = rewriter.create<Torch::AtenAbsOp>(
|
|
|
|
|
binder.getLoc(), operand.getType(), operand);
|
2024-04-03 12:27:19 +08:00
|
|
|
|
|
2024-04-16 14:54:46 +08:00
|
|
|
|
return reducedSumImpl(binder, rewriter, data, resultType,
|
|
|
|
|
/*storeValue=*/operand, keepDims,
|
|
|
|
|
noop_with_empty_axes, false);
|
|
|
|
|
});
|
2024-04-23 14:03:05 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
int64_t keepDims, noop_with_empty_axes;
|
|
|
|
|
if (binder.tensorOperandAtIndex(operand, 0) ||
|
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
|
0))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// A ReduceL2 op is equivalent to the following sequence of operations:
|
|
|
|
|
// Mul(x, x) -> ReduceSum -> CastF32 -> Sqrt -> CastLike(resultType)
|
|
|
|
|
Value squareOfOperand = rewriter.create<Torch::AtenMulTensorOp>(
|
|
|
|
|
binder.getLoc(), operand.getType(), operand, operand);
|
|
|
|
|
|
|
|
|
|
auto reducedSum =
|
|
|
|
|
reducedSumImpl(binder, rewriter, squareOfOperand, resultType,
|
|
|
|
|
operand, keepDims, noop_with_empty_axes, true);
|
|
|
|
|
if (failed(reducedSum))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"Failed to perform sum operation on square of operand");
|
|
|
|
|
|
|
|
|
|
Value castDType = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 6));
|
|
|
|
|
|
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
Value constFalse =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
|
|
|
|
|
|
// Perform an AtenToDtype op on the squared sum of the operand, stored
|
|
|
|
|
// now in operand itself.
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto size = dyn_cast<Torch::ValueTensorType>(operand.getType())
|
2024-04-23 14:03:05 +08:00
|
|
|
|
.getOptionalSizes();
|
|
|
|
|
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
size, rewriter.getF32Type());
|
|
|
|
|
Value operandCast = rewriter.create<Torch::AtenToDtypeOp>(
|
|
|
|
|
binder.getLoc(), f32ResultType, operand, castDType,
|
|
|
|
|
/*non_blocking=*/constFalse, /*copy=*/constFalse,
|
|
|
|
|
/*memory_format=*/noneVal);
|
|
|
|
|
|
|
|
|
|
Value operandSqrt = rewriter.create<Torch::AtenSqrtOp>(
|
|
|
|
|
binder.getLoc(), f32ResultType, operandCast);
|
|
|
|
|
|
|
|
|
|
Value resultDtype = Torch::getDtypeIntValueForType(
|
|
|
|
|
rewriter, binder.getLoc(), resultType.getDtype());
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
|
|
|
|
binder.op, resultType, operandSqrt, resultDtype,
|
|
|
|
|
/*non_blocking=*/constFalse, /*copy=*/constFalse,
|
|
|
|
|
/*memory_format=*/noneVal);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-04-16 14:54:46 +08:00
|
|
|
|
patterns.onOp("ReduceSum", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value data;
|
|
|
|
|
int64_t keepDims, noop_with_empty_axes;
|
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
|
|
|
|
"noop_with_empty_axes", 0))
|
|
|
|
|
return failure();
|
2024-04-03 12:27:19 +08:00
|
|
|
|
|
2024-04-16 14:54:46 +08:00
|
|
|
|
return reducedSumImpl(binder, rewriter, data, resultType,
|
|
|
|
|
/*storeValue=*/data, keepDims,
|
|
|
|
|
noop_with_empty_axes, false);
|
|
|
|
|
});
|
2024-04-26 07:37:57 +08:00
|
|
|
|
patterns.onOp("ReduceLogSum", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value data;
|
|
|
|
|
int64_t keepDims, noop_with_empty_axes;
|
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
|
|
|
|
"noop_with_empty_axes", 0))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto reducedSumBool =
|
|
|
|
|
reducedSumImpl(binder, rewriter, data, resultType,
|
|
|
|
|
/*storeValue=*/data, keepDims,
|
|
|
|
|
noop_with_empty_axes, true);
|
|
|
|
|
|
|
|
|
|
if (failed(reducedSumBool))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"Failed to perform sum operation on square of operand");
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(
|
|
|
|
|
binder.op, resultType, data);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2023-12-19 04:37:31 +08:00
|
|
|
|
patterns.onOp(
|
2024-03-12 02:32:53 +08:00
|
|
|
|
"ReduceMean", 1,
|
2023-12-19 04:37:31 +08:00
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value data;
|
2024-04-02 19:24:04 +08:00
|
|
|
|
int64_t keepDims, noop_with_empty_axes;
|
2024-03-12 02:32:53 +08:00
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
2023-12-19 04:37:31 +08:00
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
|
0))
|
|
|
|
|
return failure();
|
2024-03-12 02:32:53 +08:00
|
|
|
|
|
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
|
|
|
|
|
|
Value axesVal;
|
|
|
|
|
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
2024-04-02 19:24:04 +08:00
|
|
|
|
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented: expected input and result to have shapes");
|
|
|
|
|
}
|
2024-03-12 02:32:53 +08:00
|
|
|
|
|
2024-04-02 19:24:04 +08:00
|
|
|
|
// If the input shape and result shape is statically known then the
|
|
|
|
|
// list of dims to be squeezed can be derived from those shapes. As a
|
|
|
|
|
// result, we don't have to wait for the dim values to be known at
|
|
|
|
|
// runtime which is also expected by the downstream pipeline.
|
|
|
|
|
if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) {
|
|
|
|
|
SmallVector<int64_t> inputShape{inputType.getSizes()};
|
|
|
|
|
SmallVector<int64_t> resultShape{resultType.getSizes()};
|
|
|
|
|
if (llvm::equal(inputShape, resultShape)) {
|
|
|
|
|
// Case: none of the dimension is reduced.
|
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
2024-04-03 12:27:19 +08:00
|
|
|
|
if (areAllElementsDistinct(inputShape)) {
|
2024-04-02 19:24:04 +08:00
|
|
|
|
// The check for the input shape elements to be distinct is added
|
|
|
|
|
// for the cases like:
|
|
|
|
|
// Input: [3, 2, 2] -> Output: [3, 2]
|
|
|
|
|
// For the above case, from the input and output shape it can't be
|
|
|
|
|
// inferred whether the dim:1 is reduced or dim:2. To avoid these
|
|
|
|
|
// type of cases, the check has been placed.
|
|
|
|
|
SmallVector<int64_t> reduceDims;
|
|
|
|
|
unsigned resultShapeCounter = 0;
|
|
|
|
|
for (unsigned i = 0; i < inputShape.size(); i++) {
|
|
|
|
|
if (resultShapeCounter < resultShape.size() &&
|
|
|
|
|
inputShape[i] == resultShape[resultShapeCounter]) {
|
|
|
|
|
resultShapeCounter++;
|
|
|
|
|
} else {
|
|
|
|
|
reduceDims.push_back(i);
|
|
|
|
|
if (resultShapeCounter < resultShape.size() &&
|
|
|
|
|
resultShape[resultShapeCounter] == 1)
|
|
|
|
|
resultShapeCounter++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto i : reduceDims) {
|
|
|
|
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (axesList.empty()) {
|
|
|
|
|
Torch::BaseTensorType axesType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(axesVal.getType());
|
2024-04-02 19:24:04 +08:00
|
|
|
|
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
|
|
|
|
auto axesShape = axesTy.getSizes();
|
|
|
|
|
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
2023-12-19 04:37:31 +08:00
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
2024-04-02 19:24:04 +08:00
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
|
|
|
|
SmallVector<int64_t> selectSizes{1};
|
|
|
|
|
auto selType = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
selectSizes, axesType.getOptionalDtype());
|
|
|
|
|
int64_t numAxes = axesShape[0];
|
|
|
|
|
for (int64_t i = 0; i < numAxes; ++i) {
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getI64IntegerAttr(i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(), selType, axesVal, zero, iv);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
|
axesList.push_back(dim);
|
|
|
|
|
}
|
2023-12-19 04:37:31 +08:00
|
|
|
|
}
|
2024-03-12 02:32:53 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> axesInts;
|
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
|
|
|
|
|
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getI64IntegerAttr(axesInts[i]));
|
|
|
|
|
axesList.push_back(iv);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// deal with case when axes is empty
|
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
2023-12-19 04:37:31 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
2024-03-12 02:32:53 +08:00
|
|
|
|
|
2023-12-19 04:37:31 +08:00
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
2024-03-12 02:32:53 +08:00
|
|
|
|
axesList);
|
|
|
|
|
Value keepDimBool =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
2023-12-29 01:31:41 +08:00
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
2023-12-19 04:37:31 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-03-07 08:48:21 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"ReduceMax", 13,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
// AtenAmaxOp allows us to pass a list of dims
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value data;
|
|
|
|
|
Value axes;
|
|
|
|
|
int64_t keepDims;
|
|
|
|
|
int64_t noop_with_empty_axes;
|
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
|
0))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
|
|
|
|
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
|
|
// If any of the input dims are 0 we set to the upper limit:
|
|
|
|
|
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
|
|
|
|
|
(llvm::any_of(dataTy.getSizes(),
|
|
|
|
|
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
|
|
|
|
|
keepDims)) {
|
|
|
|
|
auto dty = dataTy.getDtype();
|
|
|
|
|
Value scalar;
|
|
|
|
|
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
|
|
|
|
|
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
|
|
|
|
|
scalar = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(),
|
|
|
|
|
inf.convertToDouble()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
|
|
|
|
|
auto mx =
|
|
|
|
|
intTy.isSigned()
|
|
|
|
|
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
|
|
|
|
|
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
|
|
|
|
|
scalar = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
|
mx.getSExtValue()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> fillDims;
|
|
|
|
|
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
|
|
|
|
|
auto staticDim = resultType.getSizes()[i];
|
|
|
|
|
if (staticDim != Torch::kUnknownSize) {
|
|
|
|
|
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
|
rewriter.getI64IntegerAttr(staticDim)));
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy, data, iv));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
|
|
|
|
|
binder.op, resultType, fillDimsList, scalar, none, none, none,
|
|
|
|
|
none);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Previous version of the operation had the axes as an attribute:
|
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
|
llvm::SmallVector<int64_t> axesAttr;
|
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
|
|
|
|
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
|
|
|
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
|
rewriter.getI64IntegerAttr(axesAttr[i])));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Extract the axes values from the axes operand:
|
|
|
|
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
|
|
|
|
Torch::BaseTensorType axesType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(axes.getType());
|
2024-03-07 08:48:21 +08:00
|
|
|
|
SmallVector<int64_t> selectSizes{1};
|
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
|
selectSizes, axesType.getOptionalDtype());
|
|
|
|
|
auto sizes = axesType.getSizes();
|
|
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
|
|
|
|
|
// Extract the value of each axes:
|
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
|
axesList.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle the noop case:
|
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Deal with case when no axes arg is passed but not a noop:
|
|
|
|
|
if (axesList.empty()) {
|
|
|
|
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
|
|
|
|
.getSizes()
|
|
|
|
|
.size();
|
|
|
|
|
for (int i = 0; i < numDims; i++) {
|
|
|
|
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
axesList.push_back(curr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle negative axis:
|
|
|
|
|
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
|
|
|
|
torchIntTy, data);
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
|
|
|
|
for (Value &axes : axesList) {
|
|
|
|
|
Value isNegative =
|
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
|
isNegative);
|
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
|
binder.getLoc(), isNegative, rankVal);
|
|
|
|
|
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
|
|
|
|
finalOffset);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
|
|
|
|
|
Value keepDimBool =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
|
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
|
2023-12-19 04:37:31 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"ReduceMin", 13,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
// AtenAminOp allows us to pass a list of dims
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value data;
|
|
|
|
|
Value axes;
|
|
|
|
|
int64_t keepDims;
|
|
|
|
|
int64_t noop_with_empty_axes;
|
2024-02-28 14:48:07 +08:00
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
2023-12-19 04:37:31 +08:00
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
|
0))
|
|
|
|
|
return failure();
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
|
|
|
|
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
|
|
|
|
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
|
|
// If any of the input dims are 0 we set to the upper limit:
|
|
|
|
|
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
|
|
|
|
|
(llvm::any_of(dataTy.getSizes(),
|
|
|
|
|
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
|
|
|
|
|
keepDims)) {
|
|
|
|
|
auto dty = dataTy.getDtype();
|
|
|
|
|
Value scalar;
|
|
|
|
|
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
|
|
|
|
|
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
|
|
|
|
|
scalar = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(),
|
|
|
|
|
inf.convertToDouble()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
|
|
|
|
|
auto mx =
|
|
|
|
|
intTy.isSigned()
|
|
|
|
|
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
|
|
|
|
|
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
|
|
|
|
|
scalar = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
|
mx.getSExtValue()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> fillDims;
|
|
|
|
|
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
|
|
|
|
|
auto staticDim = resultType.getSizes()[i];
|
|
|
|
|
if (staticDim != Torch::kUnknownSize) {
|
|
|
|
|
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
|
rewriter.getI64IntegerAttr(staticDim)));
|
|
|
|
|
continue;
|
2023-12-19 04:37:31 +08:00
|
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy, data, iv));
|
2023-12-19 04:37:31 +08:00
|
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
|
|
|
|
|
binder.op, resultType, fillDimsList, scalar, none, none, none,
|
|
|
|
|
none);
|
2023-12-19 04:37:31 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
|
|
|
|
// Previous version of the operation had the axes as an attribute:
|
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
|
llvm::SmallVector<int64_t> axesAttr;
|
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
|
|
|
|
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
|
|
|
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
|
rewriter.getI64IntegerAttr(axesAttr[i])));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Extract the axes values from the axes operand:
|
|
|
|
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
|
|
|
|
Torch::BaseTensorType axesType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(axes.getType());
|
2024-02-28 14:48:07 +08:00
|
|
|
|
SmallVector<int64_t> selectSizes{1};
|
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
|
selectSizes, axesType.getOptionalDtype());
|
|
|
|
|
auto sizes = axesType.getSizes();
|
|
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
|
|
|
|
|
// Extract the value of each axes:
|
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
|
axesList.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle the noop case:
|
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Deal with case when no axes arg is passed but not a noop:
|
|
|
|
|
if (axesList.empty()) {
|
|
|
|
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
|
|
|
|
.getSizes()
|
|
|
|
|
.size();
|
|
|
|
|
for (int i = 0; i < numDims; i++) {
|
|
|
|
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
axesList.push_back(curr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle negative axis:
|
|
|
|
|
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
|
|
|
|
torchIntTy, data);
|
2023-12-19 04:37:31 +08:00
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
2024-02-28 14:48:07 +08:00
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
|
|
|
|
for (Value &axes : axesList) {
|
2023-12-19 04:37:31 +08:00
|
|
|
|
Value isNegative =
|
2024-02-28 14:48:07 +08:00
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
2023-12-19 04:37:31 +08:00
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
|
isNegative);
|
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
2024-02-28 14:48:07 +08:00
|
|
|
|
binder.getLoc(), isNegative, rankVal);
|
|
|
|
|
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
|
|
|
|
finalOffset);
|
2023-12-19 04:37:31 +08:00
|
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
2023-12-19 04:37:31 +08:00
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
2024-02-28 14:48:07 +08:00
|
|
|
|
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
|
|
|
|
|
Value keepDimBool =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
2023-12-19 04:37:31 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2023-12-16 03:37:49 +08:00
|
|
|
|
|
|
|
|
|
patterns.onOp("Shape", 9,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
|
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2023-12-16 13:23:51 +08:00
|
|
|
|
|
2024-04-30 15:44:41 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
2023-12-16 13:23:51 +08:00
|
|
|
|
|
2024-04-30 15:44:41 +08:00
|
|
|
|
// 1/2 * (exp(x) – exp(-x))
|
|
|
|
|
Value x = rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType,
|
|
|
|
|
operand);
|
|
|
|
|
Value neg = rewriter.create<Torch::AtenNegOp>(binder.getLoc(),
|
|
|
|
|
resultType, operand);
|
|
|
|
|
Value y =
|
|
|
|
|
rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType, neg);
|
|
|
|
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value z = rewriter.create<Torch::AtenSubTensorOp>(
|
|
|
|
|
binder.getLoc(), resultType, x, y, cstOne);
|
|
|
|
|
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(2));
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
|
|
|
|
|
binder.op, resultType, z, cstTwo);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2023-12-21 02:09:39 +08:00
|
|
|
|
|
2023-12-28 09:53:07 +08:00
|
|
|
|
// split with fixed-size parts
|
|
|
|
|
// Arguments:
|
|
|
|
|
// - input: the tensor to split
|
|
|
|
|
// Attributes:
|
|
|
|
|
// - axis: the axis along which to split the input
|
|
|
|
|
// - num_outputs: the number of outputs to produce
|
|
|
|
|
// Outputs:
|
|
|
|
|
// - outputs: the produced outputs. Variadic with num_outputs elements.
|
|
|
|
|
// Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
|
|
|
|
|
// tensors
|
|
|
|
|
// so we need to unpack the list
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Value self;
|
|
|
|
|
int64_t axis;
|
2024-04-22 00:31:56 +08:00
|
|
|
|
int64_t numOutputs;
|
2023-12-28 09:53:07 +08:00
|
|
|
|
if (binder.tensorOperand(self))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Not converting to AtenSplitTensorOp due to input "
|
|
|
|
|
"tensor mismatch");
|
|
|
|
|
if (binder.s64IntegerAttr(axis, "axis", 0))
|
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
|
"Failed to get axis attribute");
|
2024-04-22 00:31:56 +08:00
|
|
|
|
if (binder.s64IntegerAttr(numOutputs, "num_outputs", 2))
|
2023-12-28 09:53:07 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Failed to get num_outputs attribute");
|
|
|
|
|
|
2024-04-22 00:31:56 +08:00
|
|
|
|
auto loc = binder.getLoc();
|
2023-12-28 09:53:07 +08:00
|
|
|
|
auto result0Ty =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::ValueTensorType>(binder.op->getResult(0).getType());
|
|
|
|
|
auto resultNTy = cast<Torch::ValueTensorType>(
|
|
|
|
|
binder.op->getResults().back().getType());
|
|
|
|
|
auto selfTy = cast<Torch::ValueTensorType>(self.getType());
|
2023-12-28 09:53:07 +08:00
|
|
|
|
|
|
|
|
|
int64_t dim = axis;
|
|
|
|
|
if (dim < 0)
|
|
|
|
|
dim += selfTy.getSizes().size();
|
|
|
|
|
|
2024-04-22 00:31:56 +08:00
|
|
|
|
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getI64IntegerAttr(dim));
|
|
|
|
|
|
|
|
|
|
Value vNumOutputs = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getI64IntegerAttr(numOutputs));
|
|
|
|
|
|
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
|
|
|
|
|
Value vDimSize = rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
|
loc, rewriter.getType<Torch::IntType>(), self, dimValue);
|
|
|
|
|
|
|
|
|
|
Value addNumOutputs =
|
|
|
|
|
rewriter.create<Torch::AtenAddIntOp>(loc, vDimSize, vNumOutputs);
|
|
|
|
|
Value subOne =
|
|
|
|
|
rewriter.create<Torch::AtenSubIntOp>(loc, addNumOutputs, one);
|
|
|
|
|
Value splitSize =
|
|
|
|
|
rewriter.create<Torch::AtenFloordivIntOp>(loc, subOne, vNumOutputs);
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> outputs;
|
|
|
|
|
Value step = one;
|
|
|
|
|
Value start = zero;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < numOutputs - 1; ++i) {
|
|
|
|
|
Value end =
|
|
|
|
|
rewriter.create<Torch::AtenAddIntOp>(loc, start, splitSize);
|
|
|
|
|
Value slice = rewriter.create<Torch::AtenSliceTensorOp>(
|
|
|
|
|
loc, result0Ty, self, dimValue, start, end, step);
|
|
|
|
|
start = end;
|
|
|
|
|
outputs.push_back(slice);
|
2023-12-28 09:53:07 +08:00
|
|
|
|
}
|
|
|
|
|
|
2024-04-22 00:31:56 +08:00
|
|
|
|
Value end = vDimSize;
|
|
|
|
|
Value lastSlice = rewriter.create<Torch::AtenSliceTensorOp>(
|
|
|
|
|
loc, resultNTy, self, dimValue, start, end, step);
|
|
|
|
|
outputs.push_back(lastSlice);
|
2023-12-28 09:53:07 +08:00
|
|
|
|
|
2024-04-22 00:31:56 +08:00
|
|
|
|
rewriter.replaceOp(binder.op, outputs);
|
2023-12-28 09:53:07 +08:00
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
// split with variable parts
|
|
|
|
|
// Arguments:
|
|
|
|
|
// - input: the tensor to split
|
|
|
|
|
// - split: the sizes of the splits to be produced
|
|
|
|
|
// Attributes:
|
|
|
|
|
// - axis: the axis along which to split the input
|
|
|
|
|
// - num_outputs: the number of outputs to produce
|
|
|
|
|
// Outputs:
|
|
|
|
|
// - outputs: the produced outputs. Variadic with num_outputs elements.
|
|
|
|
|
// Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
|
|
|
|
|
// tensors
|
|
|
|
|
// so we need to unpack the list
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Value self;
|
|
|
|
|
Value split;
|
|
|
|
|
int64_t axis;
|
|
|
|
|
int64_t num_outputs;
|
|
|
|
|
if (binder.tensorOperandAtIndex(self, 0) ||
|
|
|
|
|
binder.tensorOperandAtIndex(split, 1))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Not converting to AtenSplitWithSizesOp due to input "
|
|
|
|
|
"tensor mismatch");
|
|
|
|
|
if (binder.s64IntegerAttr(axis, "axis", 0))
|
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
|
"Failed to get axis attribute");
|
|
|
|
|
if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Failed to get num_outputs attribute");
|
|
|
|
|
|
|
|
|
|
auto result0Ty =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::ValueTensorType>(binder.op->getResult(0).getType());
|
2023-12-28 09:53:07 +08:00
|
|
|
|
auto selfTy =
|
|
|
|
|
cast<Torch::ValueTensorType>(binder.op->getOperand(0).getType());
|
|
|
|
|
|
|
|
|
|
int64_t dim = axis;
|
|
|
|
|
if (dim < 0)
|
|
|
|
|
dim += selfTy.getSizes().size();
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
|
|
|
|
for (auto result : binder.op->getResultTypes()) {
|
2024-04-11 21:47:35 +08:00
|
|
|
|
int64_t d = cast<Torch::ValueTensorType>(result).getSizes()[dim];
|
2023-12-28 09:53:07 +08:00
|
|
|
|
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Torch::PrimTolistOp splitToList = rewriter.create<Torch::PrimTolistOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
Torch::ListType::get(rewriter.getType<Torch::IntType>()), split);
|
|
|
|
|
|
|
|
|
|
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim));
|
|
|
|
|
|
|
|
|
|
// TODO: Attempting to use the shape expected by the ONNX mlir as ground
|
|
|
|
|
// truth. For now just use dynamic shapes.
|
|
|
|
|
auto resultOuterType =
|
|
|
|
|
Torch::ListType::get(rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
/*std::optional<llvm::ArrayRef<int64_t>>=*/intermediateShape,
|
|
|
|
|
result0Ty.getOptionalDtype()));
|
|
|
|
|
Torch::AtenSplitWithSizesOp new_op =
|
|
|
|
|
rewriter.create<Torch::AtenSplitWithSizesOp>(
|
|
|
|
|
binder.getLoc(), resultOuterType, self,
|
|
|
|
|
splitToList.getResult(0), dimValue);
|
|
|
|
|
|
|
|
|
|
// the onnx op is variadic with multiple results, but AtenSplitWithSizes
|
|
|
|
|
// outputs a list so we need to unpack the list
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListUnpackOp>(
|
|
|
|
|
binder.op, binder.op->getResults().getType(), new_op.getResult());
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
|
2023-12-21 02:09:39 +08:00
|
|
|
|
patterns.onOp("Tan", 7,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTanOp>(
|
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
|
2023-12-16 07:30:05 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Transpose", 13,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
auto loc = binder.getLoc();
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto operandType = cast<Torch::ValueTensorType>(operand.getType());
|
2023-12-16 07:30:05 +08:00
|
|
|
|
TensorType tensorType = operandType.toBuiltinTensor();
|
|
|
|
|
if (!tensorType || !tensorType.hasRank())
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// Default permutation is to reverse orders:
|
|
|
|
|
int64_t rank = tensorType.getRank();
|
|
|
|
|
llvm::SmallVector<int64_t> reverse(rank);
|
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
|
reverse[i] = rank - i - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> permutations;
|
|
|
|
|
if (failed(binder.s64IntegerArrayAttr(permutations, "perm", reverse)))
|
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
|
"Failed to obtain permutations");
|
|
|
|
|
|
|
|
|
|
if (static_cast<int64_t>(permutations.size()) != rank)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Permutation length does not match operand rank");
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> shape(tensorType.getShape());
|
|
|
|
|
llvm::SmallVector<int64_t> current(rank);
|
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
|
current[i] = i;
|
|
|
|
|
}
|
|
|
|
|
|
2024-02-17 05:04:47 +08:00
|
|
|
|
for (auto &dim : permutations)
|
|
|
|
|
dim = dim < 0 ? dim + rank : dim;
|
|
|
|
|
|
|
|
|
|
// We need to override to the destination if known:
|
|
|
|
|
if (resultType.hasSizes()) {
|
|
|
|
|
for (int i = 0; i < rank; ++i) {
|
|
|
|
|
shape[permutations[i]] = resultType.getSizes()[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Convert dynamic shape dimension:
|
2024-01-30 01:59:33 +08:00
|
|
|
|
for (unsigned i = 0; i < shape.size(); i++) {
|
2024-01-27 01:46:54 +08:00
|
|
|
|
if (shape[i] == ShapedType::kDynamic)
|
|
|
|
|
shape[i] = Torch::kUnknownSize;
|
|
|
|
|
}
|
|
|
|
|
|
2023-12-16 07:30:05 +08:00
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
|
if (current[i] == permutations[i])
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
|
int64_t target = i + 1;
|
|
|
|
|
for (; target < rank; ++target) {
|
|
|
|
|
if (current[target] == permutations[i])
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::swap(shape[i], shape[target]);
|
|
|
|
|
std::swap(current[i], current[target]);
|
|
|
|
|
|
|
|
|
|
Value dim0 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
|
|
|
|
|
Value dim1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), target));
|
|
|
|
|
|
|
|
|
|
operand = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
|
loc,
|
|
|
|
|
Torch::ValueTensorType::get(tensorType.getContext(), shape,
|
|
|
|
|
operandType.getOptionalDtype()),
|
|
|
|
|
operand, dim0, dim1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(binder.op, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-01-04 11:41:10 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Slice", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultTorchType;
|
|
|
|
|
Value operand, starts, ends;
|
|
|
|
|
// Handle if axes are not provided
|
2023-12-27 02:20:13 +08:00
|
|
|
|
|
2024-01-04 11:41:10 +08:00
|
|
|
|
if (binder.tensorOperandAtIndex(operand, 0) ||
|
|
|
|
|
binder.tensorOperandAtIndex(starts, 1) ||
|
|
|
|
|
binder.tensorOperandAtIndex(ends, 2) ||
|
|
|
|
|
binder.tensorResultType(resultTorchType)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto context = rewriter.getContext();
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto operandTorchTy = cast<Torch::ValueTensorType>(operand.getType());
|
2024-01-04 11:41:10 +08:00
|
|
|
|
auto operandTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
dyn_cast<RankedTensorType>(operandTorchTy.toBuiltinTensor());
|
2024-01-04 11:41:10 +08:00
|
|
|
|
|
|
|
|
|
if (!operandTy)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"Expected tensor operator argument to be a ranked tensor type");
|
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto startsTorchTy = cast<Torch::ValueTensorType>(starts.getType());
|
2024-01-04 11:41:10 +08:00
|
|
|
|
auto startsTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
dyn_cast<RankedTensorType>(startsTorchTy.toBuiltinTensor());
|
2024-01-04 11:41:10 +08:00
|
|
|
|
int startSize = startsTy.getDimSize(0);
|
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto endsTorchTy = cast<Torch::ValueTensorType>(ends.getType());
|
|
|
|
|
auto endsTy = dyn_cast<RankedTensorType>(endsTorchTy.toBuiltinTensor());
|
2024-01-04 11:41:10 +08:00
|
|
|
|
int endSize = endsTy.getDimSize(0);
|
|
|
|
|
auto resultTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
dyn_cast<RankedTensorType>(resultTorchType.toBuiltinTensor());
|
2024-01-04 11:41:10 +08:00
|
|
|
|
if (!resultTy)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Expected result type to be a ranked tensor type");
|
|
|
|
|
|
|
|
|
|
Location loc = binder.getLoc();
|
|
|
|
|
|
|
|
|
|
// Binding `axes` from its arguments or through a default value
|
|
|
|
|
Value axes;
|
|
|
|
|
if (binder.getNumOperands() >= 4) {
|
|
|
|
|
if (binder.tensorOperandAtIndex(axes, 3)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Binding `steps` from its arguments or through a default value
|
|
|
|
|
Value steps;
|
|
|
|
|
if (binder.getNumOperands() >= 5) {
|
|
|
|
|
if (binder.tensorOperandAtIndex(steps, 4)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// The default `steps` value is a 1d tensor filled with ones with a
|
onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec (#2848)
This PR contains three commits to update the validation checks in the
ONNX -> Torch conversion pass for the AveragePool, Pad, and Slice operators:
> onnx: fix preconditions for lowering AveragePool ops
>
> The `pads` attribute of the AveragePool operator specifies the value to
> pad at both the beginning as well as the end of the axis (see
> https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so
> the size of this attribute should be twice the rank of the input tensor.
> However, our TorchOnnxToTorch bails out early since it incorrectly
> compares the pads attribute with the rank (not twice the rank) of the
> input tensor.
>
> This patch fixes the code to match the spec and adds a lit test.
> onnx: allow optional constant value for Pad operator
>
> The `constant_value` input of the onnx.Pad operator is optional (see
> https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the
existing
> logic for lowering the operator into the Torch dialect assumes that it
> is mandatory.
>
> This patch makes the attribute optional and constructs a default value
> (a list of zeros the size of the input tensor) if the attribute was not
> specified.
> onnx: fix checks for axes and steps inputs of Slice operator
>
> The ONNX Spec for the Slice operator allows the `starts` and `ends`
> inputs to have fewer indices that the dimensions of the `data` tensor
> (see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
> expects these inputs to be as many as the `data` tensor's dimensions.
>
> More precisely, the spec requires that the `starts` and `ends` inputs
> are only as long as the `axes` input, but since the `axes` input is
> optional, the default type for the `axes` input has to match the type
> for the `starts` and `ends` inputs. Moreover, the number of indices in
> the `steps` input also has to match those in the `axes` inputs (instad
> of matching the dimensions of the `data` input).
>
> This patch fixes the checks in the TorchOnnxToTorch conversion so that
> they match the ONNX spec.
2024-02-08 13:19:27 +08:00
|
|
|
|
// size equal to the size of `starts` and `ends`.
|
2024-01-04 11:41:10 +08:00
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
Value sizeStepInput = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getType<Torch::IntType>(),
|
onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec (#2848)
This PR contains three commits to update the validation checks in the
ONNX -> Torch conversion pass for the AveragePool, Pad, and Slice operators:
> onnx: fix preconditions for lowering AveragePool ops
>
> The `pads` attribute of the AveragePool operator specifies the value to
> pad at both the beginning as well as the end of the axis (see
> https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so
> the size of this attribute should be twice the rank of the input tensor.
> However, our TorchOnnxToTorch bails out early since it incorrectly
> compares the pads attribute with the rank (not twice the rank) of the
> input tensor.
>
> This patch fixes the code to match the spec and adds a lit test.
> onnx: allow optional constant value for Pad operator
>
> The `constant_value` input of the onnx.Pad operator is optional (see
> https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the
existing
> logic for lowering the operator into the Torch dialect assumes that it
> is mandatory.
>
> This patch makes the attribute optional and constructs a default value
> (a list of zeros the size of the input tensor) if the attribute was not
> specified.
> onnx: fix checks for axes and steps inputs of Slice operator
>
> The ONNX Spec for the Slice operator allows the `starts` and `ends`
> inputs to have fewer indices that the dimensions of the `data` tensor
> (see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
> expects these inputs to be as many as the `data` tensor's dimensions.
>
> More precisely, the spec requires that the `starts` and `ends` inputs
> are only as long as the `axes` input, but since the `axes` input is
> optional, the default type for the `axes` input has to match the type
> for the `starts` and `ends` inputs. Moreover, the number of indices in
> the `steps` input also has to match those in the `axes` inputs (instad
> of matching the dimensions of the `data` input).
>
> This patch fixes the checks in the TorchOnnxToTorch conversion so that
> they match the ONNX spec.
2024-02-08 13:19:27 +08:00
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize));
|
2024-01-04 11:41:10 +08:00
|
|
|
|
Value sizeStepsInput = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc,
|
|
|
|
|
Torch::ListType::get(
|
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
|
sizeStepInput);
|
|
|
|
|
steps = rewriter.create<Torch::AtenOnesOp>(
|
onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec (#2848)
This PR contains three commits to update the validation checks in the
ONNX -> Torch conversion pass for the AveragePool, Pad, and Slice operators:
> onnx: fix preconditions for lowering AveragePool ops
>
> The `pads` attribute of the AveragePool operator specifies the value to
> pad at both the beginning as well as the end of the axis (see
> https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so
> the size of this attribute should be twice the rank of the input tensor.
> However, our TorchOnnxToTorch bails out early since it incorrectly
> compares the pads attribute with the rank (not twice the rank) of the
> input tensor.
>
> This patch fixes the code to match the spec and adds a lit test.
> onnx: allow optional constant value for Pad operator
>
> The `constant_value` input of the onnx.Pad operator is optional (see
> https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the
existing
> logic for lowering the operator into the Torch dialect assumes that it
> is mandatory.
>
> This patch makes the attribute optional and constructs a default value
> (a list of zeros the size of the input tensor) if the attribute was not
> specified.
> onnx: fix checks for axes and steps inputs of Slice operator
>
> The ONNX Spec for the Slice operator allows the `starts` and `ends`
> inputs to have fewer indices that the dimensions of the `data` tensor
> (see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
> expects these inputs to be as many as the `data` tensor's dimensions.
>
> More precisely, the spec requires that the `starts` and `ends` inputs
> are only as long as the `axes` input, but since the `axes` input is
> optional, the default type for the `axes` input has to match the type
> for the `starts` and `ends` inputs. Moreover, the number of indices in
> the `steps` input also has to match those in the `axes` inputs (instad
> of matching the dimensions of the `data` input).
>
> This patch fixes the checks in the TorchOnnxToTorch conversion so that
> they match the ONNX spec.
2024-02-08 13:19:27 +08:00
|
|
|
|
loc, startsTorchTy, sizeStepsInput, none, none, none, none);
|
2024-01-04 11:41:10 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 &&
|
|
|
|
|
startSize == endSize))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Expected the rank of starts and ends tensors to be 1 "
|
|
|
|
|
"and their dimensions to match");
|
|
|
|
|
|
2024-02-20 02:26:29 +08:00
|
|
|
|
if (axes) {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto axesTorchTy = cast<Torch::ValueTensorType>(axes.getType());
|
2024-02-20 02:26:29 +08:00
|
|
|
|
auto axesTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
dyn_cast<RankedTensorType>(axesTorchTy.toBuiltinTensor());
|
2024-02-20 02:26:29 +08:00
|
|
|
|
int64_t numAxes = axesTy.getDimSize(0);
|
2024-01-04 11:41:10 +08:00
|
|
|
|
|
2024-02-20 02:26:29 +08:00
|
|
|
|
if (!(axesTy && numAxes == endSize))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Axes should be the same size of starts and ends");
|
|
|
|
|
}
|
2024-01-04 11:41:10 +08:00
|
|
|
|
|
|
|
|
|
auto stepsTy = steps.getType()
|
|
|
|
|
.cast<Torch::ValueTensorType>()
|
|
|
|
|
.toBuiltinTensor()
|
|
|
|
|
.dyn_cast<RankedTensorType>();
|
|
|
|
|
|
|
|
|
|
if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Steps should be the same size of starts and ends");
|
|
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
|
|
|
|
|
auto select = [&](Value v, Value k) -> Value {
|
2024-04-28 05:00:56 +08:00
|
|
|
|
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
2024-01-04 11:41:10 +08:00
|
|
|
|
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
|
|
|
|
loc,
|
|
|
|
|
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
|
|
|
|
|
ty.getOptionalDtype()),
|
|
|
|
|
v, zero, k);
|
|
|
|
|
Value item = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
loc, rewriter.getType<Torch::IntType>(), sel);
|
|
|
|
|
return item;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
|
|
|
|
|
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
|
2024-02-17 05:35:25 +08:00
|
|
|
|
if (operandTy.getDimSize(i) != resultTy.getDimSize(i))
|
2024-01-04 11:41:10 +08:00
|
|
|
|
intermediateShape[i] = -1;
|
2024-02-17 05:35:25 +08:00
|
|
|
|
if (intermediateShape[i] == ShapedType::kDynamic)
|
|
|
|
|
intermediateShape[i] = Torch::kUnknownSize;
|
2024-01-04 11:41:10 +08:00
|
|
|
|
}
|
|
|
|
|
auto intermediateType = Torch::ValueTensorType::get(
|
|
|
|
|
context, intermediateShape, resultTorchType.getOptionalDtype());
|
2024-02-20 02:26:29 +08:00
|
|
|
|
for (int i = 0; i < endSize; ++i) {
|
2024-01-04 11:41:10 +08:00
|
|
|
|
|
|
|
|
|
Value k = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
|
|
|
|
loc,
|
|
|
|
|
Torch::ValueTensorType::get(
|
|
|
|
|
context, ArrayRef<int64_t>{1},
|
|
|
|
|
rewriter.getIntegerType(64, /*signed*/ 1)),
|
|
|
|
|
k);
|
|
|
|
|
|
|
|
|
|
Value start = select(starts, kTensor);
|
|
|
|
|
Value end = select(ends, kTensor);
|
2024-02-20 02:26:29 +08:00
|
|
|
|
Value axis = axes ? select(axes, kTensor) : k;
|
2024-01-04 11:41:10 +08:00
|
|
|
|
Value step = select(steps, kTensor);
|
|
|
|
|
|
|
|
|
|
auto sliceType = intermediateType;
|
2024-02-20 02:26:29 +08:00
|
|
|
|
sliceType = i == (endSize - 1) ? resultTorchType : sliceType;
|
2024-01-04 11:41:10 +08:00
|
|
|
|
operand = rewriter.create<Torch::AtenSliceTensorOp>(
|
|
|
|
|
loc, sliceType, operand, axis, start, end, step);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(binder.op, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2023-12-27 02:20:13 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value data;
|
|
|
|
|
Value shape;
|
|
|
|
|
int64_t allowzero;
|
|
|
|
|
if (binder.tensorOperands(data, shape) ||
|
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(allowzero, "allowzero", 0))
|
|
|
|
|
return failure();
|
2024-02-08 09:44:07 +08:00
|
|
|
|
|
|
|
|
|
// If the result shape is static then we can create a result shape list
|
|
|
|
|
// directly using the result shape values (integers).
|
|
|
|
|
if (resultType.hasSizes()) {
|
|
|
|
|
bool hasStaticShape = resultType.areAllSizesKnown();
|
|
|
|
|
ArrayRef<int64_t> resultShapeInt = resultType.getSizes();
|
|
|
|
|
if (hasStaticShape) {
|
|
|
|
|
SmallVector<Value> resultShape;
|
|
|
|
|
for (int64_t dim : resultShapeInt) {
|
|
|
|
|
resultShape.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(dim)));
|
|
|
|
|
}
|
|
|
|
|
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
Torch::ListType::get(
|
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
|
resultShape);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
|
|
|
|
binder.op, resultType, data, resultShapeList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-12-27 02:20:13 +08:00
|
|
|
|
Torch::BaseTensorType shapeType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(shape.getType());
|
2023-12-27 02:20:13 +08:00
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
|
Type selectResultType = shapeType.getWithSizesAndDtype(
|
|
|
|
|
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
|
|
|
|
auto shapeSizes =
|
|
|
|
|
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
|
|
|
|
|
auto dataSizes =
|
|
|
|
|
dyn_cast<Torch::ValueTensorType>(data.getType()).getSizes();
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
if (allowzero == 0) {
|
|
|
|
|
// convert shape (tensor) into torch int list while dealing with zero
|
|
|
|
|
// vals
|
|
|
|
|
for (int i = 0; i < shapeSizes[0]; i++) {
|
|
|
|
|
// Go through the shape list and get each dim in the list
|
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(), selectResultType, shape, zero, selectIndex);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
|
// deal with zero axis values: replace with original dim value in
|
|
|
|
|
// input
|
|
|
|
|
Value isZero =
|
|
|
|
|
rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
|
isZero =
|
|
|
|
|
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero);
|
2024-04-03 07:19:57 +08:00
|
|
|
|
|
|
|
|
|
int64_t dataRank = dataSizes.size();
|
|
|
|
|
if (i < dataRank) {
|
|
|
|
|
auto torchIntTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
auto int64Ty = rewriter.getIntegerType(64, true);
|
|
|
|
|
auto dimTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
ArrayRef<int64_t>(), int64Ty);
|
|
|
|
|
auto boolTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
ArrayRef<int64_t>(), rewriter.getI1Type());
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
|
|
|
|
Value inDim = rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy, data, iv);
|
|
|
|
|
isZero = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
|
|
|
|
binder.getLoc(), boolTy, isZero);
|
|
|
|
|
inDim = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
|
|
|
|
binder.getLoc(), dimTy, inDim);
|
|
|
|
|
dim = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
|
|
|
|
binder.getLoc(), dimTy, dim);
|
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenWhereSelfOp>(
|
|
|
|
|
binder.getLoc(), dimTy, isZero, inDim, dim);
|
|
|
|
|
dim = rewriter.create<Torch::AtenItemOp>(
|
2023-12-27 02:20:13 +08:00
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
2024-04-03 07:19:57 +08:00
|
|
|
|
finalDim);
|
2023-12-27 02:20:13 +08:00
|
|
|
|
}
|
2024-04-03 07:19:57 +08:00
|
|
|
|
dimList.push_back(dim);
|
2023-12-27 02:20:13 +08:00
|
|
|
|
}
|
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
Torch::ListType::get(
|
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
|
dimList);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
|
|
|
|
binder.op, resultType, data, dimValueList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
// convert axes (tensor) into torch int list
|
|
|
|
|
for (int i = 0; i < shapeSizes[0]; i++) {
|
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(), selectResultType, shape, zero, selectIndex);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
|
dimList.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
|
dimList);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(binder.op, resultType,
|
|
|
|
|
data, dimValueList);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-03-05 03:07:03 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"ReduceProd", 13,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
// ReduceProd allows us to pass a list of dims but AtenProdDimIn only
|
|
|
|
|
// allow one dim as input.
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value data;
|
|
|
|
|
Value axes;
|
|
|
|
|
int64_t keepDims;
|
|
|
|
|
int64_t noop_with_empty_axes;
|
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
|
0))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
|
|
|
|
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
|
|
if (!resultType.hasSizes() || !resultType.areAllSizesKnown() ||
|
|
|
|
|
!dataTy.areAllSizesKnown())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"Expected the input and result type to have known sizes");
|
|
|
|
|
|
|
|
|
|
int64_t rank = dataTy.getSizes().size();
|
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
|
|
|
|
|
|
|
|
|
// Previous version of the operation had the axes as an attribute:
|
|
|
|
|
llvm::SmallVector<int64_t> axesAttr;
|
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
|
|
|
|
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
|
|
|
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
|
rewriter.getI64IntegerAttr(axesAttr[i])));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle cases that axes are explicitly specified.
|
|
|
|
|
// Extract the axes values from the axes operand.
|
|
|
|
|
// This really shouldn't happen but it helps pass weird tests.
|
|
|
|
|
// TODO: Derive the chosen axes from the data type and final result type
|
|
|
|
|
// instead of using the dynamic axes at operand[1].
|
|
|
|
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
|
|
|
|
Torch::BaseTensorType axesType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(axes.getType());
|
2024-03-05 03:07:03 +08:00
|
|
|
|
auto sizes = axesType.getSizes();
|
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
axesType.getWithSizesAndDtype(llvm::SmallVector<int64_t>{1},
|
|
|
|
|
axesType.getOptionalDtype()),
|
|
|
|
|
axes, zero, selectIndex);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
|
|
|
|
torchIntTy, extract);
|
|
|
|
|
axesList.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle the noop case:
|
|
|
|
|
// When axes is empty and noop_with_empty_axes is set to true, input
|
|
|
|
|
// tensor will not be reduced, and the output tensor would be
|
|
|
|
|
// equivalent to input tensor.
|
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle case when no axes arg is passed but not a noop:
|
|
|
|
|
// Manually set positive axis to all dims.
|
|
|
|
|
if (axesList.empty()) {
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
|
|
|
|
axesList.push_back(dimValue);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle negative axis:
|
|
|
|
|
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
|
|
|
|
torchIntTy, data);
|
|
|
|
|
for (Value &axes : axesList) {
|
|
|
|
|
Value isNegative =
|
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
|
isNegative);
|
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
|
binder.getLoc(), isNegative, rankVal);
|
|
|
|
|
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
|
|
|
|
finalOffset);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle multiple axes case:
|
|
|
|
|
// ReduceProd on each dim, always set keepDimsBool == True to avoid
|
|
|
|
|
// segfault.
|
|
|
|
|
Value trueVal =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
SmallVector<int64_t> intermediateShape(rank, Torch::kUnknownSize);
|
|
|
|
|
Value dataReduceProd = data;
|
|
|
|
|
for (int i = 0, numAxes = axesList.size(); i < numAxes; i++) {
|
|
|
|
|
auto axis = axesList[i];
|
|
|
|
|
if (keepDims && i == numAxes - 1) {
|
|
|
|
|
dataReduceProd = rewriter.create<Torch::AtenProdDimIntOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
dataTy.getWithSizesAndDtype(resultType.getSizes(),
|
|
|
|
|
dataTy.getOptionalDtype()),
|
|
|
|
|
dataReduceProd, axis, trueVal, noneVal);
|
|
|
|
|
rewriter.replaceOp(binder.op, dataReduceProd);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
Type resultTyReduceProd = dataTy.getWithSizesAndDtype(
|
|
|
|
|
ArrayRef(intermediateShape), dataTy.getOptionalDtype());
|
|
|
|
|
dataReduceProd = rewriter.create<Torch::AtenProdDimIntOp>(
|
|
|
|
|
binder.getLoc(), resultTyReduceProd, dataReduceProd, axis,
|
|
|
|
|
trueVal, noneVal);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Derived the final shape of the tensor after prod loop of each axis.
|
|
|
|
|
SmallVector<int64_t> dataReduceProdSize;
|
|
|
|
|
auto dataSize = dataTy.getSizes();
|
|
|
|
|
auto resultTypeSizes = resultType.getSizes();
|
|
|
|
|
if (!keepDims) {
|
|
|
|
|
// Handle the keepDimsBool == False case:
|
|
|
|
|
// 2 point algorithm to derive the static shape after prod loop.
|
|
|
|
|
int j = 0;
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
if (resultTypeSizes.size() && dataSize[i] == resultTypeSizes[j]) {
|
|
|
|
|
dataReduceProdSize.push_back(resultTypeSizes[i]);
|
|
|
|
|
j++;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
dataReduceProdSize.push_back(1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Handle the keepDimsBool == False case:
|
|
|
|
|
// Reshape the prod loop result to the final result shape.
|
|
|
|
|
SmallVector<Value> dataReduceProdShape;
|
|
|
|
|
for (auto dim : dataReduceProdSize)
|
|
|
|
|
dataReduceProdShape.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(dim)));
|
|
|
|
|
Value dataReduceProdShapeList =
|
|
|
|
|
rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
rewriter.getType<Torch::ListType>(
|
|
|
|
|
rewriter.getType<Torch::IntType>()),
|
|
|
|
|
dataReduceProdShape);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
|
|
|
|
binder.op, resultType, dataReduceProd, dataReduceProdShapeList);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-01-16 03:26:46 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
// ONNX.Range(start, limit, delta) -- limit is exclusive
|
|
|
|
|
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value start, limit, delta;
|
|
|
|
|
auto loc = binder.getLoc();
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
if (binder.tensorOperandAtIndex(start, 0) ||
|
|
|
|
|
binder.tensorOperandAtIndex(limit, 1) ||
|
|
|
|
|
binder.tensorOperandAtIndex(delta, 2) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// Convert a 0-dimensional/Scalar Tensor ([]) to Scalar Torch Numeric
|
|
|
|
|
// Value torch.tensor(1.1) equivalent in ONNX to 1.1 as an example
|
|
|
|
|
// type of start, limit, delta can be one of: double, float, int16,
|
|
|
|
|
// int32, int64 Assuming start, limit and delta to be same type (could
|
|
|
|
|
// they be different?)
|
|
|
|
|
Torch::BaseTensorType startTensorType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(start.getType());
|
2024-01-16 03:26:46 +08:00
|
|
|
|
bool isFloatDType = startTensorType.getDtype().isF64() ||
|
|
|
|
|
startTensorType.getDtype().isF32();
|
|
|
|
|
bool isIntDType = startTensorType.getDtype().isInteger(16) ||
|
|
|
|
|
startTensorType.getDtype().isInteger(32) ||
|
|
|
|
|
startTensorType.getDtype().isInteger(64);
|
|
|
|
|
if (!isFloatDType && !isIntDType) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op, "Expected the start, limit, delta to be one of "
|
|
|
|
|
"double, float, int16, int32, int64");
|
|
|
|
|
}
|
|
|
|
|
Value scalarStart, scalarLimit, scalarDelta;
|
|
|
|
|
if (isFloatDType) {
|
|
|
|
|
scalarStart = getItemOp<Torch::FloatType>(binder, rewriter, start);
|
|
|
|
|
scalarLimit = getItemOp<Torch::FloatType>(binder, rewriter, limit);
|
|
|
|
|
scalarDelta = getItemOp<Torch::FloatType>(binder, rewriter, delta);
|
|
|
|
|
} else {
|
|
|
|
|
scalarStart = getItemOp<Torch::IntType>(binder, rewriter, start);
|
|
|
|
|
scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, limit);
|
|
|
|
|
scalarDelta = getItemOp<Torch::IntType>(binder, rewriter, delta);
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenArangeStartStepOp>(
|
|
|
|
|
binder.op, resultType, scalarStart, scalarLimit, scalarDelta, none,
|
|
|
|
|
none, none, none);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-03-07 09:01:05 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Size", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto loc = binder.getLoc();
|
|
|
|
|
auto &op = binder.op;
|
|
|
|
|
auto operandTy = cast<Torch::BaseTensorType>(operand.getType());
|
|
|
|
|
|
|
|
|
|
if (!operandTy.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "input rank unknown");
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> dims;
|
|
|
|
|
int64_t rank = operandTy.getSizes().size();
|
|
|
|
|
for (int i = 0; i < rank; ++i) {
|
|
|
|
|
auto iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
|
loc, rewriter.getType<Torch::IntType>(), operand, iv);
|
|
|
|
|
dims.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
|
|
|
|
|
if (dims.empty()) {
|
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
|
|
|
|
|
op, resultType, one, none, none, cstFalse);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value prod = dims[0];
|
|
|
|
|
for (int i = 1, s = dims.size(); i < s; ++i)
|
|
|
|
|
prod = rewriter.create<Torch::AtenMulIntOp>(loc, prod, dims[i]);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
|
|
|
|
|
op, resultType, prod, none, none, cstFalse);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-01-25 01:26:21 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Tile", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
Value repeatDims;
|
|
|
|
|
if (binder.tensorOperands(operand, repeatDims) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// convert repeatDims tensor to list of ints
|
|
|
|
|
auto repeatDimsSizes =
|
|
|
|
|
dyn_cast<Torch::ValueTensorType>(repeatDims.getType()).getSizes();
|
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
|
Torch::BaseTensorType shapeType =
|
2024-04-28 05:00:56 +08:00
|
|
|
|
cast<Torch::BaseTensorType>(repeatDims.getType());
|
2024-01-25 01:26:21 +08:00
|
|
|
|
Type selectResultType = shapeType.getWithSizesAndDtype(
|
|
|
|
|
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
for (int i = 0; i < repeatDimsSizes[0]; i++) {
|
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
|
binder.getLoc(), selectResultType, repeatDims, zero, selectIndex);
|
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
|
dimList.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
binder.getLoc(),
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
|
dimList);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTileOp>(binder.op, resultType,
|
|
|
|
|
operand, dimValueList);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-01-23 04:56:39 +08:00
|
|
|
|
patterns.onOp(
|
2024-04-04 00:42:48 +08:00
|
|
|
|
"TopK", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
2024-01-23 04:56:39 +08:00
|
|
|
|
Torch::ValueTensorType Values_type, Indices_type;
|
2024-04-04 00:42:48 +08:00
|
|
|
|
Value input, kValue;
|
2024-01-23 04:56:39 +08:00
|
|
|
|
int64_t axis;
|
|
|
|
|
bool largest, sorted;
|
2024-04-04 00:42:48 +08:00
|
|
|
|
if (binder.tensorOperandAtIndex(input, 0) ||
|
|
|
|
|
binder.tensorOperandAtIndex(kValue, 1) ||
|
2024-01-23 04:56:39 +08:00
|
|
|
|
binder.s64IntegerAttr(axis, "axis", -1) ||
|
|
|
|
|
binder.s64BoolAttr(largest, "largest", true) ||
|
|
|
|
|
binder.s64BoolAttr(sorted, "sorted", true) ||
|
|
|
|
|
binder.tensorResultTypeAtIndex(Values_type, 0) ||
|
|
|
|
|
binder.tensorResultTypeAtIndex(Indices_type, 1))
|
|
|
|
|
return failure();
|
2024-04-04 00:42:48 +08:00
|
|
|
|
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
|
2024-01-23 04:56:39 +08:00
|
|
|
|
if (!maybeRank)
|
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
|
"Unimplemented: unranked tensor");
|
|
|
|
|
unsigned rank = *maybeRank;
|
|
|
|
|
axis = Torch::toPositiveDim(axis, rank);
|
|
|
|
|
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
|
|
|
|
|
Value cstLargest =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), largest);
|
|
|
|
|
Value cstSorted =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), sorted);
|
2024-04-04 00:42:48 +08:00
|
|
|
|
Value kValueInt = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), kValue);
|
2024-01-23 04:56:39 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTopkOp>(
|
2024-04-04 00:42:48 +08:00
|
|
|
|
binder.op, Values_type, Indices_type, input, kValueInt, cstAxis,
|
|
|
|
|
cstLargest, cstSorted);
|
2024-01-23 04:56:39 +08:00
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-01-25 01:26:21 +08:00
|
|
|
|
patterns.onOp("Sign", 9,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value operand;
|
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSignOp>(
|
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-03-25 22:59:07 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Softplus", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value input;
|
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
// out = ln(exp(x) + 1)
|
|
|
|
|
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
|
|
|
|
|
resultType, input);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLog1pOp>(binder.op, resultType,
|
|
|
|
|
exp);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value input;
|
|
|
|
|
int64_t upper;
|
|
|
|
|
if (binder.tensorOperandAtIndex(input, 0) ||
|
|
|
|
|
binder.s64IntegerAttr(upper, "upper", 1) ||
|
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value diagonal;
|
|
|
|
|
if (binder.tensorOperandAtIndex(diagonal, 1)) {
|
|
|
|
|
diagonal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
|
|
|
|
} else {
|
|
|
|
|
diagonal = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), diagonal);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (upper) {
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTriuOp>(binder.op, resultType,
|
|
|
|
|
input, diagonal);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTrilOp>(binder.op, resultType,
|
|
|
|
|
input, diagonal);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp("ThresholdedRelu", 10,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
Value input;
|
|
|
|
|
float alpha;
|
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
|
binder.f32FloatAttr(alpha, "alpha", 1.0)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
|
|
|
|
|
Value value = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenThresholdOp>(
|
|
|
|
|
binder.op, resultType, input, cstAlpha, value);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2024-04-23 00:58:07 +08:00
|
|
|
|
patterns.onOp(
|
|
|
|
|
"RandomNormal", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
SmallString<64> name("torch.onnx.seed");
|
|
|
|
|
auto seedAttr = binder.op->getAttr(name);
|
|
|
|
|
if (seedAttr)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented: support not present for seed attribute");
|
|
|
|
|
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
int64_t dtypeIntOnnx;
|
|
|
|
|
float mean, scale;
|
|
|
|
|
SmallVector<int64_t> shape;
|
|
|
|
|
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
|
|
|
|
binder.f32FloatAttr(mean, "mean", 0.0) ||
|
|
|
|
|
binder.f32FloatAttr(scale, "scale", 1.0) ||
|
|
|
|
|
binder.s64IntegerArrayAttr(shape, "shape", {}) ||
|
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::optional<int64_t> dtypeIntTorch =
|
|
|
|
|
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
|
|
|
|
if (!dtypeIntTorch.has_value()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented support for the given dtype conversion");
|
|
|
|
|
}
|
|
|
|
|
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
|
|
|
|
|
|
|
|
|
Value shapeList = createConstantIntList(binder, rewriter, shape);
|
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
|
|
|
|
|
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
|
|
|
|
|
binder.op->getLoc(), resultType, shapeList,
|
|
|
|
|
/*dtype=*/constDtype,
|
|
|
|
|
/*layout=*/cstNone,
|
|
|
|
|
/*device=*/cstNone, /*pinMemory=*/cstNone,
|
|
|
|
|
/*memoryFormat=*/cstNone);
|
|
|
|
|
|
|
|
|
|
Value cstMean = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), mean));
|
|
|
|
|
Value cstStd = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), scale));
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenNormalFunctionalOp>(
|
|
|
|
|
binder.op, resultType, self, cstMean, cstStd,
|
|
|
|
|
/*generator=*/cstNone);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"RandomNormalLike", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
SmallString<64> name("torch.onnx.seed");
|
|
|
|
|
auto seedAttr = binder.op->getAttr(name);
|
|
|
|
|
if (seedAttr)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented: support not present for seed attribute");
|
|
|
|
|
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
int64_t dtypeIntOnnx;
|
|
|
|
|
float mean, scale;
|
|
|
|
|
SmallVector<int64_t> shape;
|
|
|
|
|
Value input;
|
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
|
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
|
|
|
|
binder.f32FloatAttr(mean, "mean", 0.0) ||
|
|
|
|
|
binder.f32FloatAttr(scale, "scale", 1.0) ||
|
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::optional<int64_t> dtypeIntTorch =
|
|
|
|
|
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
|
|
|
|
if (!dtypeIntTorch.has_value()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented support for the given dtype conversion");
|
|
|
|
|
}
|
|
|
|
|
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
|
|
|
|
|
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
Value cstFalse =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
|
input = rewriter.create<Torch::AtenToDtypeOp>(
|
|
|
|
|
binder.op->getLoc(), resultType, input, constDtype,
|
|
|
|
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
|
|
|
|
/*memory_format=*/cstNone);
|
|
|
|
|
|
|
|
|
|
Value cstMean = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), mean));
|
|
|
|
|
Value cstStd = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), scale));
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenNormalFunctionalOp>(
|
|
|
|
|
binder.op, resultType, input, cstMean, cstStd,
|
|
|
|
|
/*generator=*/cstNone);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"RandomUniform", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
SmallString<64> name("torch.onnx.seed");
|
|
|
|
|
auto seedAttr = binder.op->getAttr(name);
|
|
|
|
|
if (seedAttr)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented: support not present for seed attribute");
|
|
|
|
|
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
int64_t dtypeIntOnnx;
|
|
|
|
|
float high, low;
|
|
|
|
|
SmallVector<int64_t> shape;
|
|
|
|
|
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
|
|
|
|
binder.f32FloatAttr(high, "high", 1.0) ||
|
|
|
|
|
binder.f32FloatAttr(low, "low", 0.0) ||
|
|
|
|
|
binder.s64IntegerArrayAttr(shape, "shape", {}) ||
|
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::optional<int64_t> dtypeIntTorch =
|
|
|
|
|
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
|
|
|
|
if (!dtypeIntTorch.has_value()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented support for the given dtype conversion");
|
|
|
|
|
}
|
|
|
|
|
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
|
|
|
|
|
|
|
|
|
Value shapeList = createConstantIntList(binder, rewriter, shape);
|
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
|
|
|
|
|
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
|
|
|
|
|
binder.op->getLoc(), resultType, shapeList,
|
|
|
|
|
/*dtype=*/constDtype,
|
|
|
|
|
/*layout=*/cstNone,
|
|
|
|
|
/*device=*/cstNone, /*pinMemory=*/cstNone,
|
|
|
|
|
/*memoryFormat=*/cstNone);
|
|
|
|
|
|
|
|
|
|
Value cstHigh = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), high));
|
|
|
|
|
Value cstLow = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), low));
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenUniformOp>(
|
|
|
|
|
binder.op, resultType, self, cstLow, cstHigh,
|
|
|
|
|
/*generator=*/cstNone);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
patterns.onOp(
|
|
|
|
|
"RandomUniformLike", 1,
|
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
|
SmallString<64> name("torch.onnx.seed");
|
|
|
|
|
auto seedAttr = binder.op->getAttr(name);
|
|
|
|
|
if (seedAttr)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented: support not present for seed attribute");
|
|
|
|
|
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
|
int64_t dtypeIntOnnx;
|
|
|
|
|
float high, low;
|
|
|
|
|
SmallVector<int64_t> shape;
|
|
|
|
|
Value input;
|
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
|
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
|
|
|
|
binder.f32FloatAttr(high, "high", 1.0) ||
|
|
|
|
|
binder.f32FloatAttr(low, "low", 0.0) ||
|
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::optional<int64_t> dtypeIntTorch =
|
|
|
|
|
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
|
|
|
|
if (!dtypeIntTorch.has_value()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
binder.op,
|
|
|
|
|
"unimplemented support for the given dtype conversion");
|
|
|
|
|
}
|
|
|
|
|
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
|
|
|
|
|
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
Value cstFalse =
|
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
|
input = rewriter.create<Torch::AtenToDtypeOp>(
|
|
|
|
|
binder.op->getLoc(), resultType, input, constDtype,
|
|
|
|
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
|
|
|
|
/*memory_format=*/cstNone);
|
|
|
|
|
|
|
|
|
|
Value cstHigh = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), high));
|
|
|
|
|
Value cstLow = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), low));
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenUniformOp>(
|
|
|
|
|
binder.op, resultType, input, cstLow, cstHigh,
|
|
|
|
|
/*generator=*/cstNone);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
2023-12-15 00:53:47 +08:00
|
|
|
|
}
|