2023-11-22 13:02:55 +08:00
|
|
|
//===------------------------------------------------------------*- C++ -*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
2023-11-22 13:02:55 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::onnx_c;
|
|
|
|
|
|
|
|
// Simple rewrites for the default domain.
|
|
|
|
// See: https://onnx.ai/onnx/operators/
|
|
|
|
// For operators that are effectively version invariant, we register with
|
|
|
|
// sinceVersion==1. We interpret this to include the following spec
|
|
|
|
// diffs that are irrelevant to this level of lowering:
|
|
|
|
// * Supported element types.
|
|
|
|
// * Limited broadcasting to full broadcasting support.
|
|
|
|
//
|
|
|
|
// There are a lot of spec revisions that basically generalized elementwise
|
|
|
|
// to be more normal and a direct translation vs a special case. This
|
|
|
|
// results in a lot of ONNX test cases that all reduce to the exact same
|
|
|
|
// thing here, so we simplify.
|
|
|
|
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
2023-12-15 00:53:47 +08:00
|
|
|
OnnxCustomOpConversionPattern &patterns) {
|
[MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601)
This commit adds the OnnxToTorch support for Reciprocal, Round,
ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor,
Squeeze, Unsqueeze ops.
For reviewers, the ops that weren't trivial and probably require extra
review are Sum, Squeeze, and Unsqueeze.
2023-12-16 01:36:18 +08:00
|
|
|
patterns.onOp("Reciprocal", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReciprocalOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Relu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReluOp>(binder.op, resultType,
|
|
|
|
x);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Round", 11,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenRoundOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"ScatterElements", 18,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
int64_t axis;
|
|
|
|
std::string reduction;
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
binder.s64IntegerAttr(axis, "axis", 0) ||
|
|
|
|
binder.customOpNameStringAttr(reduction, "reduction", "none") ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Value data = valList[0];
|
|
|
|
Value indices = valList[1];
|
|
|
|
Value updates = valList[2];
|
|
|
|
|
|
|
|
// ONNX allows negative axis.
|
|
|
|
if (axis < 0)
|
|
|
|
axis +=
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
|
|
|
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
|
|
|
|
|
|
if (reduction == "none") {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenScatterSrcOp>(
|
|
|
|
binder.op, resultType, data, constAxis, indices, updates);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Implement max and min cases
|
|
|
|
if (reduction == "mul") {
|
|
|
|
reduction = "multiply";
|
|
|
|
} else if (reduction == "max" || reduction == "min") {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "max/min reduction unsupported for scatter elements");
|
|
|
|
}
|
|
|
|
|
|
|
|
Value cstStrReduction =
|
|
|
|
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
|
|
|
|
binder.op, resultType, data, constAxis, indices, updates,
|
|
|
|
cstStrReduction);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSigmoidOp>(binder.op, resultType,
|
|
|
|
x);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Sin", 7,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSinOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Tanh", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTanhOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Sqrt", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSqrtOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Sub", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
Value y;
|
|
|
|
if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Value const1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSubTensorOp>(
|
|
|
|
binder.op, resultType, x, y, const1);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Sum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
if (binder.op->getNumOperands() == 1) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
if (binder.tensorOperand(x) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOp(binder.op, x);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Value const1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
// Short circuit to binary add
|
|
|
|
if (numOperands == 2) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAddTensorOp>(
|
|
|
|
binder.op, resultType, valList[0], valList[1], const1);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// When binder.op->getNumOperands() > 2
|
|
|
|
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
|
|
|
|
binder.op->getContext());
|
|
|
|
Value curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
binder.getLoc(), resultType, valList[0], valList[1], const1);
|
|
|
|
for (int i = 2; i < numOperands; i++) {
|
|
|
|
if (i == numOperands - 1) {
|
|
|
|
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
binder.getLoc(), resultType, curr, valList[i], const1);
|
|
|
|
} else {
|
|
|
|
curr = rewriter.create<Torch::AtenAddTensorOp>(
|
|
|
|
binder.getLoc(), baseType, curr, valList[i], const1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(binder.op, curr);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp("Where", 1,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
SmallVector<Value> valList;
|
|
|
|
int64_t numOperands = binder.op->getNumOperands();
|
|
|
|
if (binder.tensorOperands(valList, numOperands) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Value condition = valList[0];
|
|
|
|
Value x = valList[1];
|
|
|
|
Value y = valList[2];
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
|
|
|
|
binder.op, resultType, condition, x, y);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Xor", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value x;
|
|
|
|
Value y;
|
|
|
|
if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLogicalXorOp>(binder.op,
|
|
|
|
resultType, x, y);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Squeeze", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
|
|
auto sizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
|
|
if (sizes.size() == 0) {
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op,
|
|
|
|
resultType, data);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
adjustmentInt));
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with neg axis: if (axis < 0) axis += rank
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimsSqueezeOp>(
|
|
|
|
binder.op, resultType, data, dimValueList);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Unsqueeze", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
// Unlike squeeze where we are able to lower to Torch::PrimsSqueezeOp,
|
|
|
|
// pytorch does not support torch.unsqueeze to insert multiple new dims.
|
|
|
|
// discussion can be found here:
|
|
|
|
// https://github.com/pytorch/pytorch/issues/9410
|
|
|
|
// So, for now, we unroll into multiple unsqueezes.
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
|
|
auto sizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
|
|
if (sizes.size() == 0) {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
adjustmentInt));
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with neg axis: if (axis < 0) axis += rank
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
Value cstFalse =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
Value updatedAxes = rewriter.create<Torch::AtenTensorOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
axesType.getWithSizesAndDtype(sizes, axesType.getOptionalDtype()),
|
|
|
|
dimValueList, /*dtype=*/noneVal, /*device=*/noneVal, cstFalse);
|
|
|
|
// Sort the list of dims, so we don't run into this situation:
|
|
|
|
// data.sizes = [2, 3, 4]
|
|
|
|
// dims = [4, 0]
|
|
|
|
// index 4 will be invalid to add a singleton dimension because
|
|
|
|
// data.sizes.size == 3 We have to work with sorted dims to avoid this
|
|
|
|
// situation.
|
|
|
|
auto sortIndicesType = axesType.getWithSizesAndDtype(
|
|
|
|
axesType.getOptionalSizes(),
|
|
|
|
IntegerType::get(binder.op->getContext(), 64, IntegerType::Signed));
|
|
|
|
auto sortOpResult = rewriter.create<Torch::AtenSortOp>(
|
|
|
|
binder.getLoc(), axes.getType(), sortIndicesType, updatedAxes, zero,
|
|
|
|
cstFalse);
|
|
|
|
Value result;
|
|
|
|
auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation(
|
|
|
|
binder.op->getContext());
|
|
|
|
// Go through the updated, sorted axes. Do unsqueeze for each dim.
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, sortOpResult->getResult(0),
|
|
|
|
zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
if (sizes[0] == 1) {
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
|
|
|
binder.getLoc(), resultType, data, dim);
|
|
|
|
} else if (i == 0) {
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
|
|
|
binder.getLoc(), baseType, data, dim);
|
|
|
|
} else if (i == sizes[0] - 1) {
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
|
|
|
binder.getLoc(), resultType, result, dim);
|
|
|
|
} else {
|
|
|
|
result = rewriter.create<Torch::AtenUnsqueezeOp>(
|
|
|
|
binder.getLoc(), baseType, result, dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(binder.op, result);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"Softmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value input;
|
|
|
|
int64_t axis;
|
|
|
|
if (binder.tensorOperand(input) ||
|
|
|
|
binder.s64IntegerAttr(axis, "axis", -1) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// ONNX allows negative axis.
|
|
|
|
if (axis < 0)
|
|
|
|
axis +=
|
|
|
|
cast<Torch::ValueTensorType>(input.getType()).getSizes().size();
|
|
|
|
|
|
|
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
|
|
|
|
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSoftmaxIntOp>(
|
|
|
|
binder.op, resultType, input, constAxis, /*dtype=*/noneVal);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-15 00:53:47 +08:00
|
|
|
|
|
|
|
patterns.onOp(
|
|
|
|
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
float alpha, gamma;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.f32FloatAttr(alpha, "alpha") ||
|
|
|
|
binder.f32FloatAttr(gamma, "gamma") ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
|
|
|
|
|
|
|
|
Value vScale = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));
|
|
|
|
|
|
|
|
Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
|
|
|
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
|
|
|
|
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-19 04:37:31 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"ReduceSum", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
int64_t keepDims;
|
|
|
|
int64_t noop_with_empty_axes;
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
0))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
|
|
auto sizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
// Deal with case when axes is empty
|
|
|
|
if (sizes.size() == 1 && sizes[0] == 0) {
|
|
|
|
if (noop_with_empty_axes == 0) {
|
|
|
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
|
|
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
|
|
|
binder.getLoc(), keepDimsConstInt);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
|
|
|
binder.op, resultType, data, /*dim=*/noneVal,
|
|
|
|
/*keepdim=*/keepDimsBool, /*dtype=*/noneVal);
|
|
|
|
} else {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
adjustmentInt));
|
|
|
|
// convert axes (tensor) into torch int list while dealing with neg axis
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with neg axis: if (axis < 0) axis += rank
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
Value keepDimBool;
|
|
|
|
if (keepDims == 1) {
|
|
|
|
keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
|
|
} else {
|
|
|
|
keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSumDimIntListOp>(
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"ReduceMean", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
int64_t keepDims;
|
|
|
|
int64_t noop_with_empty_axes;
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
0))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
|
|
auto sizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
|
|
|
// deal with case when axes is empty
|
|
|
|
if (sizes.size() == 1 && sizes[0] == 0) {
|
|
|
|
if (noop_with_empty_axes == 0) {
|
|
|
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
|
|
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
|
|
|
binder.getLoc(), keepDimsConstInt);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
|
|
|
binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool,
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
} else {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
adjustmentInt));
|
|
|
|
// convert axes (tensor) into torch int list while dealing with neg axis
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with neg axis: if (axis < 0) axis += rank
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
Value keepDimBool;
|
|
|
|
if (keepDims == 1) {
|
|
|
|
keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
|
|
} else {
|
|
|
|
keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenMeanDimOp>(
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool,
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
patterns.onOp(
|
|
|
|
"ReduceMin", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
// AtenAminOp allows us to pass a list of dims
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value axes;
|
|
|
|
int64_t keepDims;
|
|
|
|
int64_t noop_with_empty_axes;
|
|
|
|
// Deal with case when no axes arg is passed
|
|
|
|
if (binder.op->getNumOperands() == 1) {
|
|
|
|
if (binder.tensorOperand(data) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes,
|
|
|
|
"noop_with_empty_axes", 0))
|
|
|
|
return failure();
|
|
|
|
if (noop_with_empty_axes == 0) {
|
|
|
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
|
|
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
|
|
|
binder.getLoc(), keepDimsConstInt);
|
|
|
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
|
|
|
.getSizes()
|
|
|
|
.size();
|
|
|
|
SmallVector<Value> axesList;
|
|
|
|
for (int i = 0; i < numDims; i++) {
|
|
|
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
axesList.push_back(curr);
|
|
|
|
}
|
|
|
|
Value axesValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
axesList);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
|
|
|
binder.op, resultType, data, axesValueList, keepDimsBool);
|
|
|
|
} else {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
if (binder.tensorOperands(data, axes) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
|
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
|
|
0))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType axesType =
|
|
|
|
axes.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
|
|
auto sizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
|
|
// deal with case when axes is empty
|
|
|
|
if (sizes.size() == 1 && sizes[0] == 0) {
|
|
|
|
if (noop_with_empty_axes == 0) {
|
|
|
|
// create dims list with all dims [0, data.getSizes().size())
|
|
|
|
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
|
|
|
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
|
|
|
binder.getLoc(), keepDimsConstInt);
|
|
|
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
|
|
|
.getSizes()
|
|
|
|
.size();
|
|
|
|
for (int i = 0; i < numDims; i++) {
|
|
|
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
dimList.push_back(curr);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimsBool);
|
|
|
|
} else {
|
|
|
|
rewriter.replaceOp(binder.op, data);
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
int64_t adjustmentInt =
|
|
|
|
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
|
|
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
adjustmentInt));
|
|
|
|
// convert axes (tensor) into torch int list while dealing with neg axis
|
|
|
|
for (int i = 0; i < sizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with neg axis: if (axis < 0) axis += rank
|
|
|
|
Value isNegative =
|
|
|
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
|
|
isNegative);
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isNegative, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
Value keepDimBool;
|
|
|
|
if (keepDims == 1) {
|
|
|
|
keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
|
|
} else {
|
|
|
|
keepDimBool =
|
|
|
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
|
|
|
binder.op, resultType, data, dimValueList, keepDimBool);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-16 03:37:49 +08:00
|
|
|
|
|
|
|
patterns.onOp("Shape", 9,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-16 13:23:51 +08:00
|
|
|
|
|
|
|
patterns.onOp("Sinh", 9,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-21 02:09:39 +08:00
|
|
|
|
2023-12-28 09:53:07 +08:00
|
|
|
// split with fixed-size parts
|
|
|
|
// Arguments:
|
|
|
|
// - input: the tensor to split
|
|
|
|
// Attributes:
|
|
|
|
// - axis: the axis along which to split the input
|
|
|
|
// - num_outputs: the number of outputs to produce
|
|
|
|
// Outputs:
|
|
|
|
// - outputs: the produced outputs. Variadic with num_outputs elements.
|
|
|
|
// Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
|
|
|
|
// tensors
|
|
|
|
// so we need to unpack the list
|
|
|
|
patterns.onOp(
|
|
|
|
"Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Value self;
|
|
|
|
int64_t axis;
|
|
|
|
int64_t num_outputs;
|
|
|
|
if (binder.tensorOperand(self))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Not converting to AtenSplitTensorOp due to input "
|
|
|
|
"tensor mismatch");
|
|
|
|
if (binder.s64IntegerAttr(axis, "axis", 0))
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"Failed to get axis attribute");
|
|
|
|
if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Failed to get num_outputs attribute");
|
|
|
|
|
|
|
|
auto result0Ty =
|
|
|
|
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto selfTy = self.getType().cast<Torch::ValueTensorType>();
|
|
|
|
|
|
|
|
int64_t dim = axis;
|
|
|
|
if (dim < 0)
|
|
|
|
dim += selfTy.getSizes().size();
|
|
|
|
|
|
|
|
// set intermediate shape to the shape of the first result
|
|
|
|
// if the results are of different shapes
|
|
|
|
// set the splitted axis to variable shape
|
|
|
|
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
|
|
|
for (auto result : binder.op->getResultTypes()) {
|
|
|
|
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
|
|
|
|
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim));
|
|
|
|
|
|
|
|
Value splitSize = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), num_outputs));
|
|
|
|
|
|
|
|
// TODO: Attempting to use the shape expected by the ONNX mlir as ground
|
|
|
|
// truth. For now just use dynamic shapes.
|
|
|
|
auto resultOuterType =
|
|
|
|
Torch::ListType::get(rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
/*std::optional<llvm::ArrayRef<int64_t>>=*/intermediateShape,
|
|
|
|
result0Ty.getOptionalDtype()));
|
|
|
|
Torch::AtenSplitTensorOp new_op =
|
|
|
|
rewriter.create<Torch::AtenSplitTensorOp>(
|
|
|
|
binder.getLoc(), resultOuterType, self, splitSize, dimValue);
|
|
|
|
|
|
|
|
// the onnx op is variadic with multiple results, but AtenSplitWithSizes
|
|
|
|
// outputs a list so we need to unpack the list
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListUnpackOp>(
|
|
|
|
binder.op, binder.op->getResults().getType(), new_op.getResult());
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
|
|
|
|
// split with variable parts
|
|
|
|
// Arguments:
|
|
|
|
// - input: the tensor to split
|
|
|
|
// - split: the sizes of the splits to be produced
|
|
|
|
// Attributes:
|
|
|
|
// - axis: the axis along which to split the input
|
|
|
|
// - num_outputs: the number of outputs to produce
|
|
|
|
// Outputs:
|
|
|
|
// - outputs: the produced outputs. Variadic with num_outputs elements.
|
|
|
|
// Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of
|
|
|
|
// tensors
|
|
|
|
// so we need to unpack the list
|
|
|
|
patterns.onOp(
|
|
|
|
"Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Value self;
|
|
|
|
Value split;
|
|
|
|
int64_t axis;
|
|
|
|
int64_t num_outputs;
|
|
|
|
if (binder.tensorOperandAtIndex(self, 0) ||
|
|
|
|
binder.tensorOperandAtIndex(split, 1))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Not converting to AtenSplitWithSizesOp due to input "
|
|
|
|
"tensor mismatch");
|
|
|
|
if (binder.s64IntegerAttr(axis, "axis", 0))
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"Failed to get axis attribute");
|
|
|
|
if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Failed to get num_outputs attribute");
|
|
|
|
|
|
|
|
auto result0Ty =
|
|
|
|
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>();
|
|
|
|
auto selfTy =
|
|
|
|
cast<Torch::ValueTensorType>(binder.op->getOperand(0).getType());
|
|
|
|
|
|
|
|
int64_t dim = axis;
|
|
|
|
if (dim < 0)
|
|
|
|
dim += selfTy.getSizes().size();
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> intermediateShape(result0Ty.getSizes());
|
|
|
|
for (auto result : binder.op->getResultTypes()) {
|
|
|
|
int64_t d = result.cast<Torch::ValueTensorType>().getSizes()[dim];
|
|
|
|
intermediateShape[dim] = d == intermediateShape[dim] ? d : -1;
|
|
|
|
}
|
|
|
|
|
|
|
|
Torch::PrimTolistOp splitToList = rewriter.create<Torch::PrimTolistOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(rewriter.getType<Torch::IntType>()), split);
|
|
|
|
|
|
|
|
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim));
|
|
|
|
|
|
|
|
// TODO: Attempting to use the shape expected by the ONNX mlir as ground
|
|
|
|
// truth. For now just use dynamic shapes.
|
|
|
|
auto resultOuterType =
|
|
|
|
Torch::ListType::get(rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
/*std::optional<llvm::ArrayRef<int64_t>>=*/intermediateShape,
|
|
|
|
result0Ty.getOptionalDtype()));
|
|
|
|
Torch::AtenSplitWithSizesOp new_op =
|
|
|
|
rewriter.create<Torch::AtenSplitWithSizesOp>(
|
|
|
|
binder.getLoc(), resultOuterType, self,
|
|
|
|
splitToList.getResult(0), dimValue);
|
|
|
|
|
|
|
|
// the onnx op is variadic with multiple results, but AtenSplitWithSizes
|
|
|
|
// outputs a list so we need to unpack the list
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::PrimListUnpackOp>(
|
|
|
|
binder.op, binder.op->getResults().getType(), new_op.getResult());
|
|
|
|
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
|
2023-12-21 02:09:39 +08:00
|
|
|
patterns.onOp("Tan", 7,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenTanOp>(
|
|
|
|
binder.op, resultType, operand);
|
|
|
|
return success();
|
|
|
|
});
|
|
|
|
|
2023-12-16 07:30:05 +08:00
|
|
|
patterns.onOp(
|
|
|
|
"Transpose", 13,
|
|
|
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
auto loc = binder.getLoc();
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value operand;
|
|
|
|
if (binder.tensorOperand(operand) ||
|
|
|
|
binder.tensorResultType(resultType))
|
|
|
|
return failure();
|
|
|
|
auto operandType = operand.getType().cast<Torch::ValueTensorType>();
|
|
|
|
TensorType tensorType = operandType.toBuiltinTensor();
|
|
|
|
if (!tensorType || !tensorType.hasRank())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Default permutation is to reverse orders:
|
|
|
|
int64_t rank = tensorType.getRank();
|
|
|
|
llvm::SmallVector<int64_t> reverse(rank);
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
reverse[i] = rank - i - 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> permutations;
|
|
|
|
if (failed(binder.s64IntegerArrayAttr(permutations, "perm", reverse)))
|
|
|
|
return rewriter.notifyMatchFailure(binder.op,
|
|
|
|
"Failed to obtain permutations");
|
|
|
|
|
|
|
|
if (static_cast<int64_t>(permutations.size()) != rank)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
binder.op, "Permutation length does not match operand rank");
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> shape(tensorType.getShape());
|
|
|
|
llvm::SmallVector<int64_t> current(rank);
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
current[i] = i;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
if (current[i] == permutations[i])
|
|
|
|
continue;
|
|
|
|
|
|
|
|
int64_t target = i + 1;
|
|
|
|
for (; target < rank; ++target) {
|
|
|
|
if (current[target] == permutations[i])
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::swap(shape[i], shape[target]);
|
|
|
|
std::swap(current[i], current[target]);
|
|
|
|
|
|
|
|
Value dim0 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
|
|
|
|
Value dim1 = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), target));
|
|
|
|
|
|
|
|
operand = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
loc,
|
|
|
|
Torch::ValueTensorType::get(tensorType.getContext(), shape,
|
|
|
|
operandType.getOptionalDtype()),
|
|
|
|
operand, dim0, dim1);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(binder.op, operand);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-27 02:20:13 +08:00
|
|
|
|
|
|
|
patterns.onOp(
|
|
|
|
"Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
|
|
Torch::ValueTensorType resultType;
|
|
|
|
Value data;
|
|
|
|
Value shape;
|
|
|
|
int64_t allowzero;
|
|
|
|
if (binder.tensorOperands(data, shape) ||
|
|
|
|
binder.tensorResultType(resultType) ||
|
|
|
|
binder.s64IntegerAttr(allowzero, "allowzero", 0))
|
|
|
|
return failure();
|
|
|
|
Torch::BaseTensorType shapeType =
|
|
|
|
shape.getType().cast<Torch::BaseTensorType>();
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
SmallVector<int64_t> selectSizes;
|
|
|
|
selectSizes.push_back(1);
|
|
|
|
Type selectResultType = shapeType.getWithSizesAndDtype(
|
|
|
|
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
|
|
|
auto shapeSizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
|
|
|
|
auto dataSizes =
|
|
|
|
dyn_cast<Torch::ValueTensorType>(data.getType()).getSizes();
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
if (allowzero == 0) {
|
|
|
|
// convert shape (tensor) into torch int list while dealing with zero
|
|
|
|
// vals
|
|
|
|
for (int i = 0; i < shapeSizes[0]; i++) {
|
|
|
|
// Go through the shape list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, shape, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
// deal with zero axis values: replace with original dim value in
|
|
|
|
// input
|
|
|
|
Value isZero =
|
|
|
|
rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero);
|
|
|
|
isZero =
|
|
|
|
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero);
|
|
|
|
Value adjustment;
|
|
|
|
int64_t inputDimsSize = dataSizes.size();
|
|
|
|
if (i < inputDimsSize) {
|
|
|
|
adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
|
|
dataSizes[i]));
|
|
|
|
}
|
|
|
|
// Will never have a 0 in the shape tensor input at an index out of
|
|
|
|
// bounds of original input dims Therefore, no need to adjust
|
|
|
|
else {
|
|
|
|
adjustment = zero;
|
|
|
|
}
|
|
|
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
binder.getLoc(), isZero, adjustment);
|
|
|
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
|
|
binder.getLoc(), dim, finalOffset);
|
|
|
|
dimList.push_back(finalDim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(
|
|
|
|
Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
|
|
|
binder.op, resultType, data, dimValueList);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
// convert axes (tensor) into torch int list
|
|
|
|
for (int i = 0; i < shapeSizes[0]; i++) {
|
|
|
|
// Go through the axes list and get each dim in the list
|
|
|
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
|
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
|
|
binder.getLoc(), selectResultType, shape, zero, selectIndex);
|
|
|
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
|
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
|
|
dimList.push_back(dim);
|
|
|
|
}
|
|
|
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
binder.getLoc(),
|
|
|
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
|
|
dimList);
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(binder.op, resultType,
|
|
|
|
data, dimValueList);
|
|
|
|
return success();
|
|
|
|
});
|
2023-12-15 00:53:47 +08:00
|
|
|
}
|