2023-11-22 13:02:55 +08:00
|
|
|
//===------------------------------------------------------------*- C++ -*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
2024-02-06 08:09:41 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
2024-01-19 08:33:10 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
2023-11-22 13:02:55 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::onnx_c;
|
|
|
|
|
|
|
|
// Simple rewrites for the default domain.
|
|
|
|
// See: https://onnx.ai/onnx/operators/
|
|
|
|
// For operators that are effectively version invariant, we register with
|
|
|
|
// sinceVersion==1. We interpret this to include the following spec
|
|
|
|
// diffs that are irrelevant to this level of lowering:
|
|
|
|
// * Supported element types.
|
|
|
|
// * Limited broadcasting to full broadcasting support.
|
|
|
|
//
|
|
|
|
// There are a lot of spec revisions that basically generalized elementwise
|
|
|
|
// to be more normal and a direct translation vs a special case. This
|
|
|
|
// results in a lot of ONNX test cases that all reduce to the exact same
|
|
|
|
// thing here, so we simplify.
|
2024-01-16 03:26:46 +08:00
|
|
|
|
|
|
|
// utilities
|
|
|
|
// Templatized function to get an item op of a type
|
|
|
|
namespace {
|
|
|
|
template <typename T>
|
|
|
|
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
|
|
|
Value &ofItem) {
|
|
|
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
|
|
|
rewriter.getType<T>(), ofItem);
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
2023-11-22 13:02:55 +08:00
|
|
|
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
2023-12-15 00:53:47 +08:00
|
|
|
OnnxCustomOpConversionPattern &patterns) {
|
2024-01-30 01:59:33 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"QuantizeLinear", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
llvm::SmallVector<Value> operands;
|
|
|
|
if (binder.tensorOperands(operands, 3) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
2024-01-19 08:33:10 +08:00
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
Value operand = operands[0];
|
|
|
|
Value scale = operands[1];
|
|
|
|
Value zeropoint = operands[2];
|
|
|
|
|
|
|
|
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
|
|
|
|
if (!scaleTy || !scaleTy.hasSizes())
|
|
|
|
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
|
|
|
if (!resultType.hasDtype())
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"requires known result dtype");
|
|
|
|
|
|
|
|
if (scaleTy.getSizes().size() == 0) {
|
|
|
|
Type qTy = resultType.getDtype();
|
|
|
|
|
|
|
|
if (qTy.isUnsignedInteger(8)) {
|
|
|
|
qTy = rewriter.getType<Torch::QUInt8Type>();
|
|
|
|
} else if (qTy.isSignedInteger(8)) {
|
|
|
|
qTy = rewriter.getType<Torch::QInt8Type>();
|
|
|
|
} else if (qTy.isSignedInteger(32)) {
|
|
|
|
qTy = rewriter.getType<Torch::QInt32Type>();
|
|
|
|
} else {
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"unsupported result dtype");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
resultType.getOptionalSizes(), qTy);
|
|
|
|
auto torchqTy = Torch::getScalarTypeForType(qTy);
|
|
|
|
|
|
|
|
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
static_cast<int64_t>(torchqTy)));
|
|
|
|
|
|
|
|
scale = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
|
|
|
|
zeropoint = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
|
|
|
|
|
|
|
|
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
|
|
|
binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(
|
|
|
|
binder.op, resultType, quantize);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
return failure();
|
|
|
|
});
|
2024-02-06 08:09:41 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"QLinearConv", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
llvm::SmallVector<Value> operands;
|
|
|
|
if ((binder.tensorOperands(operands, 8) &&
|
|
|
|
binder.tensorOperands(operands, 9)) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Value a = operands[0];
|
|
|
|
Value aScale = operands[1];
|
|
|
|
Value aZp = operands[2];
|
|
|
|
Value b = operands[3];
|
|
|
|
Value bScale = operands[4];
|
|
|
|
Value bZp = operands[5];
|
|
|
|
Value cScale = operands[6];
|
|
|
|
Value cZp = operands[7];
|
|
|
|
Value c = operands.size() == 9 ? operands[8] : nullptr;
|
|
|
|
|
|
|
|
auto check = [](Value v) {
|
|
|
|
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
|
|
|
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
|
|
|
|
};
|
|
|
|
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
|
|
|
!check(cScale) || !check(cScale))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "not supported for non per-tensor quantization");
|
|
|
|
|
|
|
|
auto extract = [&rewriter, &binder](Value v) {
|
|
|
|
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
|
|
|
Type extractTy = rewriter.getType<Torch::FloatType>();
|
|
|
|
if (isa<IntegerType>(vTy.getDtype()))
|
|
|
|
extractTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
|
|
|
|
v);
|
|
|
|
};
|
|
|
|
|
|
|
|
aZp = extract(aZp);
|
|
|
|
bZp = extract(bZp);
|
|
|
|
cZp = extract(cZp);
|
|
|
|
aScale = extract(aScale);
|
|
|
|
bScale = extract(bScale);
|
|
|
|
cScale = extract(cScale);
|
|
|
|
|
|
|
|
auto make = [&rewriter, &binder](Value v, Value scale,
|
|
|
|
Value zp) -> Value {
|
|
|
|
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
|
|
|
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
|
|
binder.getLoc(), newTy, v, scale, zp);
|
|
|
|
};
|
|
|
|
|
|
|
|
a = make(a, aScale, aZp);
|
|
|
|
b = make(b, bScale, bZp);
|
|
|
|
|
|
|
|
auto cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
resultType.getOptionalSizes(),
|
|
|
|
rewriter.getIntegerType(32, /*issigned=*/true));
|
|
|
|
|
|
|
|
// TODO(suderman): insert convolution operator.
|
|
|
|
llvm::SmallVector<Value> newOperands = {a, b};
|
|
|
|
if (c)
|
|
|
|
newOperands.push_back(c);
|
|
|
|
|
|
|
|
cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
resultType.getOptionalSizes(),
|
|
|
|
rewriter.getType<Torch::QInt32Type>());
|
|
|
|
|
|
|
|
llvm::SmallVector<NamedAttribute> newAttributes;
|
|
|
|
newAttributes.push_back(
|
|
|
|
rewriter.getNamedAttr("name", rewriter.getStringAttr("onnx.Conv")));
|
|
|
|
for (auto namedAttr : binder.op->getAttrDictionary()) {
|
|
|
|
if (namedAttr.getName().getValue().compare("name") == 0)
|
|
|
|
continue;
|
|
|
|
llvm::errs() << namedAttr.getName() << "\n";
|
|
|
|
newAttributes.push_back(namedAttr);
|
|
|
|
}
|
|
|
|
|
|
|
|
c = rewriter
|
|
|
|
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
|
2024-02-29 04:18:02 +08:00
|
|
|
newAttributes,
|
|
|
|
binder.op->getRegions().size())
|
2024-02-06 08:09:41 +08:00
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
|
|
|
|
bScale);
|
|
|
|
Value outZp = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
|
|
binder.getLoc(), cTy, c, outScale, outZp);
|
|
|
|
cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
resultType.getOptionalSizes(), rewriter.getF32Type());
|
|
|
|
|
|
|
|
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
|
|
|
|
c);
|
|
|
|
cTy = dyn_cast<Torch::ValueTensorType>(
|
|
|
|
getQTorchTypeFromTorchIntType(resultType));
|
|
|
|
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(
|
|
|
|
rewriter.getIntegerType(64),
|
|
|
|
static_cast<int64_t>(
|
|
|
|
Torch::getScalarTypeForType(cTy.getDtype()))));
|
|
|
|
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
|
|
|
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
|
|
|
|
c);
|
|
|
|
return success();
|
|
|
|
});
|
2024-01-25 04:28:48 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"QLinearMatMul", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
llvm::SmallVector<Value> operands;
|
|
|
|
if (binder.tensorOperands(operands, 8) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Value a = operands[0];
|
|
|
|
Value aScale = operands[1];
|
|
|
|
Value aZp = operands[2];
|
|
|
|
Value b = operands[3];
|
|
|
|
Value bScale = operands[4];
|
|
|
|
Value bZp = operands[5];
|
|
|
|
Value cScale = operands[6];
|
|
|
|
Value cZp = operands[7];
|
|
|
|
|
|
|
|
auto check = [](Value v) {
|
|
|
|
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
|
|
|
for (auto dim : vTy.getSizes())
|
|
|
|
if (dim != 1)
|
|
|
|
return false;
|
|
|
|
return true;
|
|
|
|
};
|
|
|
|
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
|
|
|
!check(cScale) || !check(cScale))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "not supported for non per-tensor quantization");
|
|
|
|
|
|
|
|
Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
rewriter.getType<Torch::ListType>(
|
|
|
|
rewriter.getType<Torch::IntType>()),
|
|
|
|
ValueRange{});
|
|
|
|
auto extract = [&rewriter, &binder, &emptyList](Value v) {
|
|
|
|
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
|
|
|
if (!vTy.getSizes().empty()) {
|
|
|
|
vTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
|
|
|
|
v = rewriter.create<Torch::AtenReshapeOp>(binder.getLoc(), vTy, v,
|
|
|
|
emptyList);
|
|
|
|
}
|
|
|
|
|
|
|
|
Type extractTy = rewriter.getType<Torch::FloatType>();
|
|
|
|
if (isa<IntegerType>(vTy.getDtype()))
|
|
|
|
extractTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
|
|
|
|
v);
|
|
|
|
};
|
|
|
|
|
|
|
|
aZp = extract(aZp);
|
|
|
|
bZp = extract(bZp);
|
|
|
|
cZp = extract(cZp);
|
|
|
|
aScale = extract(aScale);
|
|
|
|
bScale = extract(bScale);
|
|
|
|
cScale = extract(cScale);
|
|
|
|
|
2024-02-06 08:09:41 +08:00
|
|
|
auto make = [&rewriter, &binder](Value v, Value scale,
|
|
|
|
Value zp) -> Value {
|
2024-01-25 04:28:48 +08:00
|
|
|
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
2024-02-06 08:09:41 +08:00
|
|
|
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
2024-01-25 04:28:48 +08:00
|
|
|
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
|
|
binder.getLoc(), newTy, v, scale, zp);
|
|
|
|
};
|
|
|
|
|
|
|
|
a = make(a, aScale, aZp);
|
|
|
|
b = make(b, bScale, bZp);
|
|
|
|
|
|
|
|
auto cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
resultType.getOptionalSizes(),
|
|
|
|
rewriter.getIntegerType(32, /*issigned=*/true));
|
|
|
|
|
|
|
|
Value c;
|
|
|
|
if (cTy.getSizes().size() == 2) {
|
|
|
|
c = rewriter.create<Torch::AtenMmOp>(binder.getLoc(), cTy, a, b);
|
|
|
|
} else {
|
|
|
|
c = rewriter.create<Torch::AtenBmmOp>(binder.getLoc(), cTy, a, b);
|
|
|
|
}
|
|
|
|
|
|
|
|
cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
resultType.getOptionalSizes(),
|
|
|
|
rewriter.getType<Torch::QInt32Type>());
|
|
|
|
|
|
|
|
Value mmScale = rewriter.create<Torch::AtenMulFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
|
|
|
|
bScale);
|
|
|
|
Value mmZp = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
|
|
|
binder.getLoc(), cTy, c, mmScale, mmZp);
|
|
|
|
cTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
resultType.getOptionalSizes(), rewriter.getF32Type());
|
|
|
|
|
|
|
|
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
|
|
|
|
c);
|
2024-02-06 08:09:41 +08:00
|
|
|
cTy = dyn_cast<Torch::ValueTensorType>(
|
|
|
|
getQTorchTypeFromTorchIntType(resultType));
|
2024-01-25 04:28:48 +08:00
|
|
|
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(
|
|
|
|
rewriter.getIntegerType(64),
|
|
|
|
static_cast<int64_t>(
|
|
|
|
Torch::getScalarTypeForType(cTy.getDtype()))));
|
|
|
|
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
|
|
|
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
|
|
|
|
c);
|
|
|
|
return success();
|
|
|
|
});
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
patterns.onOp("Reciprocal", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReciprocalOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Relu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReluOp>(binder.op, resultType,
|
|
|
|
x);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Round", 11,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenRoundOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"ScatterElements", 18,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
int64_t axis;
|
|
|
|
std::string reduction;
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
binder.s64IntegerAttr(axis, "axis", 0) ||
|
|
|
|
binder.customOpNameStringAttr(reduction, "reduction", "none") ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Value data = valList[0];
|
|
|
|
Value indices = valList[1];
|
|
|
|
Value updates = valList[2];
|
|
|
|
|
|
|
|
// ONNX allows negative axis.
|
|
|
|
if (axis < 0)
|
|
|
|
axis +=
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
|
|
|
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
|
|
|
|
|
|
if (reduction == "none") {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenScatterSrcOp>(
|
|
|
|
binder.op, resultType, data, constAxis, indices, updates);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Implement max and min cases
|
|
|
|
if (reduction == "mul") {
|
|
|
|
reduction = "multiply";
|
|
|
|
} else if (reduction == "max" || reduction == "min") {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "max/min reduction unsupported for scatter elements");
|
|
|
|
}
|
|
|
|
|
|
|
|
Value cstStrReduction =
|
|
|
|
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
|
|
|
|
binder.op, resultType, data, constAxis, indices, updates,
|
|
|
|
cstStrReduction);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSigmoidOp>(binder.op, resultType,
|
|
|
|
x);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Sin", 7,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSinOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Tanh", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTanhOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Sqrt", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSqrtOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Sub", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
Value y;
|
|
|
|
if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Value const1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSubTensorOp>(
|
|
|
|
binder.op, resultType, x, y, const1);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Sum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
if (binder.op->getNumOperands() == 1) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOp(binder.op, x);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Value const1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
// Short circuit to binary add
|
|
|
|
if (numOperands == 2) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
|
|
|
|
binder.op, resultType, valList[0], valList[1], const1);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// When binder.op->getNumOperands() > 2
|
|
|
|
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
binder.getLoc(), resultType, valList[0], valList[1], const1);
|
|
|
|
for (int i = 2; i < numOperands; i++) {
|
|
|
|
if (i == numOperands - 1) {
|
|
|
|
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
binder.getLoc(), resultType, curr, valList[i], const1);
|
|
|
|
} else {
|
2024-03-01 00:18:46 +08:00
|
|
|
SmallVector<int64_t> resultBroadcastShapeInt;
|
|
|
|
SmallVector<Value> resultBroadcastShapeValue;
|
|
|
|
Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr,
|
|
|
|
valList[i], resultBroadcastShapeInt,
|
|
|
|
resultBroadcastShapeValue);
|
|
|
|
auto baseType = Torch::ValueTensorType::get(
|
|
|
|
binder.op->getContext(), resultBroadcastShapeInt,
|
|
|
|
resultType.getOptionalDtype());
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
binder.getLoc(), baseType, curr, valList[i], const1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(binder.op, curr);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Where", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Value condition = valList[0];
|
|
|
|
Value x = valList[1];
|
|
|
|
Value y = valList[2];
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
|
|
|
|
binder.op, resultType, condition, x, y);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Xor", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
Value y;
|
|
|
|
if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLogicalXorOp>(binder.op,
|
|
|
|
resultType, x, y);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Squeeze", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
|
|
auto sizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
|
|
if (sizes.size() == 0) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op,
|
|
|
|
resultType, data);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
adjustmentInt));
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with neg axis: if (axis < 0) axis += rank
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimsSqueezeOp>(
|
|
|
|
binder.op, resultType, data, dimValueList);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Unsqueeze", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
// Unlike squeeze where we are able to lower to Torch::PrimsSqueezeOp,
|
|
|
|
// pytorch does not support torch.unsqueeze to insert multiple new dims.
|
|
|
|
// discussion can be found here:
|
|
|
|
// https://github.com/pytorch/pytorch/issues/9410
|
|
|
|
// So, for now, we unroll into multiple unsqueezes.
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
|
|
auto sizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
|
|
if (sizes.size() == 0) {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
adjustmentInt));
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with neg axis: if (axis < 0) axis += rank
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
Value cstFalse =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
Value updatedAxes = rewriter.create<Torch::AtenTensorOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
axesType.getWithSizesAndDtype(sizes, axesType.getOptionalDtype()),
|
|
|
|
dimValueList, /*dtype=*/noneVal, /*device=*/noneVal, cstFalse);
|
|
|
|
// Sort the list of dims, so we don't run into this situation:
|
|
|
|
// data.sizes = [2, 3, 4]
|
|
|
|
// dims = [4, 0]
|
|
|
|
// index 4 will be invalid to add a singleton dimension because
|
|
|
|
// data.sizes.size == 3 We have to work with sorted dims to avoid this
|
|
|
|
// situation.
|
|
|
|
auto sortIndicesType = axesType.getWithSizesAndDtype(
|
|
|
|
axesType.getOptionalSizes(),
|
|
|
|
IntegerType::get(binder.op->getContext(), 64, IntegerType::Signed));
|
|
|
|
auto sortOpResult = rewriter.create<Torch::AtenSortOp>(
|
|
|
|
binder.getLoc(), axes.getType(), sortIndicesType, updatedAxes, zero,
|
|
|
|
cstFalse);
|
|
|
|
Value result;
|
|
|
|
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
|
|
|
|
binder.op->getContext());
|
|
|
|
// Go through the updated, sorted axes. Do unsqueeze for each dim.
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, sortOpResult->getResult(0),
|
|
|
|
zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
if (sizes[0] == 1) {
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
|
|
|
binder.getLoc(), resultType, data, dim);
|
|
|
|
} else if (i == 0) {
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
|
|
|
binder.getLoc(), baseType, data, dim);
|
|
|
|
} else if (i == sizes[0] - 1) {
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
|
|
|
binder.getLoc(), resultType, result, dim);
|
|
|
|
} else {
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
|
|
|
binder.getLoc(), baseType, result, dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(binder.op, result);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Softmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value input;
|
|
|
|
int64_t axis;
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
binder.s64IntegerAttr(axis, "axis", -1) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// ONNX allows negative axis.
|
|
|
|
if (axis < 0)
|
|
|
|
axis +=
|
|
|
|
cast<Torch::ValueTensorType>(input.getType()).getSizes().size();
|
|
|
|
|
|
|
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
|
|
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSoftmaxIntOp>(
|
|
|
|
binder.op, resultType, input, constAxis, /*dtype=*/noneVal);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-15 00:53:47 +08:00
|
|
|
|
|
|
|
patterns.onOp(
|
|
|
|
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
float alpha, gamma;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.f32FloatAttr(alpha, "alpha") ||
|
|
|
|
binder.f32FloatAttr(gamma, "gamma") ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
|
|
|
|
|
|
|
|
Value vScale = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));
|
|
|
|
|
|
|
|
Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
|
|
|
|
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-19 04:37:31 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"ReduceSum", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
int64_t keepDims;
|
|
|
|
int64_t noop_with_empty_axes;
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
0))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
|
|
auto sizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
// Deal with case when axes is empty
|
|
|
|
if (sizes.size() == 1 && sizes[0] == 0) {
|
|
|
|
if (noop_with_empty_axes == 0) {
|
|
|
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
|
|
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
|
|
|
binder.getLoc(), keepDimsConstInt);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
|
|
|
binder.op, resultType, data, /*dim=*/noneVal,
|
|
|
|
/*keepdim=*/keepDimsBool, /*dtype=*/noneVal);
|
|
|
|
} else {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
adjustmentInt));
|
|
|
|
// convert axes (tensor) into torch int list while dealing with neg axis
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with neg axis: if (axis < 0) axis += rank
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
Value keepDimBool;
|
|
|
|
if (keepDims == 1) {
|
|
|
|
keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
|
|
} else {
|
|
|
|
keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
2024-03-12 02:32:53 +08:00
|
|
|
"ReduceMean", 1,
|
2023-12-19 04:37:31 +08:00
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
int64_t keepDims;
|
|
|
|
int64_t noop_with_empty_axes;
|
2024-03-12 02:32:53 +08:00
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
2023-12-19 04:37:31 +08:00
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
0))
|
|
|
|
return failure();
|
2024-03-12 02:32:53 +08:00
|
|
|
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
|
|
|
|
Value axesVal;
|
|
|
|
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axesVal.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes{1};
|
|
|
|
auto selType = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
selectSizes, axesType.getOptionalDtype());
|
|
|
|
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
|
|
|
auto axesShape = axesTy.getSizes();
|
|
|
|
|
|
|
|
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
|
|
|
int64_t numAxes = axesShape[0];
|
|
|
|
for (int64_t i = 0; i < numAxes; ++i) {
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
2023-12-19 04:37:31 +08:00
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
2024-03-12 02:32:53 +08:00
|
|
|
rewriter.getI64IntegerAttr(i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selType, axesVal, zero, iv);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
axesList.push_back(dim);
|
2023-12-19 04:37:31 +08:00
|
|
|
}
|
2024-03-12 02:32:53 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t> axesInts;
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) {
|
|
|
|
for (int64_t i = 0, s = axesInts.size(); i < s; ++i) {
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getI64IntegerAttr(axesInts[i]));
|
|
|
|
axesList.push_back(iv);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// deal with case when axes is empty
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
2023-12-19 04:37:31 +08:00
|
|
|
return success();
|
|
|
|
}
|
2024-03-12 02:32:53 +08:00
|
|
|
|
2023-12-19 04:37:31 +08:00
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
2024-03-12 02:32:53 +08:00
|
|
|
rewriter.getI64IntegerAttr(0));
|
2023-12-19 04:37:31 +08:00
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
2024-03-12 02:32:53 +08:00
|
|
|
rewriter.getI64IntegerAttr(adjustmentInt));
|
|
|
|
|
|
|
|
// Handle if the axes value is less than zero:
|
|
|
|
for (int i = 0, s = axesList.size(); i < s; i++) {
|
|
|
|
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
|
|
|
|
binder.getLoc(), axesList[i], zero);
|
2023-12-19 04:37:31 +08:00
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
2024-03-12 02:32:53 +08:00
|
|
|
binder.getLoc(), axesList[i], finalOffset);
|
|
|
|
axesList[i] = finalDim;
|
2023-12-19 04:37:31 +08:00
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
2024-03-12 02:32:53 +08:00
|
|
|
axesList);
|
|
|
|
Value keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
2023-12-29 01:31:41 +08:00
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
2023-12-19 04:37:31 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
return success();
|
|
|
|
});
|
2024-03-07 08:48:21 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"ReduceMax", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
// AtenAmaxOp allows us to pass a list of dims
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
int64_t keepDims;
|
|
|
|
int64_t noop_with_empty_axes;
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
0))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
|
|
|
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
// If any of the input dims are 0 we set to the upper limit:
|
|
|
|
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
|
|
|
|
(llvm::any_of(dataTy.getSizes(),
|
|
|
|
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
|
|
|
|
keepDims)) {
|
|
|
|
auto dty = dataTy.getDtype();
|
|
|
|
Value scalar;
|
|
|
|
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
|
|
|
|
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
|
|
|
|
scalar = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(),
|
|
|
|
inf.convertToDouble()));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
|
|
|
|
auto mx =
|
|
|
|
intTy.isSigned()
|
|
|
|
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
|
|
|
|
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
|
|
|
|
scalar = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
mx.getSExtValue()));
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> fillDims;
|
|
|
|
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
|
|
|
|
auto staticDim = resultType.getSizes()[i];
|
|
|
|
if (staticDim != Torch::kUnknownSize) {
|
|
|
|
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
rewriter.getI64IntegerAttr(staticDim)));
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
|
|
|
|
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy, data, iv));
|
|
|
|
}
|
|
|
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
|
|
|
|
binder.op, resultType, fillDimsList, scalar, none, none, none,
|
|
|
|
none);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Previous version of the operation had the axes as an attribute:
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
llvm::SmallVector<int64_t> axesAttr;
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
|
|
|
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
|
|
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
rewriter.getI64IntegerAttr(axesAttr[i])));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Extract the axes values from the axes operand:
|
|
|
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<int64_t> selectSizes{1};
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
selectSizes, axesType.getOptionalDtype());
|
|
|
|
auto sizes = axesType.getSizes();
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
|
|
|
// Extract the value of each axes:
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
axesList.push_back(dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle the noop case:
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Deal with case when no axes arg is passed but not a noop:
|
|
|
|
if (axesList.empty()) {
|
|
|
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
|
|
|
.getSizes()
|
|
|
|
.size();
|
|
|
|
for (int i = 0; i < numDims; i++) {
|
|
|
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
axesList.push_back(curr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle negative axis:
|
|
|
|
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
|
|
|
torchIntTy, data);
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
|
|
|
for (Value &axes : axesList) {
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, rankVal);
|
|
|
|
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
|
|
|
finalOffset);
|
|
|
|
}
|
|
|
|
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
|
|
|
|
Value keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
|
2023-12-19 04:37:31 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"ReduceMin", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
// AtenAminOp allows us to pass a list of dims
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
int64_t keepDims;
|
|
|
|
int64_t noop_with_empty_axes;
|
2024-02-28 14:48:07 +08:00
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
2023-12-19 04:37:31 +08:00
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
0))
|
|
|
|
return failure();
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
|
|
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
|
|
|
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
// If any of the input dims are 0 we set to the upper limit:
|
|
|
|
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
|
|
|
|
(llvm::any_of(dataTy.getSizes(),
|
|
|
|
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
|
|
|
|
keepDims)) {
|
|
|
|
auto dty = dataTy.getDtype();
|
|
|
|
Value scalar;
|
|
|
|
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
|
|
|
|
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
|
|
|
|
scalar = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(),
|
|
|
|
inf.convertToDouble()));
|
|
|
|
}
|
|
|
|
|
|
|
|
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
|
|
|
|
auto mx =
|
|
|
|
intTy.isSigned()
|
|
|
|
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
|
|
|
|
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
|
|
|
|
scalar = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
mx.getSExtValue()));
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> fillDims;
|
|
|
|
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
|
|
|
|
auto staticDim = resultType.getSizes()[i];
|
|
|
|
if (staticDim != Torch::kUnknownSize) {
|
|
|
|
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
rewriter.getI64IntegerAttr(staticDim)));
|
|
|
|
continue;
|
2023-12-19 04:37:31 +08:00
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
|
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
|
|
|
|
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy, data, iv));
|
2023-12-19 04:37:31 +08:00
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
|
|
|
|
binder.op, resultType, fillDimsList, scalar, none, none, none,
|
|
|
|
none);
|
2023-12-19 04:37:31 +08:00
|
|
|
return success();
|
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
|
|
|
// Previous version of the operation had the axes as an attribute:
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
llvm::SmallVector<int64_t> axesAttr;
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
|
|
|
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
|
|
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
rewriter.getI64IntegerAttr(axesAttr[i])));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Extract the axes values from the axes operand:
|
|
|
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<int64_t> selectSizes{1};
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
selectSizes, axesType.getOptionalDtype());
|
|
|
|
auto sizes = axesType.getSizes();
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
|
|
|
// Extract the value of each axes:
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
axesList.push_back(dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle the noop case:
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Deal with case when no axes arg is passed but not a noop:
|
|
|
|
if (axesList.empty()) {
|
|
|
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
|
|
|
.getSizes()
|
|
|
|
.size();
|
|
|
|
for (int i = 0; i < numDims; i++) {
|
|
|
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
axesList.push_back(curr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle negative axis:
|
|
|
|
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
|
|
|
torchIntTy, data);
|
2023-12-19 04:37:31 +08:00
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
2024-02-28 14:48:07 +08:00
|
|
|
rewriter.getI64IntegerAttr(0));
|
|
|
|
for (Value &axes : axesList) {
|
2023-12-19 04:37:31 +08:00
|
|
|
Value isNegative =
|
2024-02-28 14:48:07 +08:00
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
2023-12-19 04:37:31 +08:00
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
2024-02-28 14:48:07 +08:00
|
|
|
binder.getLoc(), isNegative, rankVal);
|
|
|
|
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
|
|
|
finalOffset);
|
2023-12-19 04:37:31 +08:00
|
|
|
}
|
2024-02-28 14:48:07 +08:00
|
|
|
|
2023-12-19 04:37:31 +08:00
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
2024-02-28 14:48:07 +08:00
|
|
|
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
|
|
|
|
Value keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
2023-12-19 04:37:31 +08:00
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-16 03:37:49 +08:00
|
|
|
|
|
|
|
patterns.onOp("Shape", 9,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-16 13:23:51 +08:00
|
|
|
|
|
|
|
patterns.onOp("Sinh", 9,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-21 02:09:39 +08:00
|
|
|
|
2023-12-28 09:53:07 +08:00
|
|
|
// split with fixed-size parts
|
|
|
|
// Arguments:
|
|
|
|
// - input: the tensor to split
|
|
|
|
// Attributes:
|
|
|
|
// - axis: the axis along which to split the input
|
|
|
|
// - num_outputs: the number of outputs to produce
|
|
|
|
// Outputs:
|
|
|
|
// - outputs: the produced outputs. Variadic with num_outputs elements.
|
|
|
|
// Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
|
|
|
|
// tensors
|
|
|
|
// so we need to unpack the list
|
|
|
|
patterns.onOp(
|
|
|
|
"Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Value self;
|
|
|
|
int64_t axis;
|
|
|
|
int64_t num_outputs;
|
|
|
|
if (binder.tensorOperand(self))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Not converting to AtenSplitTensorOp due to input "
|
|
|
|
"tensor mismatch");
|
|
|
|
if (binder.s64IntegerAttr(axis, "axis", 0))
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"Failed to get axis attribute");
|
|
|
|
if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Failed to get num_outputs attribute");
|
|
|
|
|
|
|
|
auto result0Ty =
|
|
|
|
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto selfTy = self.getType().cast<Torch::ValueTensorType>();
|
|
|
|
|
|
|
|
int64_t dim = axis;
|
|
|
|
if (dim < 0)
|
|
|
|
dim += selfTy.getSizes().size();
|
|
|
|
|
|
|
|
// set intermediate shape to the shape of the first result
|
|
|
|
// if the results are of different shapes
|
|
|
|
// set the splitted axis to variable shape
|
|
|
|
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
|
|
|
for (auto result : binder.op->getResultTypes()) {
|
|
|
|
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
|
|
|
|
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim));
|
|
|
|
|
|
|
|
Value splitSize = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), num_outputs));
|
|
|
|
|
|
|
|
// TODO: Attempting to use the shape expected by the ONNX mlir as ground
|
|
|
|
// truth. For now just use dynamic shapes.
|
|
|
|
auto resultOuterType =
|
|
|
|
Torch::ListType::get(rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
/*std::optional<llvm::ArrayRef<int64_t>>=*/intermediateShape,
|
|
|
|
result0Ty.getOptionalDtype()));
|
|
|
|
Torch::AtenSplitTensorOp new_op =
|
|
|
|
rewriter.create<Torch::AtenSplitTensorOp>(
|
|
|
|
binder.getLoc(), resultOuterType, self, splitSize, dimValue);
|
|
|
|
|
|
|
|
// the onnx op is variadic with multiple results, but AtenSplitWithSizes
|
|
|
|
// outputs a list so we need to unpack the list
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListUnpackOp>(
|
|
|
|
binder.op, binder.op->getResults().getType(), new_op.getResult());
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
|
|
|
|
// split with variable parts
|
|
|
|
// Arguments:
|
|
|
|
// - input: the tensor to split
|
|
|
|
// - split: the sizes of the splits to be produced
|
|
|
|
// Attributes:
|
|
|
|
// - axis: the axis along which to split the input
|
|
|
|
// - num_outputs: the number of outputs to produce
|
|
|
|
// Outputs:
|
|
|
|
// - outputs: the produced outputs. Variadic with num_outputs elements.
|
|
|
|
// Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
|
|
|
|
// tensors
|
|
|
|
// so we need to unpack the list
|
|
|
|
patterns.onOp(
|
|
|
|
"Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Value self;
|
|
|
|
Value split;
|
|
|
|
int64_t axis;
|
|
|
|
int64_t num_outputs;
|
|
|
|
if (binder.tensorOperandAtIndex(self, 0) ||
|
|
|
|
binder.tensorOperandAtIndex(split, 1))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Not converting to AtenSplitWithSizesOp due to input "
|
|
|
|
"tensor mismatch");
|
|
|
|
if (binder.s64IntegerAttr(axis, "axis", 0))
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"Failed to get axis attribute");
|
|
|
|
if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Failed to get num_outputs attribute");
|
|
|
|
|
|
|
|
auto result0Ty =
|
|
|
|
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto selfTy =
|
|
|
|
cast<Torch::ValueTensorType>(binder.op->getOperand(0).getType());
|
|
|
|
|
|
|
|
int64_t dim = axis;
|
|
|
|
if (dim < 0)
|
|
|
|
dim += selfTy.getSizes().size();
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
|
|
|
for (auto result : binder.op->getResultTypes()) {
|
|
|
|
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
|
|
|
|
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
Torch::PrimTolistOp splitToList = rewriter.create<Torch::PrimTolistOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(rewriter.getType<Torch::IntType>()), split);
|
|
|
|
|
|
|
|
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim));
|
|
|
|
|
|
|
|
// TODO: Attempting to use the shape expected by the ONNX mlir as ground
|
|
|
|
// truth. For now just use dynamic shapes.
|
|
|
|
auto resultOuterType =
|
|
|
|
Torch::ListType::get(rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
/*std::optional<llvm::ArrayRef<int64_t>>=*/intermediateShape,
|
|
|
|
result0Ty.getOptionalDtype()));
|
|
|
|
Torch::AtenSplitWithSizesOp new_op =
|
|
|
|
rewriter.create<Torch::AtenSplitWithSizesOp>(
|
|
|
|
binder.getLoc(), resultOuterType, self,
|
|
|
|
splitToList.getResult(0), dimValue);
|
|
|
|
|
|
|
|
// the onnx op is variadic with multiple results, but AtenSplitWithSizes
|
|
|
|
// outputs a list so we need to unpack the list
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListUnpackOp>(
|
|
|
|
binder.op, binder.op->getResults().getType(), new_op.getResult());
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
|
2023-12-21 02:09:39 +08:00
|
|
|
patterns.onOp("Tan", 7,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTanOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
|
2023-12-16 07:30:05 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Transpose", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
auto loc = binder.getLoc();
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
auto operandType = operand.getType().cast<Torch::ValueTensorType>();
|
|
|
|
TensorType tensorType = operandType.toBuiltinTensor();
|
|
|
|
if (!tensorType || !tensorType.hasRank())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Default permutation is to reverse orders:
|
|
|
|
int64_t rank = tensorType.getRank();
|
|
|
|
llvm::SmallVector<int64_t> reverse(rank);
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
reverse[i] = rank - i - 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> permutations;
|
|
|
|
if (failed(binder.s64IntegerArrayAttr(permutations, "perm", reverse)))
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"Failed to obtain permutations");
|
|
|
|
|
|
|
|
if (static_cast<int64_t>(permutations.size()) != rank)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Permutation length does not match operand rank");
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> shape(tensorType.getShape());
|
|
|
|
llvm::SmallVector<int64_t> current(rank);
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
current[i] = i;
|
|
|
|
}
|
|
|
|
|
2024-02-17 05:04:47 +08:00
|
|
|
for (auto &dim : permutations)
|
|
|
|
dim = dim < 0 ? dim + rank : dim;
|
|
|
|
|
|
|
|
// We need to override to the destination if known:
|
|
|
|
if (resultType.hasSizes()) {
|
|
|
|
for (int i = 0; i < rank; ++i) {
|
|
|
|
shape[permutations[i]] = resultType.getSizes()[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Convert dynamic shape dimension:
|
2024-01-30 01:59:33 +08:00
|
|
|
for (unsigned i = 0; i < shape.size(); i++) {
|
2024-01-27 01:46:54 +08:00
|
|
|
if (shape[i] == ShapedType::kDynamic)
|
|
|
|
shape[i] = Torch::kUnknownSize;
|
|
|
|
}
|
|
|
|
|
2023-12-16 07:30:05 +08:00
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
if (current[i] == permutations[i])
|
|
|
|
continue;
|
|
|
|
|
|
|
|
int64_t target = i + 1;
|
|
|
|
for (; target < rank; ++target) {
|
|
|
|
if (current[target] == permutations[i])
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::swap(shape[i], shape[target]);
|
|
|
|
std::swap(current[i], current[target]);
|
|
|
|
|
|
|
|
Value dim0 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
|
|
|
Value dim1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), target));
|
|
|
|
|
|
|
|
operand = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
loc,
|
|
|
|
Torch::ValueTensorType::get(tensorType.getContext(), shape,
|
|
|
|
operandType.getOptionalDtype()),
|
|
|
|
operand, dim0, dim1);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(binder.op, operand);
|
|
|
|
return success();
|
|
|
|
});
|
2024-01-04 11:41:10 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Slice", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultTorchType;
|
|
|
|
Value operand, starts, ends;
|
|
|
|
// Handle if axes are not provided
|
2023-12-27 02:20:13 +08:00
|
|
|
|
2024-01-04 11:41:10 +08:00
|
|
|
if (binder.tensorOperandAtIndex(operand, 0) ||
|
|
|
|
binder.tensorOperandAtIndex(starts, 1) ||
|
|
|
|
binder.tensorOperandAtIndex(ends, 2) ||
|
|
|
|
binder.tensorResultType(resultTorchType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
auto context = rewriter.getContext();
|
|
|
|
auto operandTorchTy = operand.getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto operandTy =
|
|
|
|
operandTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
|
|
|
|
|
|
|
if (!operandTy)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op,
|
|
|
|
"Expected tensor operator argument to be a ranked tensor type");
|
|
|
|
|
|
|
|
auto startsTorchTy = starts.getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto startsTy =
|
|
|
|
startsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
|
|
|
int startSize = startsTy.getDimSize(0);
|
|
|
|
|
|
|
|
auto endsTorchTy = ends.getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto endsTy =
|
|
|
|
endsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
|
|
|
int endSize = endsTy.getDimSize(0);
|
|
|
|
auto resultTy =
|
|
|
|
resultTorchType.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
|
|
|
if (!resultTy)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Expected result type to be a ranked tensor type");
|
|
|
|
|
|
|
|
Location loc = binder.getLoc();
|
|
|
|
|
|
|
|
// Binding `axes` from its arguments or through a default value
|
|
|
|
Value axes;
|
|
|
|
if (binder.getNumOperands() >= 4) {
|
|
|
|
if (binder.tensorOperandAtIndex(axes, 3)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Binding `steps` from its arguments or through a default value
|
|
|
|
Value steps;
|
|
|
|
if (binder.getNumOperands() >= 5) {
|
|
|
|
if (binder.tensorOperandAtIndex(steps, 4)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// The default `steps` value is a 1d tensor filled with ones with a
|
onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec (#2848)
This PR contains three commits to update the validation checks in the
ONNX -> Torch conversion pass for the AveragePool, Pad, and Slice operators:
> onnx: fix preconditions for lowering AveragePool ops
>
> The `pads` attribute of the AveragePool operator specifies the value to
> pad at both the beginning as well as the end of the axis (see
> https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so
> the size of this attribute should be twice the rank of the input tensor.
> However, our TorchOnnxToTorch bails out early since it incorrectly
> compares the pads attribute with the rank (not twice the rank) of the
> input tensor.
>
> This patch fixes the code to match the spec and adds a lit test.
> onnx: allow optional constant value for Pad operator
>
> The `constant_value` input of the onnx.Pad operator is optional (see
> https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the
existing
> logic for lowering the operator into the Torch dialect assumes that it
> is mandatory.
>
> This patch makes the attribute optional and constructs a default value
> (a list of zeros the size of the input tensor) if the attribute was not
> specified.
> onnx: fix checks for axes and steps inputs of Slice operator
>
> The ONNX Spec for the Slice operator allows the `starts` and `ends`
> inputs to have fewer indices that the dimensions of the `data` tensor
> (see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
> expects these inputs to be as many as the `data` tensor's dimensions.
>
> More precisely, the spec requires that the `starts` and `ends` inputs
> are only as long as the `axes` input, but since the `axes` input is
> optional, the default type for the `axes` input has to match the type
> for the `starts` and `ends` inputs. Moreover, the number of indices in
> the `steps` input also has to match those in the `axes` inputs (instad
> of matching the dimensions of the `data` input).
>
> This patch fixes the checks in the TorchOnnxToTorch conversion so that
> they match the ONNX spec.
2024-02-08 13:19:27 +08:00
|
|
|
// size equal to the size of `starts` and `ends`.
|
2024-01-04 11:41:10 +08:00
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
Value sizeStepInput = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getType<Torch::IntType>(),
|
onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec (#2848)
This PR contains three commits to update the validation checks in the
ONNX -> Torch conversion pass for the AveragePool, Pad, and Slice operators:
> onnx: fix preconditions for lowering AveragePool ops
>
> The `pads` attribute of the AveragePool operator specifies the value to
> pad at both the beginning as well as the end of the axis (see
> https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so
> the size of this attribute should be twice the rank of the input tensor.
> However, our TorchOnnxToTorch bails out early since it incorrectly
> compares the pads attribute with the rank (not twice the rank) of the
> input tensor.
>
> This patch fixes the code to match the spec and adds a lit test.
> onnx: allow optional constant value for Pad operator
>
> The `constant_value` input of the onnx.Pad operator is optional (see
> https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the
existing
> logic for lowering the operator into the Torch dialect assumes that it
> is mandatory.
>
> This patch makes the attribute optional and constructs a default value
> (a list of zeros the size of the input tensor) if the attribute was not
> specified.
> onnx: fix checks for axes and steps inputs of Slice operator
>
> The ONNX Spec for the Slice operator allows the `starts` and `ends`
> inputs to have fewer indices that the dimensions of the `data` tensor
> (see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
> expects these inputs to be as many as the `data` tensor's dimensions.
>
> More precisely, the spec requires that the `starts` and `ends` inputs
> are only as long as the `axes` input, but since the `axes` input is
> optional, the default type for the `axes` input has to match the type
> for the `starts` and `ends` inputs. Moreover, the number of indices in
> the `steps` input also has to match those in the `axes` inputs (instad
> of matching the dimensions of the `data` input).
>
> This patch fixes the checks in the TorchOnnxToTorch conversion so that
> they match the ONNX spec.
2024-02-08 13:19:27 +08:00
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize));
|
2024-01-04 11:41:10 +08:00
|
|
|
Value sizeStepsInput = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
loc,
|
|
|
|
Torch::ListType::get(
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
sizeStepInput);
|
|
|
|
steps = rewriter.create<Torch::AtenOnesOp>(
|
onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec (#2848)
This PR contains three commits to update the validation checks in the
ONNX -> Torch conversion pass for the AveragePool, Pad, and Slice operators:
> onnx: fix preconditions for lowering AveragePool ops
>
> The `pads` attribute of the AveragePool operator specifies the value to
> pad at both the beginning as well as the end of the axis (see
> https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so
> the size of this attribute should be twice the rank of the input tensor.
> However, our TorchOnnxToTorch bails out early since it incorrectly
> compares the pads attribute with the rank (not twice the rank) of the
> input tensor.
>
> This patch fixes the code to match the spec and adds a lit test.
> onnx: allow optional constant value for Pad operator
>
> The `constant_value` input of the onnx.Pad operator is optional (see
> https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the
existing
> logic for lowering the operator into the Torch dialect assumes that it
> is mandatory.
>
> This patch makes the attribute optional and constructs a default value
> (a list of zeros the size of the input tensor) if the attribute was not
> specified.
> onnx: fix checks for axes and steps inputs of Slice operator
>
> The ONNX Spec for the Slice operator allows the `starts` and `ends`
> inputs to have fewer indices that the dimensions of the `data` tensor
> (see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code
> expects these inputs to be as many as the `data` tensor's dimensions.
>
> More precisely, the spec requires that the `starts` and `ends` inputs
> are only as long as the `axes` input, but since the `axes` input is
> optional, the default type for the `axes` input has to match the type
> for the `starts` and `ends` inputs. Moreover, the number of indices in
> the `steps` input also has to match those in the `axes` inputs (instad
> of matching the dimensions of the `data` input).
>
> This patch fixes the checks in the TorchOnnxToTorch conversion so that
> they match the ONNX spec.
2024-02-08 13:19:27 +08:00
|
|
|
loc, startsTorchTy, sizeStepsInput, none, none, none, none);
|
2024-01-04 11:41:10 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 &&
|
|
|
|
startSize == endSize))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Expected the rank of starts and ends tensors to be 1 "
|
|
|
|
"and their dimensions to match");
|
|
|
|
|
2024-02-20 02:26:29 +08:00
|
|
|
if (axes) {
|
|
|
|
auto axesTorchTy = axes.getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto axesTy =
|
|
|
|
axesTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
|
|
|
int64_t numAxes = axesTy.getDimSize(0);
|
2024-01-04 11:41:10 +08:00
|
|
|
|
2024-02-20 02:26:29 +08:00
|
|
|
if (!(axesTy && numAxes == endSize))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Axes should be the same size of starts and ends");
|
|
|
|
}
|
2024-01-04 11:41:10 +08:00
|
|
|
|
|
|
|
auto stepsTy = steps.getType()
|
|
|
|
.cast<Torch::ValueTensorType>()
|
|
|
|
.toBuiltinTensor()
|
|
|
|
.dyn_cast<RankedTensorType>();
|
|
|
|
|
|
|
|
if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Steps should be the same size of starts and ends");
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
|
|
|
|
auto select = [&](Value v, Value k) -> Value {
|
|
|
|
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
|
|
|
loc,
|
|
|
|
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
|
|
|
|
ty.getOptionalDtype()),
|
|
|
|
v, zero, k);
|
|
|
|
Value item = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
loc, rewriter.getType<Torch::IntType>(), sel);
|
|
|
|
return item;
|
|
|
|
};
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
|
|
|
|
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
|
2024-02-17 05:35:25 +08:00
|
|
|
if (operandTy.getDimSize(i) != resultTy.getDimSize(i))
|
2024-01-04 11:41:10 +08:00
|
|
|
intermediateShape[i] = -1;
|
2024-02-17 05:35:25 +08:00
|
|
|
if (intermediateShape[i] == ShapedType::kDynamic)
|
|
|
|
intermediateShape[i] = Torch::kUnknownSize;
|
2024-01-04 11:41:10 +08:00
|
|
|
}
|
|
|
|
auto intermediateType = Torch::ValueTensorType::get(
|
|
|
|
context, intermediateShape, resultTorchType.getOptionalDtype());
|
2024-02-20 02:26:29 +08:00
|
|
|
for (int i = 0; i < endSize; ++i) {
|
2024-01-04 11:41:10 +08:00
|
|
|
|
|
|
|
Value k = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
|
|
|
loc,
|
|
|
|
Torch::ValueTensorType::get(
|
|
|
|
context, ArrayRef<int64_t>{1},
|
|
|
|
rewriter.getIntegerType(64, /*signed*/ 1)),
|
|
|
|
k);
|
|
|
|
|
|
|
|
Value start = select(starts, kTensor);
|
|
|
|
Value end = select(ends, kTensor);
|
2024-02-20 02:26:29 +08:00
|
|
|
Value axis = axes ? select(axes, kTensor) : k;
|
2024-01-04 11:41:10 +08:00
|
|
|
Value step = select(steps, kTensor);
|
|
|
|
|
|
|
|
auto sliceType = intermediateType;
|
2024-02-20 02:26:29 +08:00
|
|
|
sliceType = i == (endSize - 1) ? resultTorchType : sliceType;
|
2024-01-04 11:41:10 +08:00
|
|
|
operand = rewriter.create<Torch::AtenSliceTensorOp>(
|
|
|
|
loc, sliceType, operand, axis, start, end, step);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(binder.op, operand);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-27 02:20:13 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value shape;
|
|
|
|
int64_t allowzero;
|
|
|
|
if (binder.tensorOperands(data, shape) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(allowzero, "allowzero", 0))
|
|
|
|
return failure();
|
2024-02-08 09:44:07 +08:00
|
|
|
|
|
|
|
// If the result shape is static then we can create a result shape list
|
|
|
|
// directly using the result shape values (integers).
|
|
|
|
if (resultType.hasSizes()) {
|
|
|
|
bool hasStaticShape = resultType.areAllSizesKnown();
|
|
|
|
ArrayRef<int64_t> resultShapeInt = resultType.getSizes();
|
|
|
|
if (hasStaticShape) {
|
|
|
|
SmallVector<Value> resultShape;
|
|
|
|
for (int64_t dim : resultShapeInt) {
|
|
|
|
resultShape.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(dim)));
|
|
|
|
}
|
|
|
|
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
resultShape);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
|
|
|
binder.op, resultType, data, resultShapeList);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-27 02:20:13 +08:00
|
|
|
Torch::BaseTensorType shapeType =
|
|
|
|
shape.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = shapeType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
|
|
|
auto shapeSizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
|
|
|
|
auto dataSizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(data.getType()).getSizes();
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
if (allowzero == 0) {
|
|
|
|
// convert shape (tensor) into torch int list while dealing with zero
|
|
|
|
// vals
|
|
|
|
for (int i = 0; i < shapeSizes[0]; i++) {
|
|
|
|
// Go through the shape list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, shape, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with zero axis values: replace with original dim value in
|
|
|
|
// input
|
|
|
|
Value isZero =
|
|
|
|
rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isZero =
|
|
|
|
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero);
|
|
|
|
Value adjustment;
|
|
|
|
int64_t inputDimsSize = dataSizes.size();
|
|
|
|
if (i < inputDimsSize) {
|
|
|
|
adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
dataSizes[i]));
|
|
|
|
}
|
|
|
|
// Will never have a 0 in the shape tensor input at an index out of
|
|
|
|
// bounds of original input dims Therefore, no need to adjust
|
|
|
|
else {
|
|
|
|
adjustment = zero;
|
|
|
|
}
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isZero, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
|
|
|
binder.op, resultType, data, dimValueList);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// convert axes (tensor) into torch int list
|
|
|
|
for (int i = 0; i < shapeSizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, shape, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
dimList.push_back(dim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(binder.op, resultType,
|
|
|
|
data, dimValueList);
|
|
|
|
return success();
|
|
|
|
});
|
2024-03-05 03:07:03 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"ReduceProd", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
// ReduceProd allows us to pass a list of dims but AtenProdDimIn only
|
|
|
|
// allow one dim as input.
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
int64_t keepDims;
|
|
|
|
int64_t noop_with_empty_axes;
|
|
|
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
0))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
|
|
|
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
|
|
|
if (!resultType.hasSizes() || !resultType.areAllSizesKnown() ||
|
|
|
|
!dataTy.areAllSizesKnown())
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op,
|
|
|
|
"Expected the input and result type to have known sizes");
|
|
|
|
|
|
|
|
int64_t rank = dataTy.getSizes().size();
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
|
|
|
|
|
|
|
// Previous version of the operation had the axes as an attribute:
|
|
|
|
llvm::SmallVector<int64_t> axesAttr;
|
|
|
|
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
|
|
|
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
|
|
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), torchIntTy,
|
|
|
|
rewriter.getI64IntegerAttr(axesAttr[i])));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle cases that axes are explicitly specified.
|
|
|
|
// Extract the axes values from the axes operand.
|
|
|
|
// This really shouldn't happen but it helps pass weird tests.
|
|
|
|
// TODO: Derive the chosen axes from the data type and final result type
|
|
|
|
// instead of using the dynamic axes at operand[1].
|
|
|
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
auto sizes = axesType.getSizes();
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
axesType.getWithSizesAndDtype(llvm::SmallVector<int64_t>{1},
|
|
|
|
axesType.getOptionalDtype()),
|
|
|
|
axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
|
|
|
|
torchIntTy, extract);
|
|
|
|
axesList.push_back(dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle the noop case:
|
|
|
|
// When axes is empty and noop_with_empty_axes is set to true, input
|
|
|
|
// tensor will not be reduced, and the output tensor would be
|
|
|
|
// equivalent to input tensor.
|
|
|
|
if (axesList.empty() && noop_with_empty_axes) {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle case when no axes arg is passed but not a noop:
|
|
|
|
// Manually set positive axis to all dims.
|
|
|
|
if (axesList.empty()) {
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
|
|
|
axesList.push_back(dimValue);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle negative axis:
|
|
|
|
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
|
|
|
torchIntTy, data);
|
|
|
|
for (Value &axes : axesList) {
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, rankVal);
|
|
|
|
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
|
|
|
finalOffset);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle multiple axes case:
|
|
|
|
// ReduceProd on each dim, always set keepDimsBool == True to avoid
|
|
|
|
// segfault.
|
|
|
|
Value trueVal =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
SmallVector<int64_t> intermediateShape(rank, Torch::kUnknownSize);
|
|
|
|
Value dataReduceProd = data;
|
|
|
|
for (int i = 0, numAxes = axesList.size(); i < numAxes; i++) {
|
|
|
|
auto axis = axesList[i];
|
|
|
|
if (keepDims && i == numAxes - 1) {
|
|
|
|
dataReduceProd = rewriter.create<Torch::AtenProdDimIntOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
dataTy.getWithSizesAndDtype(resultType.getSizes(),
|
|
|
|
dataTy.getOptionalDtype()),
|
|
|
|
dataReduceProd, axis, trueVal, noneVal);
|
|
|
|
rewriter.replaceOp(binder.op, dataReduceProd);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Type resultTyReduceProd = dataTy.getWithSizesAndDtype(
|
|
|
|
ArrayRef(intermediateShape), dataTy.getOptionalDtype());
|
|
|
|
dataReduceProd = rewriter.create<Torch::AtenProdDimIntOp>(
|
|
|
|
binder.getLoc(), resultTyReduceProd, dataReduceProd, axis,
|
|
|
|
trueVal, noneVal);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Derived the final shape of the tensor after prod loop of each axis.
|
|
|
|
SmallVector<int64_t> dataReduceProdSize;
|
|
|
|
auto dataSize = dataTy.getSizes();
|
|
|
|
auto resultTypeSizes = resultType.getSizes();
|
|
|
|
if (!keepDims) {
|
|
|
|
// Handle the keepDimsBool == False case:
|
|
|
|
// 2 point algorithm to derive the static shape after prod loop.
|
|
|
|
int j = 0;
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
if (resultTypeSizes.size() && dataSize[i] == resultTypeSizes[j]) {
|
|
|
|
dataReduceProdSize.push_back(resultTypeSizes[i]);
|
|
|
|
j++;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
dataReduceProdSize.push_back(1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle the keepDimsBool == False case:
|
|
|
|
// Reshape the prod loop result to the final result shape.
|
|
|
|
SmallVector<Value> dataReduceProdShape;
|
|
|
|
for (auto dim : dataReduceProdSize)
|
|
|
|
dataReduceProdShape.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(dim)));
|
|
|
|
Value dataReduceProdShapeList =
|
|
|
|
rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
rewriter.getType<Torch::ListType>(
|
|
|
|
rewriter.getType<Torch::IntType>()),
|
|
|
|
dataReduceProdShape);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
|
|
|
binder.op, resultType, dataReduceProd, dataReduceProdShapeList);
|
|
|
|
return success();
|
|
|
|
});
|
2024-01-16 03:26:46 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
// ONNX.Range(start, limit, delta) -- limit is exclusive
|
|
|
|
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value start, limit, delta;
|
|
|
|
auto loc = binder.getLoc();
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
if (binder.tensorOperandAtIndex(start, 0) ||
|
|
|
|
binder.tensorOperandAtIndex(limit, 1) ||
|
|
|
|
binder.tensorOperandAtIndex(delta, 2) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Convert a 0-dimensional/Scalar Tensor ([]) to Scalar Torch Numeric
|
|
|
|
// Value torch.tensor(1.1) equivalent in ONNX to 1.1 as an example
|
|
|
|
// type of start, limit, delta can be one of: double, float, int16,
|
|
|
|
// int32, int64 Assuming start, limit and delta to be same type (could
|
|
|
|
// they be different?)
|
|
|
|
Torch::BaseTensorType startTensorType =
|
|
|
|
start.getType().cast<Torch::BaseTensorType>();
|
|
|
|
bool isFloatDType = startTensorType.getDtype().isF64() ||
|
|
|
|
startTensorType.getDtype().isF32();
|
|
|
|
bool isIntDType = startTensorType.getDtype().isInteger(16) ||
|
|
|
|
startTensorType.getDtype().isInteger(32) ||
|
|
|
|
startTensorType.getDtype().isInteger(64);
|
|
|
|
if (!isFloatDType && !isIntDType) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Expected the start, limit, delta to be one of "
|
|
|
|
"double, float, int16, int32, int64");
|
|
|
|
}
|
|
|
|
Value scalarStart, scalarLimit, scalarDelta;
|
|
|
|
if (isFloatDType) {
|
|
|
|
scalarStart = getItemOp<Torch::FloatType>(binder, rewriter, start);
|
|
|
|
scalarLimit = getItemOp<Torch::FloatType>(binder, rewriter, limit);
|
|
|
|
scalarDelta = getItemOp<Torch::FloatType>(binder, rewriter, delta);
|
|
|
|
} else {
|
|
|
|
scalarStart = getItemOp<Torch::IntType>(binder, rewriter, start);
|
|
|
|
scalarLimit = getItemOp<Torch::IntType>(binder, rewriter, limit);
|
|
|
|
scalarDelta = getItemOp<Torch::IntType>(binder, rewriter, delta);
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenArangeStartStepOp>(
|
|
|
|
binder.op, resultType, scalarStart, scalarLimit, scalarDelta, none,
|
|
|
|
none, none, none);
|
|
|
|
return success();
|
|
|
|
});
|
2024-03-07 09:01:05 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Size", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto loc = binder.getLoc();
|
|
|
|
auto &op = binder.op;
|
|
|
|
auto operandTy = cast<Torch::BaseTensorType>(operand.getType());
|
|
|
|
|
|
|
|
if (!operandTy.hasSizes())
|
|
|
|
return rewriter.notifyMatchFailure(op, "input rank unknown");
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> dims;
|
|
|
|
int64_t rank = operandTy.getSizes().size();
|
|
|
|
for (int i = 0; i < rank; ++i) {
|
|
|
|
auto iv = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
Value dim = rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
loc, rewriter.getType<Torch::IntType>(), operand, iv);
|
|
|
|
dims.push_back(dim);
|
|
|
|
}
|
|
|
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
|
|
|
if (dims.empty()) {
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
|
|
|
|
op, resultType, one, none, none, cstFalse);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
Value prod = dims[0];
|
|
|
|
for (int i = 1, s = dims.size(); i < s; ++i)
|
|
|
|
prod = rewriter.create<Torch::AtenMulIntOp>(loc, prod, dims[i]);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
|
|
|
|
op, resultType, prod, none, none, cstFalse);
|
|
|
|
return success();
|
|
|
|
});
|
2024-01-25 01:26:21 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Tile", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
Value repeatDims;
|
|
|
|
if (binder.tensorOperands(operand, repeatDims) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// convert repeatDims tensor to list of ints
|
|
|
|
auto repeatDimsSizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(repeatDims.getType()).getSizes();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Torch::BaseTensorType shapeType =
|
|
|
|
repeatDims.getType().cast<Torch::BaseTensorType>();
|
|
|
|
Type selectResultType = shapeType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
for (int i = 0; i < repeatDimsSizes[0]; i++) {
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, repeatDims, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
dimList.push_back(dim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTileOp>(binder.op, resultType,
|
|
|
|
operand, dimValueList);
|
|
|
|
return success();
|
|
|
|
});
|
2024-01-23 04:56:39 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Topk", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType Values_type, Indices_type;
|
|
|
|
Value X, K;
|
|
|
|
int64_t axis;
|
|
|
|
bool largest, sorted;
|
|
|
|
if (binder.tensorOperandAtIndex(X, 0) ||
|
|
|
|
binder.tensorOperandAtIndex(K, 1) ||
|
|
|
|
binder.s64IntegerAttr(axis, "axis", -1) ||
|
|
|
|
binder.s64BoolAttr(largest, "largest", true) ||
|
|
|
|
binder.s64BoolAttr(sorted, "sorted", true) ||
|
|
|
|
binder.tensorResultTypeAtIndex(Values_type, 0) ||
|
|
|
|
binder.tensorResultTypeAtIndex(Indices_type, 1))
|
|
|
|
return failure();
|
|
|
|
std::optional<unsigned> maybeRank = Torch::getTensorRank(X);
|
|
|
|
if (!maybeRank)
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"Unimplemented: unranked tensor");
|
|
|
|
unsigned rank = *maybeRank;
|
|
|
|
axis = Torch::toPositiveDim(axis, rank);
|
|
|
|
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(axis));
|
|
|
|
Value cstLargest =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), largest);
|
|
|
|
Value cstSorted =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), sorted);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTopkOp>(
|
|
|
|
binder.op, Values_type, Indices_type, X, K, cstAxis, cstLargest,
|
|
|
|
cstSorted);
|
|
|
|
return success();
|
|
|
|
});
|
2024-01-25 01:26:21 +08:00
|
|
|
patterns.onOp("Sign", 9,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSignOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
2024-03-25 22:59:07 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Softplus", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value input;
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
// out = ln(exp(x) + 1)
|
|
|
|
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
|
|
|
|
resultType, input);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLog1pOp>(binder.op, resultType,
|
|
|
|
exp);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value input;
|
|
|
|
int64_t upper;
|
|
|
|
if (binder.tensorOperandAtIndex(input, 0) ||
|
|
|
|
binder.s64IntegerAttr(upper, "upper", 1) ||
|
|
|
|
binder.tensorResultType(resultType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
Value diagonal;
|
|
|
|
if (binder.tensorOperandAtIndex(diagonal, 1)) {
|
|
|
|
diagonal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
|
|
|
} else {
|
|
|
|
diagonal = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), diagonal);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (upper) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTriuOp>(binder.op, resultType,
|
|
|
|
input, diagonal);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTrilOp>(binder.op, resultType,
|
|
|
|
input, diagonal);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("ThresholdedRelu", 10,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value input;
|
|
|
|
float alpha;
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
binder.f32FloatAttr(alpha, "alpha", 1.0)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
|
|
|
|
Value value = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenThresholdOp>(
|
|
|
|
binder.op, resultType, input, cstAlpha, value);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-15 00:53:47 +08:00
|
|
|
}
|