2021-10-16 06:23:59 +08:00
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
//
|
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
|
|
|
|
|
2022-02-09 04:57:23 +08:00
|
|
|
|
#include "mlir/IR/BuiltinDialect.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
2022-12-09 01:26:38 +08:00
|
|
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2022-04-26 20:18:09 +08:00
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-05-10 21:15:59 +08:00
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
|
#include "llvm/ADT/StringExtras.h"
|
2022-12-09 01:26:38 +08:00
|
|
|
|
#include "llvm/ADT/StringSet.h"
|
2022-05-10 21:15:59 +08:00
|
|
|
|
#include <cstdint>
|
2024-03-28 03:42:10 +08:00
|
|
|
|
#include <set>
|
2021-10-16 06:23:59 +08:00
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::torch;
|
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
2022-03-11 01:25:21 +08:00
|
|
|
|
// Helper function to check whether the `dtype` is None or Float type.
|
|
|
|
|
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
|
|
|
|
if (dtype.getType().isa<Torch::NoneType>())
|
|
|
|
|
return true;
|
|
|
|
|
int64_t dtypeInt;
|
|
|
|
|
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
|
|
|
|
return false;
|
2023-01-21 02:40:13 +08:00
|
|
|
|
FailureOr<Type> resDtype =
|
2022-03-11 01:25:21 +08:00
|
|
|
|
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
2023-01-21 02:40:13 +08:00
|
|
|
|
if (failed(resDtype))
|
|
|
|
|
return false;
|
|
|
|
|
return resDtype->isa<mlir::FloatType>();
|
2022-03-11 01:25:21 +08:00
|
|
|
|
}
|
|
|
|
|
|
2022-02-01 03:56:32 +08:00
|
|
|
|
// Helper function to compute the return type of the reduction function.
|
|
|
|
|
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
|
|
|
|
|
// the input tensor.
|
|
|
|
|
static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
2022-06-29 15:23:57 +08:00
|
|
|
|
BaseTensorType tensorType, Value dim,
|
|
|
|
|
bool keepDim) {
|
2021-11-08 23:56:40 +08:00
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
|
int64_t dimInt;
|
|
|
|
|
if (tensorType.hasSizes()) {
|
|
|
|
|
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
|
|
|
|
int64_t inputRank = inputShape.size();
|
|
|
|
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
|
|
|
|
dimInt = toPositiveDim(dimInt, inputRank);
|
|
|
|
|
if (!isValidDim(dimInt, inputRank)) {
|
|
|
|
|
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
sizes.append(inputShape.begin(), inputShape.end());
|
2022-02-01 03:56:32 +08:00
|
|
|
|
// The dimension to be reduced is set to 1 when `keepDim` is true else it
|
|
|
|
|
// is removed.
|
|
|
|
|
if (keepDim)
|
|
|
|
|
sizes[dimInt] = 1;
|
|
|
|
|
else
|
2022-11-23 17:36:44 +08:00
|
|
|
|
sizes.erase(sizes.begin() + dimInt);
|
2021-11-08 23:56:40 +08:00
|
|
|
|
} else {
|
2022-02-01 03:56:32 +08:00
|
|
|
|
unsigned reducedRank = keepDim ? inputRank : inputRank - 1;
|
|
|
|
|
sizes.resize(reducedRank, kUnknownSize);
|
2021-11-08 23:56:40 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type resultType = tensorType.getWithSizesAndDtype(
|
2024-03-07 08:48:21 +08:00
|
|
|
|
!tensorType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
|
|
|
|
: llvm::ArrayRef(sizes),
|
2023-01-04 06:19:18 +08:00
|
|
|
|
tensorType.getOptionalDtype());
|
2022-02-01 03:56:32 +08:00
|
|
|
|
return resultType;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Reduction function to calculate sum along given `dim`.
|
|
|
|
|
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Operation *op, Value input, Value dim,
|
|
|
|
|
bool keepDim) {
|
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(dim.getType()), dim);
|
|
|
|
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
|
|
|
|
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
2022-06-29 15:23:57 +08:00
|
|
|
|
Type resultType = computeReductionType(
|
|
|
|
|
rewriter, op, input.getType().cast<BaseTensorType>(), dim, keepDim);
|
2022-02-01 03:56:32 +08:00
|
|
|
|
if (!resultType)
|
|
|
|
|
return nullptr;
|
|
|
|
|
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
|
|
|
|
|
keepDimCst, dtype);
|
|
|
|
|
}
|
|
|
|
|
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
// Reduction function to calculate max along given `dim`.
|
2022-02-01 03:56:32 +08:00
|
|
|
|
static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Operation *op, Value input, Value dim,
|
|
|
|
|
bool keepDim) {
|
|
|
|
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
|
|
|
|
BaseTensorType valueType =
|
2022-06-29 15:23:57 +08:00
|
|
|
|
computeReductionType(rewriter, op, input.getType().cast<BaseTensorType>(),
|
|
|
|
|
dim, keepDim)
|
2022-02-01 03:56:32 +08:00
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
if (!valueType)
|
|
|
|
|
return nullptr;
|
|
|
|
|
BaseTensorType indexType =
|
|
|
|
|
valueType
|
|
|
|
|
.getWithSizesAndDtype(
|
2022-12-20 18:17:27 +08:00
|
|
|
|
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
2023-01-25 09:29:42 +08:00
|
|
|
|
: llvm::ArrayRef(valueType.getSizes()),
|
2022-02-01 03:56:32 +08:00
|
|
|
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
return rewriter
|
|
|
|
|
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
2022-12-08 04:20:41 +08:00
|
|
|
|
.getValues();
|
2021-11-08 23:56:40 +08:00
|
|
|
|
}
|
|
|
|
|
|
2021-11-19 02:02:20 +08:00
|
|
|
|
// Helper for creating `aten::sub_tensor_op`.
|
2021-11-19 20:18:41 +08:00
|
|
|
|
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Type tensorType, Value lhs, Value rhs) {
|
2021-11-19 02:02:20 +08:00
|
|
|
|
Value alpha =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
|
|
|
|
Value sub =
|
|
|
|
|
rewriter.create<AtenSubTensorOp>(loc, tensorType, lhs, rhs, alpha);
|
|
|
|
|
return sub;
|
|
|
|
|
}
|
|
|
|
|
|
2021-11-19 20:18:41 +08:00
|
|
|
|
// Share code between `softmax_backward` and `log_softmax_backward` ops.
|
|
|
|
|
// Returns x - y * sum(z, dim).
|
|
|
|
|
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
|
|
|
|
Location loc, Operation *op,
|
|
|
|
|
Type tensorType, Value x,
|
|
|
|
|
Value y, Value z, Value dim) {
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value sum =
|
|
|
|
|
createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
|
2021-11-19 20:18:41 +08:00
|
|
|
|
if (!sum)
|
|
|
|
|
return nullptr;
|
|
|
|
|
auto broadcastSizeType =
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op->getContext()));
|
|
|
|
|
Value broadcastSize = rewriter.create<AtenSizeOp>(loc, broadcastSizeType, z);
|
|
|
|
|
Value sumBroadcast =
|
|
|
|
|
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
|
|
|
|
Value temp =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, y, sumBroadcast);
|
|
|
|
|
|
|
|
|
|
Value sub = createTensorSub(rewriter, loc, tensorType, x, temp);
|
|
|
|
|
return sub;
|
|
|
|
|
}
|
|
|
|
|
|
2023-01-02 22:34:39 +08:00
|
|
|
|
static SmallVector<int64_t> computeDimsOrderForMoveDim(int64_t srcDimInt,
|
|
|
|
|
int64_t dstDimInt,
|
|
|
|
|
unsigned inputRank) {
|
|
|
|
|
llvm::iota_range<int64_t> dimsOrderIR(0, inputRank, /*inclusive=*/false);
|
|
|
|
|
SmallVector<int64_t> dimsOrder(dimsOrderIR.begin(), dimsOrderIR.end());
|
|
|
|
|
dimsOrder.erase(dimsOrder.begin() + srcDimInt);
|
|
|
|
|
dimsOrder.insert(dimsOrder.begin() + dstDimInt, srcDimInt);
|
|
|
|
|
return dimsOrder;
|
|
|
|
|
}
|
|
|
|
|
|
2024-03-28 03:42:10 +08:00
|
|
|
|
static bool
|
|
|
|
|
rewriteEquationWithEllipsisSlicing(std::string &equation,
|
|
|
|
|
SmallVector<int64_t> &inputRanks) {
|
|
|
|
|
// split equation into input and result
|
|
|
|
|
size_t arrowPos = equation.find("->");
|
|
|
|
|
if (arrowPos == std::string::npos) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::string inputStr = equation.substr(0, arrowPos);
|
|
|
|
|
std::string resultStr = equation.substr(arrowPos + 2);
|
|
|
|
|
|
|
|
|
|
// split input into tokens
|
|
|
|
|
SmallVector<std::string> inputTokens;
|
|
|
|
|
size_t start = 0;
|
|
|
|
|
size_t end = 0;
|
|
|
|
|
std::set<char> usedTokens;
|
|
|
|
|
while (end < inputStr.size()) {
|
|
|
|
|
end = inputStr.find(",", start);
|
|
|
|
|
if (end == std::string::npos) {
|
|
|
|
|
end = inputStr.size();
|
|
|
|
|
}
|
|
|
|
|
std::string token = inputStr.substr(start, end - start);
|
|
|
|
|
inputTokens.push_back(token);
|
|
|
|
|
start = end + 1;
|
|
|
|
|
}
|
|
|
|
|
if (inputTokens.size() != inputRanks.size()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// find the rank which ellipsis represents, and max ellipsis rank because a
|
|
|
|
|
// tensor can be broadcasted
|
|
|
|
|
SmallVector<int64_t> ellipsisRanks;
|
|
|
|
|
int maxEllipsisRank = 0;
|
|
|
|
|
for (const auto &[token, inputRank] : llvm::zip(inputTokens, inputRanks)) {
|
|
|
|
|
int explictRank = 0;
|
|
|
|
|
for (auto c : token) {
|
|
|
|
|
if (std::isalpha(c)) {
|
|
|
|
|
usedTokens.insert(c);
|
|
|
|
|
explictRank++;
|
|
|
|
|
} else if (c == '.' || c == ' ') {
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int ellipsisRank = inputRank - explictRank;
|
|
|
|
|
if (ellipsisRank > maxEllipsisRank) {
|
|
|
|
|
maxEllipsisRank = ellipsisRank;
|
|
|
|
|
}
|
|
|
|
|
if (ellipsisRank < 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
ellipsisRanks.push_back(inputRank - explictRank);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto isTokenUsed = [&usedTokens](char c) {
|
|
|
|
|
return usedTokens.find(c) != usedTokens.end();
|
|
|
|
|
};
|
|
|
|
|
std::string ellipsisToken;
|
|
|
|
|
int usedCount = 0;
|
|
|
|
|
// Iterate over the alphabet to create a new token for ellipsis
|
|
|
|
|
for (char c = 'a'; c <= 'z'; ++c) {
|
|
|
|
|
if (!isTokenUsed(c)) {
|
|
|
|
|
ellipsisToken.push_back(c);
|
|
|
|
|
usedCount++;
|
|
|
|
|
if (usedCount == maxEllipsisRank) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// replace ellipsis with ellipsisToken
|
|
|
|
|
for (size_t i = 0; i < inputTokens.size(); i++) {
|
|
|
|
|
size_t ellipsisPos = inputTokens[i].find("...");
|
|
|
|
|
if (ellipsisPos == std::string::npos) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (ellipsisRanks[i] == maxEllipsisRank) {
|
|
|
|
|
inputTokens[i].replace(ellipsisPos, 3, ellipsisToken);
|
|
|
|
|
} else if (ellipsisRanks[i] == 0) {
|
|
|
|
|
inputTokens[i].replace(ellipsisPos, 3, "");
|
|
|
|
|
} else {
|
|
|
|
|
inputTokens[i].replace(
|
|
|
|
|
ellipsisPos, 3,
|
|
|
|
|
ellipsisToken.substr(ellipsisToken.size() - ellipsisRanks[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// replace ellipsis in result
|
|
|
|
|
size_t ellipsisPos = resultStr.find("...");
|
|
|
|
|
if (ellipsisPos != std::string::npos) {
|
|
|
|
|
resultStr.replace(ellipsisPos, 3, ellipsisToken);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// join input and result
|
|
|
|
|
equation = llvm::join(inputTokens, ",") + " -> " + resultStr;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
2023-12-10 12:30:37 +08:00
|
|
|
|
static bool parseEquation(const std::string &equation,
|
|
|
|
|
SmallVector<SmallVector<char>> &inputTokens,
|
|
|
|
|
SmallVector<char> &resultTokens) {
|
|
|
|
|
SmallVector<char> inputToken;
|
|
|
|
|
size_t index = 0;
|
|
|
|
|
enum EquationVariable { kIsInput, kIsResult };
|
|
|
|
|
EquationVariable currentVariable = kIsInput;
|
|
|
|
|
while (index < equation.size()) {
|
|
|
|
|
if (std::isalpha(equation[index])) {
|
|
|
|
|
if (currentVariable == kIsInput) {
|
|
|
|
|
inputToken.push_back(equation[index]);
|
|
|
|
|
} else {
|
|
|
|
|
resultTokens.push_back(equation[index]);
|
|
|
|
|
}
|
|
|
|
|
} else if (equation[index] == ',') {
|
|
|
|
|
inputTokens.push_back(inputToken);
|
|
|
|
|
inputToken.clear();
|
|
|
|
|
} else if ((index < (equation.size() - 1)) &&
|
|
|
|
|
(equation.substr(index, 2).find("->") != std::string::npos)) {
|
|
|
|
|
inputTokens.push_back(inputToken);
|
|
|
|
|
inputToken.clear();
|
|
|
|
|
currentVariable = kIsResult;
|
|
|
|
|
index++;
|
2024-02-15 10:18:11 +08:00
|
|
|
|
} else if (equation[index] != ' ') {
|
2023-12-10 12:30:37 +08:00
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
index++;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] =>
|
|
|
|
|
// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
|
|
|
|
|
static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Value input, int64_t batchDimsLength,
|
|
|
|
|
int64_t contractingDimsLength,
|
|
|
|
|
int64_t otherDimsLength,
|
|
|
|
|
int64_t reduceDimsLength, bool isLhs) {
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
|
|
|
|
reduceDimsLength;
|
|
|
|
|
SmallVector<Value> inputShapeTensor;
|
|
|
|
|
for (auto i = 0; i < inputRank; ++i) {
|
|
|
|
|
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>(
|
|
|
|
|
loc, input,
|
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
|
rewriter.getI64IntegerAttr(i))));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> outShapeTensor;
|
|
|
|
|
Value constOne =
|
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
auto dimOffset = 0;
|
|
|
|
|
|
|
|
|
|
auto appendDims = [&](int64_t dimLength) {
|
|
|
|
|
Value prod = constOne;
|
|
|
|
|
for (auto i = 0; i < dimLength; ++i) {
|
|
|
|
|
prod = rewriter.create<AtenMulIntOp>(loc, prod,
|
|
|
|
|
inputShapeTensor[i + dimOffset]);
|
|
|
|
|
}
|
|
|
|
|
outShapeTensor.emplace_back(prod);
|
|
|
|
|
dimOffset += dimLength;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
appendDims(batchDimsLength);
|
|
|
|
|
if (!isLhs)
|
|
|
|
|
appendDims(contractingDimsLength);
|
|
|
|
|
appendDims(otherDimsLength + reduceDimsLength);
|
|
|
|
|
if (isLhs)
|
|
|
|
|
appendDims(contractingDimsLength);
|
|
|
|
|
|
|
|
|
|
auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
|
|
|
|
outShapeTensor);
|
|
|
|
|
|
|
|
|
|
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
|
|
|
|
inputType.getOptionalDtype());
|
|
|
|
|
return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input,
|
|
|
|
|
outShapeValue);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// classify every dim token into different categories. Note that although we
|
|
|
|
|
// parse out reduce dims, we delay their execution until
|
|
|
|
|
// `performLastPermuteAndReduce`.
|
|
|
|
|
static void parseDimTokens(
|
|
|
|
|
SmallVector<char> &lhsTokens, SmallVector<char> &rhsTokens,
|
|
|
|
|
SmallVector<char> &finalResultTokens, SmallVector<char> &contractingDims,
|
|
|
|
|
SmallVector<char> &lhsReduceDims, SmallVector<char> &rhsReduceDims,
|
|
|
|
|
SmallVector<char> &batchingDims, SmallVector<char> &lhsOtherDims,
|
|
|
|
|
SmallVector<char> &rhsOtherDims) {
|
|
|
|
|
llvm::SmallDenseSet<char> lhsTokenSet(lhsTokens.begin(), lhsTokens.end());
|
|
|
|
|
llvm::SmallDenseSet<char> rhsTokenSet(rhsTokens.begin(), rhsTokens.end());
|
|
|
|
|
llvm::SmallDenseSet<char> finalResultTokenSet(finalResultTokens.begin(),
|
|
|
|
|
finalResultTokens.end());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < lhsTokens.size(); ++i) {
|
|
|
|
|
bool rhsContains = rhsTokenSet.contains(lhsTokens[i]);
|
|
|
|
|
bool finalResultConatins = finalResultTokenSet.contains(lhsTokens[i]);
|
|
|
|
|
// batching dim
|
|
|
|
|
if (rhsContains && finalResultConatins) {
|
|
|
|
|
batchingDims.push_back(lhsTokens[i]);
|
|
|
|
|
// reduce dim of lhs
|
|
|
|
|
} else if (!rhsContains && !finalResultConatins) {
|
|
|
|
|
lhsReduceDims.push_back(lhsTokens[i]);
|
|
|
|
|
// other dim of lhs
|
|
|
|
|
} else if (finalResultConatins) {
|
|
|
|
|
lhsOtherDims.push_back(lhsTokens[i]);
|
|
|
|
|
// contracting dim of lhs
|
|
|
|
|
} else if (rhsContains) {
|
|
|
|
|
contractingDims.push_back(lhsTokens[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < rhsTokens.size(); ++i) {
|
|
|
|
|
bool lhsContains = lhsTokenSet.contains(rhsTokens[i]);
|
|
|
|
|
bool finalResultConatins = finalResultTokenSet.contains(rhsTokens[i]);
|
|
|
|
|
// batching dim
|
|
|
|
|
if (lhsContains && finalResultConatins) {
|
|
|
|
|
// reduce dim of rhs
|
|
|
|
|
} else if (!lhsContains && !finalResultConatins) {
|
|
|
|
|
rhsReduceDims.push_back(rhsTokens[i]);
|
|
|
|
|
// other dim of rhs
|
|
|
|
|
} else if (finalResultConatins) {
|
|
|
|
|
rhsOtherDims.push_back(rhsTokens[i]);
|
|
|
|
|
// contracting dim of rhs
|
|
|
|
|
} else if (lhsContains) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void generateIdealReusltDimTokens(SmallVector<char> &batchingDims,
|
|
|
|
|
SmallVector<char> &lhsOtherDims,
|
|
|
|
|
SmallVector<char> &rhsOtherDims,
|
|
|
|
|
SmallVector<char> &lhsReduceDims,
|
|
|
|
|
SmallVector<char> &rhsReduceDims,
|
|
|
|
|
SmallVector<char> &resultTokens) {
|
|
|
|
|
// generate ideal result dims, i.e.,
|
|
|
|
|
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims,
|
|
|
|
|
// *rhsReduceDims]
|
|
|
|
|
resultTokens.insert(resultTokens.end(), batchingDims.begin(),
|
|
|
|
|
batchingDims.end());
|
|
|
|
|
resultTokens.insert(resultTokens.end(), lhsOtherDims.begin(),
|
|
|
|
|
lhsOtherDims.end());
|
|
|
|
|
resultTokens.insert(resultTokens.end(), lhsReduceDims.begin(),
|
|
|
|
|
lhsReduceDims.end());
|
|
|
|
|
resultTokens.insert(resultTokens.end(), rhsOtherDims.begin(),
|
|
|
|
|
rhsOtherDims.end());
|
|
|
|
|
resultTokens.insert(resultTokens.end(), rhsReduceDims.begin(),
|
|
|
|
|
rhsReduceDims.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Value input, SmallVector<char> &dimTokens,
|
|
|
|
|
SmallVector<char> &batchingDims,
|
|
|
|
|
SmallVector<char> &contractingDims,
|
|
|
|
|
SmallVector<char> &otherDims,
|
|
|
|
|
SmallVector<char> &reduceDims, bool isLhs) {
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
|
|
|
|
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
|
|
|
|
dimTokenMap[dimTokens[idx]] = idx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> permuteVec;
|
|
|
|
|
auto appendDims = [&](SmallVector<char> dimTokens) {
|
|
|
|
|
for (auto d : dimTokens) {
|
|
|
|
|
permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dimTokenMap[d])));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
appendDims(batchingDims);
|
|
|
|
|
if (!isLhs)
|
|
|
|
|
appendDims(contractingDims);
|
|
|
|
|
appendDims(otherDims);
|
|
|
|
|
appendDims(reduceDims);
|
|
|
|
|
if (isLhs)
|
|
|
|
|
appendDims(contractingDims);
|
|
|
|
|
|
|
|
|
|
Value dstDims = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
|
|
|
|
permuteVec);
|
|
|
|
|
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
|
|
|
|
inputType.getOptionalDtype());
|
|
|
|
|
return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Value lhs, SmallVector<char> &lhsTokens,
|
|
|
|
|
Value rhs, SmallVector<char> &rhsTokens,
|
|
|
|
|
Value &result,
|
|
|
|
|
SmallVector<char> &resultTokens,
|
|
|
|
|
SmallVector<char> &finalResultTokens) {
|
|
|
|
|
auto lhsType = lhs.getType().cast<BaseTensorType>();
|
|
|
|
|
auto rhsType = rhs.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
2024-01-30 01:59:33 +08:00
|
|
|
|
: rhsType.getOptionalDtype();
|
2023-12-10 12:30:37 +08:00
|
|
|
|
|
|
|
|
|
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
|
|
|
|
|
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
|
|
|
|
|
char d = lhsTokens[idx];
|
|
|
|
|
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
|
|
|
|
|
loc, lhs,
|
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
|
rewriter.getI64IntegerAttr(idx)));
|
|
|
|
|
}
|
|
|
|
|
llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
|
|
|
|
|
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
|
|
|
|
|
char d = rhsTokens[idx];
|
|
|
|
|
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
|
|
|
|
|
loc, rhs,
|
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
|
rewriter.getI64IntegerAttr(idx)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// parse batch, contracting, other, reduce dims of lhs and rhs
|
|
|
|
|
SmallVector<char> contractingDims;
|
|
|
|
|
SmallVector<char> lhsReduceDims;
|
|
|
|
|
SmallVector<char> rhsReduceDims;
|
|
|
|
|
SmallVector<char> lhsOtherDims;
|
|
|
|
|
SmallVector<char> rhsOtherDims;
|
|
|
|
|
SmallVector<char> batchingDims;
|
|
|
|
|
parseDimTokens(lhsTokens, rhsTokens, finalResultTokens, contractingDims,
|
|
|
|
|
lhsReduceDims, rhsReduceDims, batchingDims, lhsOtherDims,
|
|
|
|
|
rhsOtherDims);
|
|
|
|
|
|
|
|
|
|
llvm::SmallDenseMap<char, Value> outDimShapeMap;
|
|
|
|
|
auto generateOutDimShapeMap = [&](SmallVector<char> &dims) {
|
|
|
|
|
for (auto d : dims) {
|
|
|
|
|
bool lhsContains = lhsDimShapeMap.count(d) > 0;
|
|
|
|
|
bool rhsContains = rhsDimShapeMap.count(d) > 0;
|
|
|
|
|
if (lhsContains && rhsContains) {
|
|
|
|
|
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>(
|
|
|
|
|
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
|
|
|
|
|
} else if (lhsContains) {
|
|
|
|
|
outDimShapeMap[d] = lhsDimShapeMap[d];
|
|
|
|
|
} else if (rhsContains) {
|
|
|
|
|
outDimShapeMap[d] = rhsDimShapeMap[d];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
generateOutDimShapeMap(contractingDims);
|
|
|
|
|
generateOutDimShapeMap(batchingDims);
|
|
|
|
|
generateOutDimShapeMap(lhsReduceDims);
|
|
|
|
|
generateOutDimShapeMap(rhsReduceDims);
|
|
|
|
|
generateOutDimShapeMap(lhsOtherDims);
|
|
|
|
|
generateOutDimShapeMap(rhsOtherDims);
|
|
|
|
|
|
|
|
|
|
if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 &&
|
|
|
|
|
rhsOtherDims.size() == 0) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
loc, "Hadamard product is currently not supported");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims]
|
|
|
|
|
lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims,
|
|
|
|
|
contractingDims, lhsOtherDims, lhsReduceDims,
|
|
|
|
|
true);
|
|
|
|
|
// shape: [*batchingDims, *rhsContractingDims, *rhsOtherDims, *rhsReduceDims]
|
|
|
|
|
rhs = permuteTensorForMatmul(rewriter, loc, rhs, rhsTokens, batchingDims,
|
|
|
|
|
contractingDims, rhsOtherDims, rhsReduceDims,
|
|
|
|
|
false);
|
|
|
|
|
// shape: [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
|
|
|
|
|
lhs = collapseDimForMatmul(rewriter, loc, lhs, batchingDims.size(),
|
|
|
|
|
contractingDims.size(), lhsOtherDims.size(),
|
|
|
|
|
lhsReduceDims.size(), true);
|
|
|
|
|
// shape: [batchingDimsProd, rhsContractingDimsProd, rhsOtherDimsProd]
|
|
|
|
|
rhs = collapseDimForMatmul(rewriter, loc, rhs, batchingDims.size(),
|
|
|
|
|
contractingDims.size(), rhsOtherDims.size(),
|
|
|
|
|
rhsReduceDims.size(), false);
|
|
|
|
|
|
|
|
|
|
// perform matmul
|
|
|
|
|
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType);
|
|
|
|
|
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
|
|
|
|
|
|
|
|
|
|
// generate ideal result dims.
|
|
|
|
|
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
|
|
|
|
|
lhsReduceDims, rhsReduceDims, resultTokens);
|
|
|
|
|
|
|
|
|
|
// reshape matmul result to ideal shape:
|
|
|
|
|
// [batchingDimsProd, lhsOtherDimsProd, rhsOtherDimsProd] =>
|
|
|
|
|
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims,
|
|
|
|
|
// *rhsReduceDims]
|
|
|
|
|
SmallVector<Value> outShapeTensors;
|
|
|
|
|
for (char d : resultTokens) {
|
|
|
|
|
outShapeTensors.emplace_back(outDimShapeMap[d]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto outResultShape = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())),
|
|
|
|
|
outShapeTensors);
|
|
|
|
|
result = rewriter.create<Torch::AtenReshapeOp>(
|
|
|
|
|
loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result,
|
|
|
|
|
outResultShape);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
|
|
|
|
Location loc, Type outType,
|
|
|
|
|
Value input,
|
|
|
|
|
SmallVector<char> &inputTokens,
|
|
|
|
|
SmallVector<char> &outTokens) {
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end());
|
|
|
|
|
SmallVector<int64_t> sumDims;
|
|
|
|
|
llvm::SmallDenseMap<char, int64_t> inputDimToIdx;
|
|
|
|
|
int64_t idx = 0;
|
|
|
|
|
for (size_t i = 0; i < inputTokens.size(); ++i) {
|
|
|
|
|
char d = inputTokens[i];
|
|
|
|
|
if (!outTokenSet.contains(d)) {
|
|
|
|
|
sumDims.emplace_back(i);
|
|
|
|
|
} else {
|
|
|
|
|
inputDimToIdx[d] = idx++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (sumDims.size() > 0) {
|
|
|
|
|
SmallVector<Value> sumDimsTensor;
|
|
|
|
|
for (auto d : sumDims) {
|
|
|
|
|
sumDimsTensor.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(d)));
|
|
|
|
|
}
|
|
|
|
|
auto sumDimsListValue = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
|
|
|
|
sumDimsTensor);
|
|
|
|
|
auto falseValue = rewriter.create<Torch::ConstantBoolOp>(
|
|
|
|
|
loc, rewriter.getBoolAttr(false));
|
|
|
|
|
auto noneValue = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
input = rewriter.create<Torch::AtenSumDimIntListOp>(
|
|
|
|
|
loc,
|
|
|
|
|
inputType.getWithSizesAndDtype(std::nullopt,
|
|
|
|
|
inputType.getOptionalDtype()),
|
|
|
|
|
input, sumDimsListValue, falseValue, noneValue);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> permuteDimsTensor;
|
|
|
|
|
for (auto d : outTokens) {
|
|
|
|
|
permuteDimsTensor.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(inputDimToIdx[d])));
|
|
|
|
|
}
|
|
|
|
|
auto permuteDimsListValue = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
|
|
|
|
permuteDimsTensor);
|
|
|
|
|
auto out = rewriter.create<Torch::AtenPermuteOp>(loc, outType, input,
|
|
|
|
|
permuteDimsListValue);
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
2022-11-23 02:37:28 +08:00
|
|
|
|
namespace {
|
|
|
|
|
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
|
|
|
|
/// number of dimensions across which the max needs to be computed.
|
|
|
|
|
/// Eg:
|
|
|
|
|
/// INPUT:
|
|
|
|
|
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
|
|
|
|
|
///
|
|
|
|
|
/// OUTPUT:
|
|
|
|
|
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
|
|
|
|
|
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
|
|
|
|
|
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
|
|
|
|
|
///
|
|
|
|
|
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
|
|
|
|
|
/// of the `aten.amax` op and create an `aten.amax.dim` op.
|
|
|
|
|
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
|
|
|
|
|
/// previous `aten.amax.dim` op.
|
|
|
|
|
class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenAmaxOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
SmallVector<int64_t, 4> dims;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2022-11-23 02:37:28 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"non-const dim parameter unsupported");
|
|
|
|
|
|
|
|
|
|
bool keepDim;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
|
2022-11-23 02:37:28 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected a constant boolean value for keepDim");
|
|
|
|
|
|
2023-06-22 11:26:51 +08:00
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
auto inputTy = input.getType().dyn_cast<Torch::ValueTensorType>();
|
|
|
|
|
if (!inputTy || !inputTy.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"Expected input type having sizes");
|
|
|
|
|
}
|
2022-11-23 02:37:28 +08:00
|
|
|
|
// For every dimension included in `dim` of the op, iterated over in
|
|
|
|
|
// reverse order, we create a call to aten.max.dim.
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
std::sort(dims.rbegin(), dims.rend());
|
2022-12-13 00:56:28 +08:00
|
|
|
|
for (int64_t dimInt : dims) {
|
2023-06-22 11:26:51 +08:00
|
|
|
|
int64_t inputRank = inputTy.getSizes().size();
|
2023-04-07 19:49:35 +08:00
|
|
|
|
dimInt = toPositiveDim(dimInt, inputRank);
|
|
|
|
|
if (!isValidDim(dimInt, inputRank))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
2022-11-23 02:37:28 +08:00
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
2022-12-13 00:56:28 +08:00
|
|
|
|
loc, rewriter.getI64IntegerAttr(dimInt));
|
2022-11-23 02:37:28 +08:00
|
|
|
|
// The input to the next invocation of aten.max.dim is the output of the
|
|
|
|
|
// previous aten.max.dim op.
|
|
|
|
|
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, input);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // end namespace
|
|
|
|
|
|
[Torch Dialect] Decompose AtenTriuOp (#2561)
decompose like:
```
import torch
def my_triu(x, diag):
rows = torch.ops.aten.size(x, -2)
cols = torch.ops.aten.size(x, -1)
row_indices = torch.ops.aten.arange(rows).unsqueeze(1)
col_indices = torch.ops.aten.arange(cols).unsqueeze(0)
cond = torch.ops.aten.ge(
col_indices, torch.ops.aten.add(row_indices, diag))
return torch.ops.aten.where(cond, x, 0)
x = torch.rand(5, 7)
assert torch.allclose(my_triu(x, 0), torch.triu(x, 0))
assert torch.allclose(my_triu(x, 1), torch.triu(x, 1))
assert torch.allclose(my_triu(x, 2), torch.triu(x, 2))
assert torch.allclose(my_triu(x, -1), torch.triu(x, -1))
```
---------
Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>
2023-11-29 10:35:26 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTriuOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasSizes() || !inputType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
|
|
|
|
|
}
|
|
|
|
|
if (inputType.getSizes().size() < 2) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
|
|
|
|
Value cstZero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value cstOne =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
|
|
|
|
|
Value rowDim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(-2));
|
|
|
|
|
Value colDim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(-1));
|
|
|
|
|
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
|
|
|
|
|
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
|
|
|
|
|
|
|
|
|
|
Value rowArange = rewriter.create<AtenArangeOp>(
|
|
|
|
|
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none);
|
|
|
|
|
Value colArange = rewriter.create<AtenArangeOp>(
|
|
|
|
|
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none);
|
|
|
|
|
|
|
|
|
|
Value unsqueezeRowArange =
|
|
|
|
|
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
|
|
|
|
|
Value unsqueezeColArange =
|
|
|
|
|
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
|
|
|
|
|
|
|
|
|
|
Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
|
|
|
|
|
|
|
|
|
|
Value condTensor = rewriter.create<AtenGeTensorOp>(
|
|
|
|
|
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
|
|
|
|
|
op, op.getResult().getType(), condTensor, input, cstZero);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSizeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2021-11-08 23:56:40 +08:00
|
|
|
|
MLIRContext *context = op.getContext();
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeRank)
|
2021-11-08 23:56:40 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
2022-12-13 00:56:28 +08:00
|
|
|
|
unsigned rank = *maybeRank;
|
2021-11-08 23:56:40 +08:00
|
|
|
|
SmallVector<Value> sizes;
|
2022-12-13 00:56:28 +08:00
|
|
|
|
for (unsigned i = 0; i < rank; i++) {
|
2021-11-08 23:56:40 +08:00
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value sizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
|
|
|
|
|
rewriter.replaceOp(op, sizeList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-03 12:09:21 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSelectIntOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value start = op.getIndex();
|
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
Value self = op.getSelf();
|
2022-02-12 03:34:05 +08:00
|
|
|
|
|
2024-04-01 21:25:02 +08:00
|
|
|
|
auto resultTy = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resultTy.hasSizes() || !resultTy.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have sizes and dtype");
|
|
|
|
|
}
|
|
|
|
|
|
2023-01-18 02:14:14 +08:00
|
|
|
|
// convert `start` to non-negative: start += int(start < 0) * dimSize
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value isNegative = rewriter.create<AtenLtIntOp>(loc, start, zero);
|
|
|
|
|
isNegative = rewriter.create<AtenIntBoolOp>(loc, isNegative);
|
|
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
|
|
|
|
Value indexOffset = rewriter.create<AtenMulIntOp>(loc, isNegative, dimSize);
|
|
|
|
|
start = rewriter.create<AtenAddIntOp>(loc, start, indexOffset);
|
|
|
|
|
|
2021-12-03 12:09:21 +08:00
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
2022-02-12 03:34:05 +08:00
|
|
|
|
Value startPlusOne =
|
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
|
|
|
|
Value slice = rewriter.create<AtenSliceTensorOp>(
|
2022-06-29 15:23:57 +08:00
|
|
|
|
loc,
|
|
|
|
|
computeReductionType(rewriter, op,
|
|
|
|
|
self.getType().cast<BaseTensorType>(), dim,
|
|
|
|
|
/*keepDim=*/true),
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
2022-02-12 03:34:05 +08:00
|
|
|
|
|
2024-03-27 03:41:40 +08:00
|
|
|
|
auto sliceTy = cast<BaseTensorType>(slice.getType());
|
|
|
|
|
if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
|
|
|
|
|
rewriter.replaceOp(op, slice);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-12 03:34:05 +08:00
|
|
|
|
// `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after
|
|
|
|
|
// slicing, while `aten.select.int` does.
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(),
|
2022-12-08 04:20:41 +08:00
|
|
|
|
slice, op.getDim());
|
2021-12-03 12:09:21 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-03-06 07:01:21 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposePrimTolistOp : public OpRewritePattern<PrimTolistOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(PrimTolistOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
auto self = op.getOperands()[0];
|
|
|
|
|
auto selfTy = dyn_cast<Torch::BaseTensorType>(self.getType());
|
|
|
|
|
if (!selfTy || !selfTy.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unknown self shape");
|
|
|
|
|
|
|
|
|
|
int64_t rank = selfTy.getSizes().size();
|
|
|
|
|
if (rank != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Expected rank-1");
|
|
|
|
|
|
|
|
|
|
int64_t length = selfTy.getSizes().back();
|
|
|
|
|
if (length == Torch::kUnknownSize)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Tolist length is unknown");
|
|
|
|
|
|
|
|
|
|
auto resultTy = dyn_cast<Torch::ListType>(op.getType(0));
|
|
|
|
|
if (!resultTy)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Result type is not list");
|
|
|
|
|
|
|
|
|
|
auto scalarTy = resultTy.getContainedType();
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
auto extractTy = rewriter.getType<ValueTensorType>(
|
|
|
|
|
llvm::SmallVector<int64_t>{1}, selfTy.getOptionalDtype());
|
|
|
|
|
llvm::SmallVector<Value> results;
|
|
|
|
|
llvm::SmallVector<int64_t> sizes(selfTy.getSizes());
|
|
|
|
|
for (int64_t i = 0; i < length; ++i) {
|
|
|
|
|
Value iv =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
Value extract = rewriter.create<AtenSelectIntOp>(
|
|
|
|
|
loc, extractTy, self, /*dim=*/zero, /*index=*/iv);
|
|
|
|
|
Value scalar = rewriter.create<AtenItemOp>(loc, scalarTy, extract);
|
|
|
|
|
results.push_back(scalar);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, resultTy, results);
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSplitSizesOp : public OpRewritePattern<AtenSplitSizesOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSplitSizesOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
|
|
|
|
|
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSplitWithSizesOp
|
|
|
|
|
: public OpRewritePattern<AtenSplitWithSizesOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSplitWithSizesOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
SmallVector<Value> splitSizes;
|
|
|
|
|
if (!getListConstructElements(op.getSplitSizes(), splitSizes))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unable to get sizes");
|
|
|
|
|
|
|
|
|
|
if (splitSizes.empty())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "No split sizes");
|
|
|
|
|
|
|
|
|
|
auto selfTy = dyn_cast<BaseTensorType>(self.getType());
|
|
|
|
|
if (!selfTy || !selfTy.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Self shape unknown");
|
|
|
|
|
|
|
|
|
|
int64_t rank = selfTy.getSizes().size();
|
|
|
|
|
auto resultTy = dyn_cast<Torch::ListType>(op.getResult().getType());
|
|
|
|
|
if (!resultTy)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Result type not a list");
|
|
|
|
|
|
|
|
|
|
auto sliceTy =
|
|
|
|
|
dyn_cast_or_null<Torch::BaseTensorType>(resultTy.getContainedType());
|
|
|
|
|
if (!isa<Torch::BaseTensorType>(sliceTy))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Slice type is unknown");
|
|
|
|
|
|
|
|
|
|
int64_t dimInt = 0;
|
|
|
|
|
bool hasDim = matchPattern(op.getDim(), m_TorchConstantInt(&dimInt));
|
|
|
|
|
if (dimInt < 0)
|
|
|
|
|
dimInt += rank;
|
|
|
|
|
|
|
|
|
|
auto intTy = rewriter.getType<Torch::IntType>();
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value begin =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<Value> slices;
|
|
|
|
|
llvm::SmallVector<int64_t> sliceSizes(sliceTy.getSizes());
|
|
|
|
|
int64_t defaultLength = !hasDim ? Torch::kUnknownSize : sliceSizes[dimInt];
|
|
|
|
|
for (auto size : splitSizes) {
|
|
|
|
|
Value end = rewriter.create<AtenAddIntOp>(loc, intTy, begin, size);
|
|
|
|
|
|
|
|
|
|
int64_t sizeInt;
|
|
|
|
|
if (hasDim && matchPattern(size, m_TorchConstantInt(&sizeInt))) {
|
|
|
|
|
sliceSizes[dimInt] = sizeInt;
|
|
|
|
|
} else if (hasDim) {
|
|
|
|
|
sliceSizes[dimInt] = defaultLength;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sliceTy = rewriter.getType<ValueTensorType>(sliceSizes,
|
|
|
|
|
sliceTy.getOptionalDtype());
|
|
|
|
|
Value slice = rewriter.create<AtenSliceTensorOp>(
|
|
|
|
|
loc, sliceTy, op.getSelf(),
|
|
|
|
|
/*dim=*/op.getDim(), /*start=*/begin, /*end=*/end, /*step=*/one);
|
|
|
|
|
slices.push_back(slice);
|
|
|
|
|
begin = end;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, resultTy, slices);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-25 03:26:37 +08:00
|
|
|
|
namespace {
|
2022-08-01 20:32:35 +08:00
|
|
|
|
class DecomposeAtenNarrowOp : public OpRewritePattern<AtenNarrowOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNarrowOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value start = op.getStart();
|
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
Value length = op.getLength();
|
2022-08-01 20:32:35 +08:00
|
|
|
|
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value startPlusLength =
|
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, length);
|
2022-11-06 20:44:05 +08:00
|
|
|
|
|
2022-08-01 20:32:35 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, op.getResult().getType(), op.getSelf(), /*dim=*/dim,
|
|
|
|
|
/*start=*/start,
|
2022-08-01 20:32:35 +08:00
|
|
|
|
/*end=*/startPlusLength, /*step=*/one);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-07-20 16:46:44 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.narrow.Tensor` to `aten.narrow` op
|
|
|
|
|
class DecomposeAtenNarrowTensorOp
|
|
|
|
|
: public OpRewritePattern<AtenNarrowTensorOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNarrowTensorOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto *context = op.getContext();
|
|
|
|
|
// PyTorch makes sure that `start` param is an 0-dim integral tensor.
|
|
|
|
|
// REF: https://pytorch.org/docs/stable/generated/torch.narrow.html.
|
|
|
|
|
auto start = rewriter.create<Torch::AtenScalarImplicitOp>(
|
|
|
|
|
loc, Torch::IntType::get(context), op.getStart());
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenNarrowOp>(
|
|
|
|
|
op, op.getType(), op.getSelf(), op.getDim(), start, op.getLength());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-10-26 10:36:18 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenGluOp : public OpRewritePattern<AtenGluOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenGluOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
|
|
|
|
|
auto outputTy = op.getType().dyn_cast<Torch::ValueTensorType>();
|
|
|
|
|
if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected output type having sizes and dtype");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
|
|
|
|
Value two =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
|
|
|
|
|
|
|
|
|
|
Value remainder = rewriter.create<AtenRemainderIntOp>(loc, dimSize, two);
|
|
|
|
|
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, zero);
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2023-10-26 10:36:18 +08:00
|
|
|
|
rewriter.create<RuntimeAssertOp>(
|
|
|
|
|
loc, eqOrNot,
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
rewriter.getStringAttr("AtenGluOp's dim size must be multiple of 2"));
|
2023-10-26 10:36:18 +08:00
|
|
|
|
|
|
|
|
|
Value splitLength = rewriter.create<AtenFloordivIntOp>(loc, dimSize, two);
|
|
|
|
|
Value a = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim, zero,
|
|
|
|
|
splitLength);
|
|
|
|
|
Value b = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim,
|
|
|
|
|
splitLength, splitLength);
|
|
|
|
|
// a⊗σ(b)
|
|
|
|
|
Value sigmoidB = rewriter.create<AtenSigmoidOp>(loc, outputTy, b);
|
|
|
|
|
Value result = rewriter.create<AtenMulTensorOp>(loc, outputTy, a, sigmoidB);
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-08-01 20:32:35 +08:00
|
|
|
|
namespace {
|
2023-11-01 11:56:54 +08:00
|
|
|
|
class DecomposeAtenZeroOp : public OpRewritePattern<AtenZeroOp> {
|
2022-03-25 03:26:37 +08:00
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-07-12 01:56:12 +08:00
|
|
|
|
LogicalResult matchAndRewrite(AtenZeroOp op,
|
2022-03-25 03:26:37 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
|
|
|
|
|
rewriter.getI64IntegerAttr(0));
|
2023-11-01 11:56:54 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenFillScalarOp>(op, op.getType(),
|
|
|
|
|
op.getSelf(), zero);
|
2022-03-25 03:26:37 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-11-02 02:23:28 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenEyeOp : public OpRewritePattern<AtenEyeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenEyeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value n = op.getN();
|
|
|
|
|
Value m = op.getN();
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEyeMOp>(op, op.getType(), n, m,
|
|
|
|
|
op.getDtype(), op.getLayout(),
|
|
|
|
|
op.getDevice(), op.getPinMemory());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
int64_t n;
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2023-11-02 02:23:28 +08:00
|
|
|
|
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"unimplemented: n must be constant");
|
|
|
|
|
int64_t m;
|
|
|
|
|
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"unimplemented: m must be constant");
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
|
|
|
|
if (!outType)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Only tensor types input are currently supported");
|
|
|
|
|
if (!outType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
|
|
|
|
if (n < 0) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
|
|
|
|
|
}
|
|
|
|
|
if (m < 0) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto context = op.getContext();
|
|
|
|
|
auto int64Dtype = getDtypeIntValueForType(
|
|
|
|
|
rewriter, loc,
|
|
|
|
|
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
|
|
|
|
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
|
|
|
|
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
|
|
|
|
|
Value rangeN = rewriter.create<AtenArangeOp>(
|
|
|
|
|
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
|
|
|
|
/*device=*/op.getDevice(), /*pin_memory=*/none);
|
|
|
|
|
|
|
|
|
|
auto arangeType1 =
|
|
|
|
|
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
|
|
|
|
|
Value rangeM = rewriter.create<AtenArangeOp>(
|
|
|
|
|
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none);
|
|
|
|
|
|
|
|
|
|
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(-1));
|
|
|
|
|
auto unsqzTensorInfo =
|
|
|
|
|
unsqueezeTensor(rewriter, op, rangeN, /*dim=*/constMinusOne);
|
|
|
|
|
if (failed(unsqzTensorInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"cannot generate unsqueeze tensor");
|
|
|
|
|
}
|
|
|
|
|
Value unsqzRangeN = *unsqzTensorInfo;
|
|
|
|
|
|
|
|
|
|
// compare unsqueezed input with boundaries
|
|
|
|
|
auto eqType = ValueTensorType::get(
|
|
|
|
|
context, op.getType().cast<BaseTensorType>().getSizes(),
|
|
|
|
|
IntegerType::get(context, 1));
|
|
|
|
|
Value eqTensor =
|
|
|
|
|
rewriter.create<AtenEqTensorOp>(loc, eqType, unsqzRangeN, rangeM);
|
|
|
|
|
|
|
|
|
|
Value dtype = op.getDtype();
|
|
|
|
|
if (dtype.getType().isa<Torch::BoolType>()) {
|
|
|
|
|
rewriter.replaceOp(op, eqTensor);
|
|
|
|
|
return success();
|
|
|
|
|
} else {
|
|
|
|
|
auto zero =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
|
|
|
|
auto one =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value outTensor =
|
|
|
|
|
rewriter.create<AtenWhereScalarOp>(loc, outType, eqTensor, one, zero);
|
|
|
|
|
rewriter.replaceOp(op, outTensor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-06-07 10:06:27 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenIsnanOp : public OpRewritePattern<AtenIsnanOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIsnanOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
// Create a new aten.ne operation with the same type and input value.
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenNeTensorOp>(op, op.getType(), input, input);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-11-04 22:26:01 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenIsinfOp : public OpRewritePattern<AtenIsinfOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIsinfOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
|
|
|
|
|
mlir::FloatType f64Type = rewriter.getF64Type();
|
|
|
|
|
Value inf = rewriter.create<ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getFloatAttr(
|
|
|
|
|
f64Type, APFloat::getInf(f64Type.getFloatSemantics())));
|
|
|
|
|
Value abs = rewriter.create<AtenAbsOp>(loc, self.getType(), self);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEqScalarOp>(op, op.getType(), abs, inf);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-01-16 14:29:34 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenIsneginfOp : public OpRewritePattern<AtenIsneginfOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIsneginfOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
mlir::FloatType f64Type = rewriter.getF64Type();
|
|
|
|
|
Value inf = rewriter.create<ConstantFloatOp>(
|
|
|
|
|
op.getLoc(),
|
|
|
|
|
rewriter.getFloatAttr(
|
|
|
|
|
f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true)));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEqScalarOp>(op, op.getType(), op.getSelf(),
|
|
|
|
|
inf);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIsposinfOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
mlir::FloatType f64Type = rewriter.getF64Type();
|
|
|
|
|
Value inf = rewriter.create<ConstantFloatOp>(
|
|
|
|
|
op.getLoc(),
|
|
|
|
|
rewriter.getFloatAttr(f64Type,
|
|
|
|
|
APFloat::getInf(f64Type.getFloatSemantics())));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEqScalarOp>(op, op.getType(), op.getSelf(),
|
|
|
|
|
inf);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-17 23:54:03 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenReshapeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2021-12-17 23:54:03 +08:00
|
|
|
|
// TODO: Handle non value tensor type operands.
|
|
|
|
|
if (!input.getType().isa<ValueTensorType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: only value tensor type operands are supported");
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), input,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getShape());
|
2021-12-17 23:54:03 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-12-10 12:30:37 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce
|
|
|
|
|
// operation and permute operation. Currently, this pass doesn't support
|
|
|
|
|
// Hadamard product. The basic idea is that:
|
|
|
|
|
// Step 1: split the string equation to input/result tokens and find
|
|
|
|
|
// batchingDims, contractingDims, otherDims and reduceDims.
|
|
|
|
|
// Step 2: permute and reshape input tensors suitable
|
|
|
|
|
// for matmul operations.
|
|
|
|
|
// Step 3: use AtenMatmulOp to get the result.
|
|
|
|
|
// Step 4: iteratively execute step 2 & 3 until we get the final result.
|
|
|
|
|
// Step 5: perform remaining permute and reduce operations.
|
|
|
|
|
// notice: support static shape only
|
|
|
|
|
|
|
|
|
|
class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenEinsumOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> inputTensors;
|
|
|
|
|
if (!getListConstructElements(op.getTensors(), inputTensors)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input should comes from a PrimListConstructOp");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto allTensorHasSizes = [](Value tensor) {
|
|
|
|
|
auto type = tensor.getType().dyn_cast<BaseTensorType>();
|
|
|
|
|
if (!type || !type.hasSizes())
|
|
|
|
|
return false;
|
|
|
|
|
return true;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (!llvm::all_of(inputTensors, allTensorHasSizes)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"all input tensors should have sizes");
|
|
|
|
|
}
|
|
|
|
|
|
2024-03-28 03:42:10 +08:00
|
|
|
|
std::string equation;
|
|
|
|
|
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
|
|
|
|
|
}
|
|
|
|
|
// if "..." in equation, modify it
|
|
|
|
|
if (equation.find("...") != std::string::npos) {
|
|
|
|
|
SmallVector<int64_t> inputRanks;
|
|
|
|
|
for (Value tensor : inputTensors) {
|
|
|
|
|
auto type = tensor.getType().cast<BaseTensorType>();
|
|
|
|
|
inputRanks.push_back(type.getSizes().size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!rewriteEquationWithEllipsisSlicing(equation, inputRanks)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unexpected character in equations encountered");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
SmallVector<char> resultTokens;
|
|
|
|
|
SmallVector<SmallVector<char>> inputTokens;
|
|
|
|
|
if (!parseEquation(equation, inputTokens, resultTokens)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unexpected character in equations encountered");
|
|
|
|
|
}
|
|
|
|
|
|
2023-12-10 12:30:37 +08:00
|
|
|
|
SmallVector<char> lhsTokens = inputTokens[0];
|
|
|
|
|
Value lhs = inputTensors[0];
|
|
|
|
|
Value result;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < inputTensors.size(); ++i) {
|
|
|
|
|
auto rhs = inputTensors[i];
|
|
|
|
|
auto rhsTokens = inputTokens[i];
|
|
|
|
|
SmallVector<char> outTokens;
|
|
|
|
|
if (failed(performMatmul(rewriter, loc, lhs, lhsTokens, rhs, rhsTokens,
|
|
|
|
|
result, outTokens, resultTokens))) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
lhs = result;
|
|
|
|
|
lhsTokens = outTokens;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
result = performLastReduceAndPermute(rewriter, loc, op.getType(), lhs,
|
|
|
|
|
lhsTokens, resultTokens);
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-02-10 00:00:24 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Calculate the trace of the input tensor as the sum over its diagonal
|
|
|
|
|
// elements. This computation is performed as:
|
|
|
|
|
//
|
|
|
|
|
// Step1: Obtain the diagonal using AtenDiagonalOp
|
|
|
|
|
// Step2: Compute the trace using AtenSumOp.
|
|
|
|
|
//
|
|
|
|
|
// It is verified that the input tensor has rank two.
|
|
|
|
|
class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTraceOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
std::optional<unsigned> inRank = getTensorRank(self);
|
|
|
|
|
if (inRank != 2)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected input tensor to have rank 2.");
|
|
|
|
|
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
BaseTensorType inputType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
Value output = op.getResult();
|
|
|
|
|
BaseTensorType outputType = output.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
|
|
|
|
int64_t diagonalSize = std::min(inputShape[0], inputShape[1]);
|
|
|
|
|
SmallVector<int64_t> diagonalShape{diagonalSize};
|
|
|
|
|
Type elementType = inputType.getOptionalDtype();
|
|
|
|
|
Type diagonalType = inputType.getWithSizesAndDtype(
|
|
|
|
|
llvm::ArrayRef(diagonalShape), elementType);
|
|
|
|
|
|
|
|
|
|
Value diagonal = rewriter.create<AtenDiagonalOp>(
|
|
|
|
|
loc, diagonalType, /*input=*/self, /*offset=*/zero, /*dim1=*/zero,
|
|
|
|
|
/*dim2=*/one);
|
|
|
|
|
Value sum = rewriter.create<AtenSumOp>(loc, outputType, /*self=*/diagonal,
|
|
|
|
|
/*dtype=*/none);
|
|
|
|
|
rewriter.replaceOp(op, sum);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-25 13:49:02 +08:00
|
|
|
|
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
|
|
|
|
// exp(x)/sum(exp(x)).
|
2022-02-01 03:56:32 +08:00
|
|
|
|
// To avoid overflow we use the following decomposition rule:
|
|
|
|
|
// x_max = max(input, dim, keepdim = True)
|
|
|
|
|
// unnorm = aten.exp(input - x_max)
|
|
|
|
|
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
|
2021-11-25 13:49:02 +08:00
|
|
|
|
template <typename OpTy>
|
2022-11-25 13:56:37 +08:00
|
|
|
|
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
|
2024-03-13 06:07:45 +08:00
|
|
|
|
Type accumulatorType, PatternRewriter &rewriter) {
|
2021-11-25 13:49:02 +08:00
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value dim = op.getDim();
|
2024-03-13 06:07:45 +08:00
|
|
|
|
if (resultType != accumulatorType)
|
|
|
|
|
self = convertTensorToDtype(rewriter, loc, self, accumulatorType);
|
2022-02-01 03:56:32 +08:00
|
|
|
|
Value xMax =
|
|
|
|
|
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
2024-03-13 06:07:45 +08:00
|
|
|
|
|
2022-02-01 03:56:32 +08:00
|
|
|
|
if (!xMax)
|
|
|
|
|
return nullptr;
|
2024-03-13 06:07:45 +08:00
|
|
|
|
Value unNormalized =
|
|
|
|
|
createTensorSub(rewriter, loc, self.getType(), self, xMax);
|
2022-02-01 03:56:32 +08:00
|
|
|
|
Value unNormalizedExp =
|
2024-03-13 06:07:45 +08:00
|
|
|
|
rewriter.create<AtenExpOp>(loc, self.getType(), unNormalized);
|
2022-02-01 03:56:32 +08:00
|
|
|
|
Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim,
|
|
|
|
|
/*keepDim=*/true);
|
2021-11-25 13:49:02 +08:00
|
|
|
|
if (!sum)
|
|
|
|
|
return nullptr;
|
2024-03-13 06:07:45 +08:00
|
|
|
|
|
|
|
|
|
Value result = rewriter.create<AtenDivTensorOp>(loc, self.getType(),
|
|
|
|
|
unNormalizedExp, sum);
|
|
|
|
|
if (resultType != accumulatorType)
|
|
|
|
|
result = convertTensorToDtype(rewriter, loc, result,
|
2024-04-11 21:47:35 +08:00
|
|
|
|
cast<BaseTensorType>(resultType).getDtype());
|
2024-03-13 06:07:45 +08:00
|
|
|
|
|
|
|
|
|
return result;
|
2021-11-25 13:49:02 +08:00
|
|
|
|
}
|
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
|
// Decompose softmax into: exp(x) / sum(exp(x))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2023-08-04 16:02:13 +08:00
|
|
|
|
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resultTensorType.hasDtype()) {
|
2021-10-16 06:23:59 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-08-04 16:02:13 +08:00
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
|
|
|
|
Type resultTensorDtype = resultTensorType.getDtype();
|
2024-04-11 21:47:35 +08:00
|
|
|
|
if (!isa<mlir::FloatType>(resultTensorDtype))
|
2023-08-04 16:02:13 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"Only support floating-point type");
|
|
|
|
|
|
|
|
|
|
// If `dtype` arg is non-none then convert the input to `dtype`.
|
|
|
|
|
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
|
self = rewriter.create<AtenToDtypeOp>(
|
|
|
|
|
loc, resultTensorType, self,
|
|
|
|
|
getDtypeIntValueForType(rewriter, loc, resultTensorDtype),
|
|
|
|
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
|
|
|
|
|
}
|
2021-11-08 23:56:40 +08:00
|
|
|
|
|
2024-03-13 06:07:45 +08:00
|
|
|
|
Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
|
|
|
|
|
|
|
|
|
|
Value result = getSoftmaxResult(op, self, resultTensorType,
|
|
|
|
|
accumulatorTensorType, rewriter);
|
2021-11-25 13:49:02 +08:00
|
|
|
|
if (!result)
|
2021-11-08 23:56:40 +08:00
|
|
|
|
return failure();
|
2021-11-25 13:49:02 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
|
|
|
|
result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2021-11-25 13:49:02 +08:00
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
bool halfToFloat;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat)))
|
2021-11-25 13:49:02 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected a boolean value for half_to_float");
|
|
|
|
|
|
2022-11-25 13:56:37 +08:00
|
|
|
|
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
|
if (!resultTensorType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
|
|
|
|
Type resultTensorDtype = resultTensorType.getDtype();
|
2022-11-25 13:56:37 +08:00
|
|
|
|
// `torch.ops.aten._softmax`'s softmax with half to float conversion is not
|
|
|
|
|
// supported on CPU, but we go ahead with the decomposing.
|
|
|
|
|
// TODO: Add an e2e test once upstream support is added.
|
|
|
|
|
// If `half_to_float` is set, we convert the input's elemental type to match
|
|
|
|
|
// that of output's.
|
|
|
|
|
if (halfToFloat) {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
|
self = rewriter.create<AtenToDtypeOp>(
|
|
|
|
|
loc, resultTensorType, self,
|
2023-01-04 06:19:18 +08:00
|
|
|
|
getDtypeIntValueForType(rewriter, loc, resultTensorDtype),
|
2022-11-25 13:56:37 +08:00
|
|
|
|
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
|
|
|
|
|
}
|
2024-03-13 06:07:45 +08:00
|
|
|
|
|
|
|
|
|
Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
|
|
|
|
|
|
|
|
|
|
Value result = getSoftmaxResult(op, self, resultTensorType,
|
|
|
|
|
accumulatorTensorType, rewriter);
|
2021-11-25 13:49:02 +08:00
|
|
|
|
if (!result)
|
|
|
|
|
return op.emitError("failed to get softmax result");
|
2022-11-25 13:56:37 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
|
2021-10-16 06:23:59 +08:00
|
|
|
|
result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-08 23:56:40 +08:00
|
|
|
|
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
|
|
|
|
// newGrad = gradOutput * output
|
|
|
|
|
// result = newGrad - output * sum(newGrad, dim))
|
|
|
|
|
//
|
|
|
|
|
// Refer to
|
|
|
|
|
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_SoftmaxBackwardDataOp
|
|
|
|
|
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value gradOutput = op.getGradOutput();
|
|
|
|
|
Value output = op.getOutput();
|
|
|
|
|
Value dim = op.getDim();
|
2021-11-08 23:56:40 +08:00
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
|
|
Value newGrad =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
2021-11-19 20:18:41 +08:00
|
|
|
|
Value result = createSoftmaxBackwardCommonKernel(
|
|
|
|
|
rewriter, loc, op, tensorType, newGrad, output, newGrad, dim);
|
|
|
|
|
if (!result)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op,
|
|
|
|
|
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
|
|
|
|
rewriter.replaceOp(op, result);
|
2021-11-08 23:56:40 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-09 20:25:04 +08:00
|
|
|
|
// AtenTanhBackwardOp(gradOutput, output) =>
|
|
|
|
|
// result = gradOutput * (1 - output^2)
|
|
|
|
|
// To get away from broadcasts the above formula is expanded i.e.,
|
|
|
|
|
// result = gradOutput - (gradOutput * output^2)
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenTanhBackwardOp
|
|
|
|
|
: public OpRewritePattern<AtenTanhBackwardOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTanhBackwardOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value gradOutput = op.getGradOutput();
|
2021-11-09 20:25:04 +08:00
|
|
|
|
|
|
|
|
|
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
|
2021-11-19 20:18:41 +08:00
|
|
|
|
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value output = op.getOutput();
|
2021-11-09 20:25:04 +08:00
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
|
|
Value tanhSquare =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, output);
|
|
|
|
|
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, tensorType, tanhSquare, gradOutput);
|
|
|
|
|
|
2021-11-19 20:18:41 +08:00
|
|
|
|
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
|
2022-02-15 21:14:32 +08:00
|
|
|
|
gradMulTanhSquare);
|
2021-11-09 20:25:04 +08:00
|
|
|
|
rewriter.replaceOp(op, newGrad);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-11-19 02:02:20 +08:00
|
|
|
|
// Aten_LogSoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
|
|
|
|
// result = gradOutput - (exp(output) * sum(gradOutput, dim))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_LogSoftmaxBackwardDataOp
|
|
|
|
|
: public OpRewritePattern<Aten_LogSoftmaxBackwardDataOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value gradOutput = op.getGradOutput();
|
|
|
|
|
Value output = op.getOutput();
|
|
|
|
|
Value dim = op.getDim();
|
2021-11-19 02:02:20 +08:00
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
|
|
|
|
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
2021-11-19 20:18:41 +08:00
|
|
|
|
Value result = createSoftmaxBackwardCommonKernel(
|
|
|
|
|
rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim);
|
|
|
|
|
if (!result)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op,
|
|
|
|
|
"nullptr returned by createSoftmaxBackwardCommonKernel function.");
|
|
|
|
|
rewriter.replaceOp(op, result);
|
2021-11-19 02:02:20 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-02-28 14:48:07 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenAMinMaxOp : public OpRewritePattern<Torch::AtenAminOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern<Torch::AtenAminOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Torch::AtenAminOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
llvm::SmallVector<int64_t> dimList;
|
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "dims not foldable constants");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool keepdim;
|
|
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "keepdims not foldable constants");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
std::sort(dimList.begin(), dimList.end(), std::greater<int64_t>());
|
|
|
|
|
|
|
|
|
|
Value reduction = op.getSelf();
|
|
|
|
|
auto resultTy = cast<Torch::ValueTensorType>(op.getType());
|
|
|
|
|
auto reductionTy = cast<Torch::ValueTensorType>(reduction.getType());
|
|
|
|
|
llvm::SmallVector<int64_t> reductionShape(reductionTy.getSizes());
|
|
|
|
|
|
|
|
|
|
for (auto dim : dimList) {
|
|
|
|
|
auto dimValue = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dim));
|
|
|
|
|
reductionShape[dim] = 1;
|
|
|
|
|
if (!keepdim) {
|
|
|
|
|
for (int i = dim, s = reductionShape.size() - 1; i < s; ++i)
|
|
|
|
|
reductionShape[i] = reductionShape[i + 1];
|
|
|
|
|
reductionShape.resize(reductionShape.size() - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
reductionTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
reductionShape, resultTy.getOptionalDtype());
|
|
|
|
|
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
|
|
|
|
|
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
|
|
|
|
|
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
|
|
|
|
|
reduction = rewriter
|
|
|
|
|
.create<Torch::AtenMinDimOp>(loc, types, reduction,
|
|
|
|
|
dimValue, op.getKeepdim())
|
|
|
|
|
.getResult(0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, reduction);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
|
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into
|
|
|
|
|
// `AtenMinDimOp`
|
2022-01-25 16:53:55 +08:00
|
|
|
|
namespace {
|
2023-12-06 22:45:30 +08:00
|
|
|
|
template <typename OpTy, typename DecompOpTy>
|
|
|
|
|
class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
|
2022-01-25 16:53:55 +08:00
|
|
|
|
public:
|
2023-12-06 22:45:30 +08:00
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
2022-01-25 16:53:55 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
Value keepDim = op.getKeepdim();
|
|
|
|
|
Value result = op.getResult();
|
2022-01-25 16:53:55 +08:00
|
|
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input tensor to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-01-25 16:53:55 +08:00
|
|
|
|
if (!indicesTensorType.hasSizes())
|
|
|
|
|
return failure();
|
|
|
|
|
BaseTensorType valueTensorType =
|
|
|
|
|
inputType
|
2023-01-04 06:19:18 +08:00
|
|
|
|
.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(),
|
|
|
|
|
inputType.getOptionalDtype())
|
2022-01-25 16:53:55 +08:00
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
2024-01-30 01:59:33 +08:00
|
|
|
|
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
|
|
|
|
|
// first the input tensor is flattened to 1d tensor and then the reduction
|
|
|
|
|
// happens on the 0th dimension.
|
2022-01-25 16:53:55 +08:00
|
|
|
|
if (dim.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
BaseTensorType flattenType =
|
2023-01-04 06:19:18 +08:00
|
|
|
|
inputType
|
|
|
|
|
.getWithSizesAndDtype({kUnknownSize},
|
|
|
|
|
inputType.getOptionalDtype())
|
2022-01-25 16:53:55 +08:00
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value end = rewriter.create<ConstantIntOp>(
|
2022-12-13 00:56:28 +08:00
|
|
|
|
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
2022-01-25 16:53:55 +08:00
|
|
|
|
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
|
|
|
|
dim, end);
|
|
|
|
|
}
|
|
|
|
|
|
2023-12-06 22:45:30 +08:00
|
|
|
|
Value resultArg =
|
2024-01-30 01:59:33 +08:00
|
|
|
|
rewriter
|
|
|
|
|
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType, input,
|
|
|
|
|
dim, keepDim)
|
|
|
|
|
.getIndices();
|
|
|
|
|
|
2023-12-06 22:45:30 +08:00
|
|
|
|
rewriter.replaceOp(op, resultArg);
|
2022-01-25 16:53:55 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-02-03 10:20:47 +08:00
|
|
|
|
// Decompose `aten.bucketize` into the following op sequence:
|
|
|
|
|
//
|
|
|
|
|
// def aten_bucketize(input, boundaries, out_int32, right):
|
|
|
|
|
// unsqz_input = input.unsqueeze(-1)
|
|
|
|
|
// if not right:
|
|
|
|
|
// comparison = unsqz_input <= boundaries
|
|
|
|
|
// else:
|
|
|
|
|
// comparison = unsqz_input < boundaries
|
|
|
|
|
// indices = torch.argmax(comparison.float(), dim=-1)
|
|
|
|
|
// within_bound = comparison[..., -1]
|
|
|
|
|
// result = torch.where(within_bound, indices, boundaries.shape[0])
|
|
|
|
|
// if out_int32:
|
|
|
|
|
// result = result.int()
|
|
|
|
|
// return result
|
|
|
|
|
//
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenBucketizeTensorOp
|
|
|
|
|
: public OpRewritePattern<AtenBucketizeTensorOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenBucketizeTensorOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: input must have known sizes");
|
|
|
|
|
}
|
|
|
|
|
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
|
|
|
|
|
|
|
|
|
Value boundaries = op.getBoundaries();
|
|
|
|
|
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"unimplemented: boundaries must have "
|
|
|
|
|
"known sizes and must be a 1D array");
|
|
|
|
|
}
|
|
|
|
|
int64_t boundariesSize = boundariesType.getSizes()[0];
|
|
|
|
|
|
|
|
|
|
bool outInt32;
|
|
|
|
|
if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: out_int32 must be a constant bool");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool right;
|
|
|
|
|
if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: right must be a constant bool");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// unsqueeze input at the last dim to make it broadcastable with boundaries
|
|
|
|
|
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(-1));
|
|
|
|
|
auto unsqzTensorInfo =
|
|
|
|
|
unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne);
|
|
|
|
|
if (failed(unsqzTensorInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"cannot generate unsqueeze tensor");
|
|
|
|
|
}
|
|
|
|
|
Value unsqzInput = *unsqzTensorInfo;
|
|
|
|
|
|
|
|
|
|
// compare unsqueezed input with boundaries
|
|
|
|
|
SmallVector<int64_t> compareShape(inputShape);
|
|
|
|
|
compareShape.push_back(boundariesSize);
|
|
|
|
|
Type compareType =
|
|
|
|
|
inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type());
|
|
|
|
|
Value compare;
|
|
|
|
|
if (!right) {
|
|
|
|
|
compare = rewriter.create<AtenLeTensorOp>(loc, compareType, unsqzInput,
|
|
|
|
|
boundaries);
|
|
|
|
|
} else {
|
|
|
|
|
compare = rewriter.create<AtenLtTensorOp>(loc, compareType, unsqzInput,
|
|
|
|
|
boundaries);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// convert the comparison results to float32 as the argmax op input,
|
|
|
|
|
// which does not support integer dtype in LINALG backend
|
|
|
|
|
Value compareF32 =
|
|
|
|
|
convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type());
|
|
|
|
|
|
|
|
|
|
// get the first boundary index where the input element is less than (or
|
|
|
|
|
// equal to) the boundary value
|
|
|
|
|
Type indicesType = inputType.getWithSizesAndDtype(
|
|
|
|
|
inputShape, rewriter.getIntegerType(64, IntegerType::Signed));
|
|
|
|
|
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value indices = rewriter.create<AtenArgmaxOp>(loc, indicesType, compareF32,
|
|
|
|
|
/*dim=*/constMinusOne,
|
|
|
|
|
/*keepdim=*/constFalse);
|
|
|
|
|
|
|
|
|
|
// get the comparison results between each input element and the rightmost
|
|
|
|
|
// boundary value
|
|
|
|
|
Type withinUpperBoundType =
|
|
|
|
|
inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type());
|
|
|
|
|
Value withinUpperBound = rewriter.create<AtenSelectIntOp>(
|
|
|
|
|
loc, withinUpperBoundType, compare, /*dim=*/constMinusOne,
|
|
|
|
|
/*index=*/constMinusOne);
|
|
|
|
|
|
|
|
|
|
// If the input element is less than (or equal to) the rightmost boundary,
|
|
|
|
|
// take the max index as result. Otherwise, the element is beyond the
|
|
|
|
|
// rightmost boundary, so take the boundary size.
|
|
|
|
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value upperBound =
|
|
|
|
|
rewriter.create<AtenSizeIntOp>(loc, boundaries, /*dim=*/constZero);
|
|
|
|
|
Value result = rewriter.create<AtenWhereScalarOtherOp>(
|
|
|
|
|
loc, indicesType, withinUpperBound, indices, upperBound);
|
|
|
|
|
|
|
|
|
|
if (outInt32) {
|
|
|
|
|
result = convertTensorToDtype(
|
|
|
|
|
rewriter, loc, result,
|
|
|
|
|
rewriter.getIntegerType(32, IntegerType::Signed));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-11 16:39:34 +08:00
|
|
|
|
// To avoid overflow we use the following decomposition rule:
|
|
|
|
|
// x_max = aten.max(x, dim, keepdim=True)[0]
|
|
|
|
|
// shifted = x - x_max
|
|
|
|
|
// shifted_logsumexp = aten.log(aten.sum(aten.exp(shifted), dim, keepdim=True))
|
|
|
|
|
// log_softmax = shifted - shifted_logsumexp
|
|
|
|
|
template <typename OpTy>
|
|
|
|
|
static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
Value self = op.getSelf();
|
2022-02-11 16:39:34 +08:00
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
Value xMax =
|
|
|
|
|
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
|
|
|
|
if (!xMax)
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
|
|
Value shifted = createTensorSub(rewriter, loc, tensorType, self, xMax);
|
|
|
|
|
Value shiftedExp = rewriter.create<AtenExpOp>(loc, tensorType, shifted);
|
|
|
|
|
Value shiftedSumExp =
|
|
|
|
|
createSumAlongDimension(rewriter, loc, op, shiftedExp, dim,
|
|
|
|
|
/*keepDim=*/true);
|
|
|
|
|
if (!shiftedSumExp)
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
|
|
Value shiftedLogSumExp =
|
|
|
|
|
rewriter.create<AtenLogOp>(loc, shiftedSumExp.getType(), shiftedSumExp);
|
|
|
|
|
Value result =
|
|
|
|
|
createTensorSub(rewriter, loc, op.getType(), shifted, shiftedLogSumExp);
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
2021-11-03 01:06:04 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenLogSoftmaxIntOp
|
|
|
|
|
: public OpRewritePattern<AtenLogSoftmaxIntOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
if (!op.getDtype().getType().isa<Torch::NoneType>())
|
2021-11-03 01:06:04 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented non-None dtype for log_softmax");
|
|
|
|
|
|
|
|
|
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
|
|
|
|
|
2022-02-11 16:39:34 +08:00
|
|
|
|
Value logSoftmax = getLogSoftmaxResult(op, rewriter);
|
|
|
|
|
if (!logSoftmax)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "getLogSoftmaxResult function returned nullptr");
|
|
|
|
|
rewriter.replaceOp(op, logSoftmax);
|
2021-11-03 01:06:04 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-10 15:05:23 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_LogSoftmaxOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-02-11 16:39:34 +08:00
|
|
|
|
bool halfToFloat;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat)))
|
2022-02-11 16:39:34 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected a boolean value for half_to_float");
|
|
|
|
|
|
|
|
|
|
// Currently, setting `halfToFloat` is not supported as the E2E testing for
|
|
|
|
|
// the same is not present on CPU.
|
|
|
|
|
if (halfToFloat)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "halfToFloat is currently not supported.");
|
|
|
|
|
Value _logSoftmax = getLogSoftmaxResult(op, rewriter);
|
|
|
|
|
if (!_logSoftmax)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "getLogSoftmaxResult function returned nullptr");
|
|
|
|
|
rewriter.replaceOp(op, _logSoftmax);
|
2022-02-10 15:05:23 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
|
2021-10-21 13:15:10 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMatmulOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value lhs = op.getSelf();
|
|
|
|
|
Value rhs = op.getOther();
|
2021-10-21 13:15:10 +08:00
|
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeLhsRank = getTensorRank(lhs);
|
|
|
|
|
std::optional<unsigned> maybeRhsRank = getTensorRank(rhs);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeLhsRank || !maybeRhsRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input tensors to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned lhsRank = *maybeLhsRank;
|
|
|
|
|
unsigned rhsRank = *maybeRhsRank;
|
2021-10-21 13:15:10 +08:00
|
|
|
|
|
2022-12-09 01:26:38 +08:00
|
|
|
|
if (lhsRank == 2 && rhsRank == 2) {
|
|
|
|
|
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
2021-10-21 13:15:10 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
} else if (lhsRank == 3 && rhsRank == 3) {
|
|
|
|
|
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
|
2021-10-21 13:15:10 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
} else {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
2021-10-21 13:15:10 +08:00
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-10-04 02:32:17 +08:00
|
|
|
|
// Decompose aten.mv into: aten.matmul.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMvOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value lhs = op.getSelf();
|
|
|
|
|
Value rhs = op.getVec();
|
2022-10-04 02:32:17 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), lhs, rhs);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-03-14 03:17:22 +08:00
|
|
|
|
// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select,
|
|
|
|
|
// aten.add.Tensor and aten.mull.Tensor. See
|
|
|
|
|
// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70.
|
|
|
|
|
// def linalg_cross(self: Tensor, other: Tensor, dim: int = -1):
|
|
|
|
|
// broadcast_shape = compute_broadcast_shape(self, other)
|
|
|
|
|
// a = torch.broadcast_to(self, broadcast_shape)
|
|
|
|
|
// b = torch.broadcast_to(other, broadcast_shape)
|
|
|
|
|
// idx = torch.arange(3)
|
|
|
|
|
// return a.index_select(dim, (idx + 1) % 3) *
|
|
|
|
|
// b.index_select(dim, (idx + 2) % 3) -
|
|
|
|
|
// a.index_select(dim, (idx + 2) % 3) *
|
|
|
|
|
// b.index_select(dim, (idx + 1) % 3)
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenLinalgCrossOp : public OpRewritePattern<AtenLinalgCrossOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLinalgCrossOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
Value other = op.getOther();
|
|
|
|
|
Type opType = op.getType();
|
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
|
|
|
|
|
auto resType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type dtype = resType.getDtype();
|
2024-04-11 21:47:35 +08:00
|
|
|
|
if (isa<mlir::ComplexType>(dtype)) {
|
2024-03-14 03:17:22 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "lowering of aten.linalg_cross for complex inputs dtype is "
|
|
|
|
|
"currently unimplemented");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// calculate common shape for broadcast
|
|
|
|
|
SmallVector<int64_t> broadcastShape;
|
|
|
|
|
SmallVector<Value> broadcastShapeValue;
|
|
|
|
|
computeBroadcastShape(rewriter, loc, self, other, broadcastShape,
|
|
|
|
|
broadcastShapeValue);
|
|
|
|
|
|
|
|
|
|
Type broadcastType = ValueTensorType::get(
|
|
|
|
|
op.getContext(), llvm::ArrayRef(broadcastShape), dtype);
|
|
|
|
|
|
|
|
|
|
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
broadcastShapeValue);
|
|
|
|
|
|
|
|
|
|
// broadcast tensors to common shape
|
|
|
|
|
auto a = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, self,
|
|
|
|
|
indexBroadcastShapeTorchList);
|
|
|
|
|
auto b = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, other,
|
|
|
|
|
indexBroadcastShapeTorchList);
|
|
|
|
|
|
|
|
|
|
// create constants
|
|
|
|
|
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value constTwo = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(2));
|
|
|
|
|
Value constThree = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(3));
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
|
|
|
|
|
// idx = torch.arange(3)
|
2024-04-11 21:47:35 +08:00
|
|
|
|
auto outType = dyn_cast<BaseTensorType>(opType);
|
2024-03-14 03:17:22 +08:00
|
|
|
|
auto arangeType = outType.getWithSizesAndDtype(
|
|
|
|
|
llvm::ArrayRef<int64_t>(3),
|
|
|
|
|
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
|
|
|
|
|
auto idx = rewriter.create<AtenArangeOp>(
|
|
|
|
|
loc, arangeType, constThree, /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none);
|
|
|
|
|
|
|
|
|
|
// (idx + 1) and (idx + 2)
|
|
|
|
|
auto idxPlusOne = rewriter.create<AtenAddScalarOp>(loc, arangeType, idx,
|
|
|
|
|
constOne, constOne);
|
|
|
|
|
auto idxPlusTwo = rewriter.create<AtenAddScalarOp>(loc, arangeType, idx,
|
|
|
|
|
constTwo, constOne);
|
|
|
|
|
|
|
|
|
|
// (idx + 1) % 3 and (idx + 2) % 3
|
|
|
|
|
auto idxPlusOneRemainderThree = rewriter.create<AtenRemainderScalarOp>(
|
|
|
|
|
loc, arangeType, idxPlusOne, constThree);
|
|
|
|
|
auto idxPlusTwoRemainderThree = rewriter.create<AtenRemainderScalarOp>(
|
|
|
|
|
loc, arangeType, idxPlusTwo, constThree);
|
|
|
|
|
|
|
|
|
|
// a.index_select(dim, (idx + 1) % 3) * b.index_select(dim, (idx + 2) % 3)
|
|
|
|
|
auto idxSelectAPlusOne = rewriter.create<AtenIndexSelectOp>(
|
|
|
|
|
loc, opType, a, dim, idxPlusOneRemainderThree);
|
|
|
|
|
auto idxSelectBPlusTwo = rewriter.create<AtenIndexSelectOp>(
|
|
|
|
|
loc, opType, b, dim, idxPlusTwoRemainderThree);
|
|
|
|
|
auto firstMul = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo);
|
|
|
|
|
|
|
|
|
|
// a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3)
|
|
|
|
|
auto idxSelectAPlusTwo = rewriter.create<AtenIndexSelectOp>(
|
|
|
|
|
loc, opType, a, dim, idxPlusTwoRemainderThree);
|
|
|
|
|
auto idxSelectBPlusOne = rewriter.create<AtenIndexSelectOp>(
|
|
|
|
|
loc, opType, b, dim, idxPlusOneRemainderThree);
|
|
|
|
|
auto secondMul = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne);
|
|
|
|
|
|
|
|
|
|
// subtract the results of the two multiplications from above
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, opType, firstMul,
|
|
|
|
|
secondMul, constOne);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
|
|
|
|
|
// prims.collapse operations.
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
//
|
2023-11-23 04:31:06 +08:00
|
|
|
|
// If input is a tensor of shape
|
|
|
|
|
// (*leading_dims, C*r*r, H, W),
|
|
|
|
|
//
|
|
|
|
|
// where leading_dims is of size N, then
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
// X = pixel_shuffle(input, upscale_factor)
|
|
|
|
|
//
|
|
|
|
|
// gets replaced with
|
2023-11-23 04:31:06 +08:00
|
|
|
|
// X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W)
|
|
|
|
|
// X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W)
|
|
|
|
|
// X = X.permute(0, ..., N, N+3, N+1, N+4, N+2)
|
|
|
|
|
// # shape (*leading_dims, C, H, r, W, r)
|
|
|
|
|
// X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W)
|
|
|
|
|
// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W)
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
//
|
|
|
|
|
// 'r' above is referred to as the 'upscale factor' or just 'factor' below.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenPixelShuffleOp
|
|
|
|
|
: public OpRewritePattern<AtenPixelShuffleOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenPixelShuffleOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value inValue = op.getSelf();
|
|
|
|
|
auto inType = inValue.getType().cast<BaseTensorType>();
|
|
|
|
|
auto maybeSizes = inType.getOptionalSizes();
|
|
|
|
|
if (!maybeSizes) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected input tensor to have known rank.");
|
|
|
|
|
}
|
|
|
|
|
auto inShape = maybeSizes.value();
|
|
|
|
|
auto inRank = inShape.size();
|
|
|
|
|
|
|
|
|
|
// The input tensor must have at least 3 dimensions: (1) the channel
|
|
|
|
|
// dimension which gets smaller by 'factor*factor', (2) the H channel which
|
|
|
|
|
// gets larger by 'factor' and (3) the W channel which get larger by
|
|
|
|
|
// 'factor'. The total number of dimensions is 3 + N, where N is the number
|
|
|
|
|
// of leading dimensions, and N >= 0 so the input must have rank at least 3.
|
|
|
|
|
if (inRank < 3)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected input tensor to have rank greater than 2.");
|
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
const auto inOptionalDType = inType.getOptionalDtype();
|
|
|
|
|
|
|
|
|
|
auto getTypeFromShape = [inOptionalDType](auto &&vals) {
|
|
|
|
|
// Get a vector of integers from a vector of Values.
|
|
|
|
|
auto getIntShape = [](auto &&vals) {
|
|
|
|
|
SmallVector<int64_t> shape;
|
|
|
|
|
shape.reserve(vals.size());
|
|
|
|
|
for (auto v : vals) {
|
|
|
|
|
int64_t cst_val;
|
|
|
|
|
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
|
|
|
|
|
shape.push_back(cst_val);
|
|
|
|
|
} else {
|
|
|
|
|
shape.push_back(kUnknownSize);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return shape;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const auto intShape = getIntShape(vals);
|
|
|
|
|
return ValueTensorType::get(vals[0].getContext(),
|
|
|
|
|
llvm::ArrayRef(intShape), inOptionalDType);
|
|
|
|
|
};
|
|
|
|
|
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
auto nLeadingDims = inRank - 3;
|
|
|
|
|
|
|
|
|
|
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
|
|
|
|
|
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
|
|
|
|
|
// folded to a ConstantOp.
|
|
|
|
|
auto getDimSize = [&](uint64_t i) -> Value {
|
|
|
|
|
Value dim =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto inC = getDimSize(inRank - 3);
|
|
|
|
|
auto inH = getDimSize(inRank - 2);
|
|
|
|
|
auto inW = getDimSize(inRank - 1);
|
|
|
|
|
|
|
|
|
|
auto factor = op.getUpscaleFactor();
|
|
|
|
|
|
|
|
|
|
Value factorSquared =
|
|
|
|
|
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
|
2023-11-23 04:31:06 +08:00
|
|
|
|
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
Value outC =
|
|
|
|
|
rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, factorSquared);
|
|
|
|
|
|
|
|
|
|
Value outH = rewriter.createOrFold<AtenMulIntOp>(loc, inH, factor);
|
|
|
|
|
Value outW = rewriter.createOrFold<AtenMulIntOp>(loc, inW, factor);
|
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
SmallVector<Value> dimensionConstants;
|
|
|
|
|
dimensionConstants.reserve(inRank + 2);
|
|
|
|
|
for (unsigned i = 0; i < inRank + 2; ++i) {
|
|
|
|
|
dimensionConstants.push_back(
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
SmallVector<Value> leadingDims;
|
|
|
|
|
leadingDims.reserve(nLeadingDims);
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
for (unsigned i = 0; i < nLeadingDims; ++i) {
|
2023-11-23 04:31:06 +08:00
|
|
|
|
Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
|
|
|
|
|
loc, inValue, dimensionConstants[i]);
|
|
|
|
|
leadingDims.push_back(leadingDimSize);
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
SmallVector<Value> partiallyExpandedShape = leadingDims;
|
|
|
|
|
partiallyExpandedShape.append({outC, factorSquared, inH, inW});
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
SmallVector<Value> prePermuteShape = leadingDims;
|
|
|
|
|
prePermuteShape.append({outC, factor, factor, inH, inW});
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
SmallVector<Value> postPermuteShape = leadingDims;
|
|
|
|
|
postPermuteShape.append({outC, inH, factor, inW, factor});
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
SmallVector<Value> partiallyCollapsedShape = leadingDims;
|
|
|
|
|
partiallyCollapsedShape.append({outC, inH, factor, outW});
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
SmallVector<Value> outShape = leadingDims;
|
|
|
|
|
outShape.append({outC, outH, outW});
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
SmallVector<Value> permutation{dimensionConstants.begin(),
|
|
|
|
|
dimensionConstants.begin() + nLeadingDims};
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
SmallVector<uint64_t> permutationTail{0, 3, 1, 4, 2};
|
|
|
|
|
for (uint64_t d : permutationTail) {
|
2023-11-23 04:31:06 +08:00
|
|
|
|
permutation.push_back(dimensionConstants[nLeadingDims + d]);
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
|
|
|
|
permutation);
|
|
|
|
|
|
2023-11-23 04:31:06 +08:00
|
|
|
|
// Split input channel inC -> (inC, factorSquared)
|
|
|
|
|
auto partiallyExpanded =
|
|
|
|
|
rewriter
|
|
|
|
|
.create<PrimsSplitDimOp>(
|
|
|
|
|
loc, getTypeFromShape(partiallyExpandedShape), inValue,
|
|
|
|
|
dimensionConstants[nLeadingDims], outC)
|
|
|
|
|
.getResult();
|
|
|
|
|
|
|
|
|
|
// Split new dimension factorSquared -> (factor, factor)
|
|
|
|
|
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
|
|
|
|
|
loc, getTypeFromShape(prePermuteShape), partiallyExpanded,
|
|
|
|
|
dimensionConstants[nLeadingDims + 1], factor);
|
|
|
|
|
|
|
|
|
|
// Perform the permutation
|
|
|
|
|
auto permuted =
|
|
|
|
|
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
|
|
|
|
|
fullyExpanded, permuteDimsOrder);
|
|
|
|
|
|
|
|
|
|
// Collapse final 2 dimension
|
|
|
|
|
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
|
|
|
|
|
loc, getTypeFromShape(partiallyCollapsedShape), permuted,
|
|
|
|
|
dimensionConstants[nLeadingDims + 3],
|
|
|
|
|
dimensionConstants[nLeadingDims + 4]);
|
|
|
|
|
|
|
|
|
|
// Collapse back to original rank
|
|
|
|
|
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
|
|
|
|
|
op, op.getType(), partiallyCollapsed,
|
|
|
|
|
dimensionConstants[nLeadingDims + 1],
|
|
|
|
|
dimensionConstants[nLeadingDims + 2]);
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
|
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
2023-11-23 04:31:06 +08:00
|
|
|
|
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
|
|
|
|
Value input) {
|
2022-02-15 21:14:32 +08:00
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
2022-02-09 04:57:23 +08:00
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value cst6 =
|
2022-02-15 21:14:32 +08:00
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(6));
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value sixTensor = createRank0Tensor(rewriter, loc, inputType, cst6);
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value relu6Out =
|
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, inputType, relu, sixTensor);
|
|
|
|
|
return relu6Out;
|
|
|
|
|
}
|
|
|
|
|
|
2022-09-23 20:39:15 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRelu6Op : public OpRewritePattern<AtenRelu6Op> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value relu6 = getRelu6Results(rewriter, loc, op.getSelf());
|
2022-09-23 20:39:15 +08:00
|
|
|
|
rewriter.replaceOp(op, relu6);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
|
// Hardswish(x) = x * Relu6(x+3)/6
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenHardswishOp : public OpRewritePattern<AtenHardswishOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenHardswishOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Type inputType = input.getType();
|
|
|
|
|
|
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(3));
|
|
|
|
|
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(6));
|
|
|
|
|
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
|
|
|
|
Value relu6 = getRelu6Results(rewriter, loc, inputPlusThree);
|
|
|
|
|
Value divTensor =
|
|
|
|
|
rewriter.create<AtenDivScalarOp>(loc, inputType, relu6, constantSix);
|
|
|
|
|
Value mulTensor =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, inputType, divTensor, input);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, mulTensor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-01-04 00:30:16 +08:00
|
|
|
|
// LeakyRelu = max(0,x) + negative_slope * min(0,x)
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenLeakyReluOp : public OpRewritePattern<AtenLeakyReluOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLeakyReluOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value negativeSlope = op.getNegativeSlope();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2023-01-04 00:30:16 +08:00
|
|
|
|
|
|
|
|
|
Value constantZero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value constantOne =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
|
|
|
|
Value positiveOutput =
|
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
|
|
|
|
Value negativeOutput =
|
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
|
|
|
|
Value scaledNegativeOutput = rewriter.create<AtenMulScalarOp>(
|
|
|
|
|
loc, resType, negativeOutput, negativeSlope);
|
|
|
|
|
Value leakyReluOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
|
loc, resType, positiveOutput, scaledNegativeOutput, constantOne);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, leakyReluOutput);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// LeakyReluBackward = max(0,grad) + negative_slope * min(0,x)
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenLeakyReluBackwardOp
|
|
|
|
|
: public OpRewritePattern<AtenLeakyReluBackwardOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLeakyReluBackwardOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value gradOutput = op.getGradOutput();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value negativeSlope = op.getNegativeSlope();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2023-01-04 00:30:16 +08:00
|
|
|
|
|
|
|
|
|
bool selfIsResult = false;
|
|
|
|
|
if (!matchPattern(op.getSelfIsResult(),
|
|
|
|
|
m_TorchConstantBool(&selfIsResult)) ||
|
|
|
|
|
selfIsResult)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: self_is_result should be false");
|
|
|
|
|
|
|
|
|
|
Value constantZero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value constantOne =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
|
|
|
|
Value positiveOutput =
|
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, gradOutput);
|
|
|
|
|
Value negativeOutput =
|
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
|
|
|
|
Value scaledNegativeOutput = rewriter.create<AtenMulScalarOp>(
|
|
|
|
|
loc, resType, negativeOutput, negativeSlope);
|
|
|
|
|
Value leakyReluBackwardOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
|
loc, resType, positiveOutput, scaledNegativeOutput, constantOne);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, leakyReluBackwardOutput);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-03-29 08:05:00 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenPreluOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value weight = op.getWeight();
|
2024-04-03 02:41:56 +08:00
|
|
|
|
auto resType = op.getType().cast<ValueTensorType>();
|
|
|
|
|
auto boolTensorType = rewriter.getType<ValueTensorType>(
|
|
|
|
|
resType.getOptionalSizes(), rewriter.getI1Type());
|
2024-03-29 08:05:00 +08:00
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
|
|
|
|
Value inputMulWeight =
|
2024-04-03 02:41:56 +08:00
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, resType, input, weight);
|
2024-03-29 08:05:00 +08:00
|
|
|
|
Value lessThanZero =
|
2024-04-03 02:41:56 +08:00
|
|
|
|
rewriter.create<AtenLtScalarOp>(loc, boolTensorType, input, zero);
|
2024-03-29 08:05:00 +08:00
|
|
|
|
Value preluOutput = rewriter.create<AtenWhereSelfOp>(
|
|
|
|
|
loc, resType, lessThanZero, inputMulWeight, input);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, preluOutput);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-02-01 01:39:38 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLerpScalarOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
|
|
|
|
Value cstOne =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
auto start = op.getSelf();
|
|
|
|
|
auto inputType = start.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
|
|
|
|
|
start, cstOne);
|
|
|
|
|
|
|
|
|
|
auto weightedDelta =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, inputType, delta, op.getWeight());
|
|
|
|
|
auto lerp = rewriter.create<AtenAddTensorOp>(loc, inputType, start,
|
|
|
|
|
weightedDelta, cstOne);
|
|
|
|
|
rewriter.replaceOp(op, lerp);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-08-24 00:04:14 +08:00
|
|
|
|
// Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1)
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenEluOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value alpha = op.getAlpha();
|
|
|
|
|
Value scale = op.getScale();
|
|
|
|
|
Value inputScale = op.getInputScale();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value constantZero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value constantOne =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Value maxZeroX =
|
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
|
|
|
|
Value positiveOutput =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resType, maxZeroX, scale);
|
|
|
|
|
Value minZeroX =
|
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
|
|
|
|
Value scaledMinZeroX =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resType, minZeroX, inputScale);
|
2023-08-24 00:04:14 +08:00
|
|
|
|
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledMinZeroX);
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
|
|
|
|
|
constantOne, constantOne);
|
|
|
|
|
Value scaledExpXM1 =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, scale);
|
|
|
|
|
Value negativeOutput =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resType, scaledExpXM1, alpha);
|
2023-08-24 00:04:14 +08:00
|
|
|
|
|
|
|
|
|
Value eluOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
|
loc, resType, positiveOutput, negativeOutput, constantOne);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, eluOutput);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-12-14 12:28:08 +08:00
|
|
|
|
// Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1)))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSeluOp : public OpRewritePattern<AtenSeluOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSeluOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Define λ and α
|
|
|
|
|
double scale = 1.0507009873554804934193349852946;
|
|
|
|
|
double alpha = 1.6732632423543772848170429916717;
|
2024-01-30 01:59:33 +08:00
|
|
|
|
|
2023-12-14 12:28:08 +08:00
|
|
|
|
// Create constants for λ and α
|
2024-01-30 01:59:33 +08:00
|
|
|
|
Value scaleVal = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr(scale));
|
|
|
|
|
Value alphaVal = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr(alpha));
|
2023-12-14 12:28:08 +08:00
|
|
|
|
|
|
|
|
|
// Create zero tensor for comparison
|
|
|
|
|
Value constantZero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
|
|
|
|
|
|
|
|
|
// Calculate positive and negative parts
|
|
|
|
|
Value constantOne =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
2024-01-30 01:59:33 +08:00
|
|
|
|
Value positiveOutput =
|
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
2023-12-14 12:28:08 +08:00
|
|
|
|
Value minZeroX =
|
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
|
|
|
|
Value expInput = rewriter.create<AtenExpOp>(loc, resType, minZeroX);
|
2024-01-30 01:59:33 +08:00
|
|
|
|
Value expInputMinusOne = rewriter.create<AtenSubScalarOp>(
|
|
|
|
|
loc, resType, expInput, constantOne, constantOne);
|
|
|
|
|
Value negativeOutput = rewriter.create<AtenMulScalarOp>(
|
|
|
|
|
loc, resType, expInputMinusOne, alphaVal);
|
2023-12-14 12:28:08 +08:00
|
|
|
|
|
|
|
|
|
// Multiply the result by λ
|
|
|
|
|
Value seluOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
|
loc, resType, positiveOutput, negativeOutput, constantOne);
|
2024-01-30 01:59:33 +08:00
|
|
|
|
seluOutput =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);
|
2023-12-14 12:28:08 +08:00
|
|
|
|
|
|
|
|
|
// Replace the original operation
|
|
|
|
|
rewriter.replaceOp(op, seluOutput);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-17 12:08:07 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value lhs = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> lhsRank = getTensorRank(lhs);
|
2021-12-17 12:08:07 +08:00
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!lhsRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
} else if (*lhsRank > 2) {
|
2021-12-17 12:08:07 +08:00
|
|
|
|
std::string errorMessage =
|
|
|
|
|
"t() expects a tensor with <=2 dimensions, but self is " +
|
2022-12-13 00:56:28 +08:00
|
|
|
|
std::to_string(*lhsRank) + "D";
|
2021-12-17 12:08:07 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
|
2022-12-13 00:56:28 +08:00
|
|
|
|
} else if (*lhsRank < 2)
|
2021-12-17 12:08:07 +08:00
|
|
|
|
rewriter.replaceOp(op, lhs);
|
|
|
|
|
else {
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
|
|
|
|
|
zero, one);
|
|
|
|
|
}
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-03-11 09:25:25 +08:00
|
|
|
|
// Decompose `aten.stack` into `aten.unsqueeze` and `aten.cat`.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenStackOp : public OpRewritePattern<AtenStackOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenStackOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
SmallVector<Value> tensors;
|
|
|
|
|
if (!getListConstructElements(op.getTensors(), tensors)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: the tensor list is not from list construct");
|
|
|
|
|
}
|
|
|
|
|
// Ensure all tensors have known sizes
|
|
|
|
|
for (Value tensor : tensors) {
|
|
|
|
|
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: one tensor does not have known sizes");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> unsqueezedTensors;
|
|
|
|
|
for (Value tensor : tensors) {
|
|
|
|
|
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, tensor, op.getDim());
|
|
|
|
|
if (failed(unsqueezedInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "cannot generate unsqueeze tensor op");
|
|
|
|
|
}
|
|
|
|
|
unsqueezedTensors.push_back(*unsqueezedInfo);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type listElemType =
|
|
|
|
|
op.getType().cast<BaseTensorType>().getWithSizesAndDtype(
|
|
|
|
|
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
|
|
|
|
|
Type listType = Torch::ListType::get(listElemType);
|
|
|
|
|
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
op.getLoc(), listType, unsqueezedTensors);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(),
|
|
|
|
|
unsqueezedTensorList, op.getDim());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-08-24 08:36:05 +08:00
|
|
|
|
// Decompose aten.roll into aten.slice and aten.cat ops.
|
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.roll.html
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRollOp : public OpRewritePattern<AtenRollOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRollOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
SmallVector<Value> shifts;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!getListConstructElements(op.getShifts(), shifts))
|
2022-08-24 08:36:05 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: shifts not list of Scalar");
|
|
|
|
|
SmallVector<Value> dims;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!getListConstructElements(op.getDims(), dims))
|
2022-08-24 08:36:05 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: dims not list of Scalar");
|
|
|
|
|
|
|
|
|
|
if (shifts.size() != dims.size())
|
|
|
|
|
return op.emitError("list sizes of shifts and dims are not the same");
|
|
|
|
|
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
2022-12-08 04:20:41 +08:00
|
|
|
|
auto self = op.getSelf();
|
2022-08-24 08:36:05 +08:00
|
|
|
|
auto selfTy = self.getType().cast<BaseTensorType>();
|
|
|
|
|
// roll(input, shift, dim) = cat({
|
|
|
|
|
// slice(input, dim, -shift, none),
|
|
|
|
|
// slice(input, dim, 0, -shift)}, dim)
|
|
|
|
|
auto imitateRoll = [&](Value input, Value shift, Value dim,
|
|
|
|
|
int64_t cstDim) {
|
|
|
|
|
Value negShift = rewriter.create<AtenNegIntOp>(loc, shift);
|
|
|
|
|
ArrayRef<int64_t> inputShape = selfTy.getSizes();
|
|
|
|
|
SmallVector<int64_t> sizes;
|
|
|
|
|
sizes.append(inputShape.begin(), inputShape.end());
|
2022-11-29 20:33:31 +08:00
|
|
|
|
sizes[cstDim] = kUnknownSize;
|
2023-01-25 09:29:42 +08:00
|
|
|
|
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes),
|
2023-01-04 06:19:18 +08:00
|
|
|
|
selfTy.getOptionalDtype());
|
2022-08-24 08:36:05 +08:00
|
|
|
|
Value slice0 = rewriter.create<AtenSliceTensorOp>(
|
|
|
|
|
loc, sliceTy, input, dim, negShift, constNone, constOne);
|
|
|
|
|
Value slice1 = rewriter.create<AtenSliceTensorOp>(
|
|
|
|
|
loc, sliceTy, input, dim, constZero, negShift, constOne);
|
|
|
|
|
|
|
|
|
|
Type listType = Torch::ListType::get(sliceTy);
|
|
|
|
|
Value slices = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
|
|
|
|
|
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
|
|
|
|
|
};
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeRank)
|
2022-08-24 08:36:05 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
2022-12-13 00:56:28 +08:00
|
|
|
|
unsigned rank = *maybeRank;
|
2022-08-24 08:36:05 +08:00
|
|
|
|
Value output = self;
|
|
|
|
|
auto nShifts = shifts.size();
|
|
|
|
|
for (size_t k = 0; k < nShifts; ++k) {
|
|
|
|
|
auto dim = dims[k];
|
|
|
|
|
int64_t cstDim = -1;
|
|
|
|
|
if (!matchPattern(dim, m_TorchConstantInt(&cstDim)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: dim must be constant");
|
|
|
|
|
|
|
|
|
|
cstDim = toPositiveDim(cstDim, rank);
|
|
|
|
|
output = imitateRoll(output, shifts[k], dim, cstDim);
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, output);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-03-07 02:19:18 +08:00
|
|
|
|
// Decompose aten.repeat into aten.squeeze, aten.unsqueeze, and aten.broadcast.
|
2022-07-01 13:02:31 +08:00
|
|
|
|
//
|
|
|
|
|
// Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRepeatOp : public OpRewritePattern<AtenRepeatOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRepeatOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2024-03-07 02:19:18 +08:00
|
|
|
|
auto selfTy = cast<BaseTensorType>(self.getType());
|
|
|
|
|
if (!selfTy.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: no implementation for rankless tensor");
|
2022-07-01 13:02:31 +08:00
|
|
|
|
|
|
|
|
|
SmallVector<Value> repeats;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!getListConstructElements(op.getRepeats(), repeats))
|
2022-07-01 13:02:31 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: repeats not list of Scalar");
|
|
|
|
|
|
2024-03-07 02:19:18 +08:00
|
|
|
|
int64_t rank = selfTy.getSizes().size();
|
|
|
|
|
if (rank > static_cast<int64_t>(repeats.size())) {
|
2022-07-01 13:02:31 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "repeats are not matched with self's rank");
|
|
|
|
|
}
|
|
|
|
|
|
2024-03-07 02:19:18 +08:00
|
|
|
|
int64_t repeatSz = repeats.size();
|
|
|
|
|
int64_t batch = repeatSz - rank;
|
2022-07-01 13:02:31 +08:00
|
|
|
|
|
2024-03-07 02:19:18 +08:00
|
|
|
|
if (!selfTy.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "input sizes unknown");
|
|
|
|
|
|
|
|
|
|
// Materialize out 1 dimensions to broadcast along. This includes
|
|
|
|
|
// materializing out preceding batch dimensions:
|
|
|
|
|
for (int i = 0; i < repeatSz; ++i) {
|
|
|
|
|
auto oldSizes = selfTy.getSizes();
|
|
|
|
|
llvm::SmallVector<int64_t> sizes;
|
|
|
|
|
int64_t squeezeDim = i < batch ? i : i * 2 - batch;
|
2022-07-01 13:02:31 +08:00
|
|
|
|
|
2024-03-07 02:19:18 +08:00
|
|
|
|
for (int j = 0; j < squeezeDim; ++j)
|
|
|
|
|
sizes.push_back(oldSizes[j]);
|
|
|
|
|
sizes.push_back(1);
|
|
|
|
|
for (int j = squeezeDim, s = oldSizes.size(); j < s; j++)
|
|
|
|
|
sizes.push_back(oldSizes[j]);
|
|
|
|
|
|
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(loc, squeezeDim);
|
|
|
|
|
selfTy =
|
|
|
|
|
rewriter.getType<ValueTensorType>(sizes, selfTy.getOptionalDtype());
|
|
|
|
|
self = rewriter.create<AtenUnsqueezeOp>(loc, selfTy, self, dim);
|
2022-07-01 13:02:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
2024-03-07 02:19:18 +08:00
|
|
|
|
llvm::SmallVector<Value> lengths;
|
|
|
|
|
for (int i = 0; i < repeatSz; ++i) {
|
|
|
|
|
if (i < batch) {
|
|
|
|
|
lengths.push_back(repeats[i]);
|
|
|
|
|
continue;
|
2022-07-01 13:02:31 +08:00
|
|
|
|
}
|
|
|
|
|
|
2024-03-07 02:19:18 +08:00
|
|
|
|
Value iv = rewriter.create<ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch));
|
|
|
|
|
Value dim = rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/iv);
|
|
|
|
|
lengths.push_back(repeats[i]);
|
|
|
|
|
lengths.push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value lengthv = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, ListType::get(rewriter.getType<IntType>()), lengths);
|
|
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t> expandShape(selfTy.getSizes());
|
|
|
|
|
for (int i = 0; i < repeatSz; ++i) {
|
|
|
|
|
int64_t repeatDim = i < batch ? i : i * 2 - batch;
|
|
|
|
|
int64_t repeat;
|
|
|
|
|
if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat)))
|
|
|
|
|
repeat = Torch::kUnknownSize;
|
|
|
|
|
expandShape[repeatDim] = repeat;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto mulDim = [](int64_t lhs, int64_t rhs) {
|
|
|
|
|
if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize)
|
|
|
|
|
return Torch::kUnknownSize;
|
|
|
|
|
return lhs * rhs;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
|
|
|
|
|
expandShape, selfTy.getOptionalDtype());
|
|
|
|
|
Value expand =
|
|
|
|
|
rewriter.create<AtenBroadcastToOp>(loc, expandTy, self, lengthv);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < rank; ++i) {
|
|
|
|
|
auto oldShape = expandTy.getSizes();
|
|
|
|
|
llvm::SmallVector<int64_t> newShape;
|
|
|
|
|
int64_t flattenDim = i + batch;
|
|
|
|
|
for (int j = 0; j < flattenDim; ++j)
|
|
|
|
|
newShape.push_back(oldShape[j]);
|
|
|
|
|
newShape.push_back(
|
|
|
|
|
mulDim(oldShape[flattenDim], oldShape[flattenDim + 1]));
|
|
|
|
|
for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j)
|
|
|
|
|
newShape.push_back(oldShape[j]);
|
|
|
|
|
|
|
|
|
|
expandTy = rewriter.getType<ValueTensorType>(newShape,
|
|
|
|
|
expandTy.getOptionalDtype());
|
|
|
|
|
|
|
|
|
|
// Used to keep the return type the same on the last flatten:
|
|
|
|
|
expandTy = i < rank - 1 ? expandTy : cast<BaseTensorType>(op.getType());
|
|
|
|
|
|
|
|
|
|
Value start = rewriter.create<ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(flattenDim));
|
|
|
|
|
Value end = rewriter.create<ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(flattenDim + 1));
|
|
|
|
|
expand = rewriter.create<AtenFlattenUsingIntsOp>(loc, expandTy, expand,
|
|
|
|
|
start, end);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, expand);
|
2022-07-01 13:02:31 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-08-23 11:52:54 +08:00
|
|
|
|
// Decompose aten.flatten.using_ints into aten.view op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenFlattenUsingIntsOp
|
|
|
|
|
: public OpRewritePattern<AtenFlattenUsingIntsOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2022-08-23 11:52:54 +08:00
|
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeRank)
|
2022-08-23 11:52:54 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
2022-12-13 00:56:28 +08:00
|
|
|
|
unsigned rank = *maybeRank;
|
2022-08-23 11:52:54 +08:00
|
|
|
|
|
|
|
|
|
int64_t start, end;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getStartDim(), m_TorchConstantInt(&start)) ||
|
|
|
|
|
!matchPattern(op.getEndDim(), m_TorchConstantInt(&end))) {
|
2022-08-23 11:52:54 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: requires start and end dims to be constants");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value, 4> newSizes;
|
|
|
|
|
if (rank == 0) {
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
newSizes.push_back(one);
|
|
|
|
|
} else {
|
|
|
|
|
start = toPositiveDim(start, rank);
|
|
|
|
|
end = toPositiveDim(end, rank);
|
|
|
|
|
|
|
|
|
|
if (start > end) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected end dim larger than start dim");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
newSizes.reserve(rank - end + start);
|
2022-08-26 06:00:01 +08:00
|
|
|
|
for (int64_t k = 0; k < start; ++k) {
|
2022-08-23 11:52:54 +08:00
|
|
|
|
Value dim =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(k));
|
|
|
|
|
newSizes.push_back(
|
|
|
|
|
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dim));
|
|
|
|
|
}
|
2022-11-24 21:02:59 +08:00
|
|
|
|
Value flattenDimSize =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
2022-08-23 11:52:54 +08:00
|
|
|
|
newSizes.push_back(flattenDimSize);
|
2022-08-26 06:00:01 +08:00
|
|
|
|
for (int64_t k = end + 1; k < rank; ++k) {
|
2022-08-23 11:52:54 +08:00
|
|
|
|
Value dim =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(k));
|
|
|
|
|
newSizes.push_back(
|
|
|
|
|
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dim));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Value newSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, ListType::get(IntType::get(context)), newSizes);
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.getSelf(),
|
2022-08-23 11:52:54 +08:00
|
|
|
|
newSizeList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-10-31 15:36:16 +08:00
|
|
|
|
// Decompose aten.unflatten.int into aten.view op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenUnflattenIntOp
|
|
|
|
|
: public OpRewritePattern<AtenUnflattenIntOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenUnflattenIntOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!outputTensorType.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: output must have known sizes");
|
|
|
|
|
|
|
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
|
|
|
|
if (!maybeRank)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
|
|
|
|
unsigned inputRank = *maybeRank;
|
|
|
|
|
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
|
|
|
|
|
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
2023-11-23 04:31:06 +08:00
|
|
|
|
"Expected input type having sizes");
|
2023-10-31 15:36:16 +08:00
|
|
|
|
}
|
|
|
|
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> sizesInts;
|
|
|
|
|
if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizesInts)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "sizes must be a list of constant ints");
|
|
|
|
|
|
|
|
|
|
bool inferred = false;
|
|
|
|
|
if (llvm::count(sizesInts, -1) > 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only one of sizes' elements can be -1");
|
|
|
|
|
|
|
|
|
|
int64_t dimInt;
|
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: requires dim to be constants");
|
|
|
|
|
|
|
|
|
|
dimInt = toPositiveDim(dimInt, inputRank);
|
|
|
|
|
if (!isValidDim(dimInt, inputRank))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> sizesTorchInt;
|
|
|
|
|
if (!getListConstructElements(op.getSizes(), sizesTorchInt))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: sizes not list of Scalar");
|
|
|
|
|
|
|
|
|
|
// Create new sizes based on the unflattened dimension.
|
|
|
|
|
SmallVector<Value> newSizes;
|
|
|
|
|
for (int64_t i = 0; i < inputRank; ++i) {
|
|
|
|
|
Value dimValue =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
Value dimSize =
|
|
|
|
|
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dimValue);
|
|
|
|
|
if (i == dimInt) {
|
2023-11-23 04:31:06 +08:00
|
|
|
|
int64_t inferredSizeInt = inputShape[i];
|
2023-10-31 15:36:16 +08:00
|
|
|
|
int64_t inferredDim;
|
|
|
|
|
for (unsigned j = 0; j < sizesInts.size(); ++j) {
|
|
|
|
|
if (sizesInts[j] == -1) {
|
|
|
|
|
inferred = true;
|
|
|
|
|
inferredDim = j;
|
|
|
|
|
} else {
|
|
|
|
|
Value sizeValue = rewriter.create<ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(sizesInts[j]));
|
|
|
|
|
newSizes.push_back(sizeValue);
|
|
|
|
|
inferredSizeInt = inferredSizeInt / sizesInts[j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (inferred) {
|
2023-11-23 04:31:06 +08:00
|
|
|
|
Value inferredSize = rewriter.create<ConstantIntOp>(
|
2023-10-31 15:36:16 +08:00
|
|
|
|
loc, rewriter.getI64IntegerAttr(inferredSizeInt));
|
2023-11-23 04:31:06 +08:00
|
|
|
|
newSizes.insert(newSizes.begin() + inferredDim + i, inferredSize);
|
2023-10-31 15:36:16 +08:00
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
newSizes.push_back(dimSize);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create the AtenViewOp to replace the original op.
|
|
|
|
|
Value newSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, ListType::get(IntType::get(context)), newSizes);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.getSelf(),
|
|
|
|
|
newSizeList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.expand into aten.broadcast_to op.
|
2021-11-03 00:48:29 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenExpandOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
bool implicit = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getImplicit(), m_TorchConstantBool(&implicit)) ||
|
2021-11-03 00:48:29 +08:00
|
|
|
|
implicit) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: requires implicit to be false");
|
|
|
|
|
}
|
2023-11-01 11:56:54 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(),
|
|
|
|
|
op.getSelf(), op.getSize());
|
2021-11-03 00:48:29 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-12 01:21:36 +08:00
|
|
|
|
// Decompose aten.where.Scalar into aten.where.self op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenWhereScalarOp : public OpRewritePattern<AtenWhereScalarOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Value otherTensor =
|
|
|
|
|
createRank0Tensor(rewriter, loc, resType, op.getOther());
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
2022-03-12 01:21:36 +08:00
|
|
|
|
selfTensor, otherTensor);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// Decompose aten.where.ScalarOther into aten.where.self op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenWhereScalarOtherOp
|
|
|
|
|
: public OpRewritePattern<AtenWhereScalarOtherOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Value otherTensor =
|
|
|
|
|
createRank0Tensor(rewriter, loc, resType, op.getOther());
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
|
|
|
|
op.getSelf(), otherTensor);
|
2022-03-12 01:21:36 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// Decompose aten.where.ScalarSelf into aten.where.self op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenWhereScalarSelfOp
|
|
|
|
|
: public OpRewritePattern<AtenWhereScalarSelfOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
|
|
|
|
selfTensor, op.getOther());
|
2022-03-12 01:21:36 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-01-16 14:29:34 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenNanToNumOp : public OpRewritePattern<AtenNanToNumOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNanToNumOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
mlir::FloatType f64Type = rewriter.getF64Type();
|
|
|
|
|
Value nan = op.getNan();
|
|
|
|
|
Value posinf = op.getPosinf();
|
|
|
|
|
Value neginf = op.getNeginf();
|
|
|
|
|
auto baseType =
|
|
|
|
|
ValueTensorType::getWithLeastStaticInformation(op.getContext());
|
|
|
|
|
if (dyn_cast_or_null<ConstantNoneOp>(nan.getDefiningOp()))
|
|
|
|
|
nan = rewriter.create<ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getFloatAttr(
|
|
|
|
|
f64Type, APFloat::getZero(f64Type.getFloatSemantics())));
|
|
|
|
|
if (dyn_cast_or_null<ConstantNoneOp>(posinf.getDefiningOp()))
|
|
|
|
|
posinf = rewriter.create<ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getFloatAttr(
|
|
|
|
|
f64Type, APFloat::getInf(f64Type.getFloatSemantics())));
|
|
|
|
|
if (dyn_cast_or_null<ConstantNoneOp>(neginf.getDefiningOp()))
|
|
|
|
|
neginf = rewriter.create<ConstantFloatOp>(
|
|
|
|
|
loc,
|
|
|
|
|
rewriter.getFloatAttr(
|
|
|
|
|
f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true)));
|
|
|
|
|
Value isNan =
|
|
|
|
|
rewriter.create<Torch::AtenIsnanOp>(loc, baseType, op.getSelf());
|
|
|
|
|
Value where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
|
|
|
|
|
loc, baseType, isNan, nan, op.getSelf());
|
|
|
|
|
Value isposinf =
|
|
|
|
|
rewriter.create<Torch::AtenIsposinfOp>(loc, baseType, where);
|
|
|
|
|
where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
|
|
|
|
|
loc, baseType, isposinf, posinf, where);
|
|
|
|
|
Value isneginf =
|
|
|
|
|
rewriter.create<Torch::AtenIsneginfOp>(loc, baseType, where);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenWhereScalarSelfOp>(
|
|
|
|
|
op, op.getType(), isneginf, neginf, where);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-02-11 05:58:39 +08:00
|
|
|
|
// Decompose aten.masked_fill.Scalar into aten.where.self op.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMaskedFillScalarOp
|
|
|
|
|
: public OpRewritePattern<AtenMaskedFillScalarOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2023-02-11 05:58:39 +08:00
|
|
|
|
Value mask = op.getMask();
|
|
|
|
|
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
|
2023-11-01 11:56:54 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask, value,
|
|
|
|
|
op.getSelf());
|
2023-02-11 05:58:39 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
2022-04-08 12:47:57 +08:00
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-08-22 11:17:36 +08:00
|
|
|
|
// Decompose aten._convolution-like to aten.convolution
|
2022-07-08 14:44:03 +08:00
|
|
|
|
namespace {
|
2022-11-04 15:57:29 +08:00
|
|
|
|
template <typename ConvolutionLikeOp>
|
2022-08-22 11:17:36 +08:00
|
|
|
|
class DecomposeAten_ConvolutionLikeOp
|
|
|
|
|
: public OpRewritePattern<ConvolutionLikeOp> {
|
2022-07-08 14:44:03 +08:00
|
|
|
|
public:
|
2022-08-22 11:17:36 +08:00
|
|
|
|
using OpRewritePattern<ConvolutionLikeOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(ConvolutionLikeOp op,
|
2022-07-08 14:44:03 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
|
|
|
|
op.getOutputPadding(), op.getGroups());
|
2022-07-08 14:44:03 +08:00
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-01-24 13:30:03 +08:00
|
|
|
|
namespace {
|
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
|
static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter,
|
|
|
|
|
Location loc, Value input,
|
|
|
|
|
int64_t dimA,
|
|
|
|
|
int64_t dimB,
|
|
|
|
|
Value &transposed) {
|
|
|
|
|
Type transposedType;
|
|
|
|
|
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
|
|
|
|
|
dimA, dimB, transposedType)))
|
|
|
|
|
return failure();
|
|
|
|
|
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dimA));
|
|
|
|
|
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dimB));
|
|
|
|
|
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
|
loc, transposedType, input, cstDimA, cstDimB);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
2024-01-24 13:30:03 +08:00
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
|
class DecomposeAtenConvTbcOp : public OpRewritePattern<AtenConvTbcOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenConvTbcOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value emptyList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
SmallVector<Value>());
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
Value oneList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
SmallVector<Value>{rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
op.getLoc(), rewriter.getI64IntegerAttr(1))});
|
|
|
|
|
Value padding = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
SmallVector<Value>{op.getPad()});
|
|
|
|
|
Value groups = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
op.getLoc(), rewriter.getI64IntegerAttr(1));
|
|
|
|
|
|
|
|
|
|
// convtbc has WNC layout for input and output
|
|
|
|
|
// and WCF layout for weight
|
|
|
|
|
// whereas Convolution is going to use Conv1DNcwFcwOp for 1d
|
|
|
|
|
// which means we need the inputs in NCW and the weight in FCW
|
|
|
|
|
Value selfWnc = op.getSelf();
|
|
|
|
|
Value selfNwc;
|
|
|
|
|
Value selfNcw;
|
|
|
|
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc,
|
|
|
|
|
0, 1, selfNwc)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"failed to transpose input to Nwc");
|
|
|
|
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc,
|
|
|
|
|
1, 2, selfNcw)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"failed to transpose input to Ncw");
|
2024-01-24 13:30:03 +08:00
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
|
Value weightWcf = op.getWeight();
|
|
|
|
|
Value weightFcw;
|
|
|
|
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
|
|
|
|
|
weightWcf, 0, 2, weightFcw)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"failed to transpose weight to Fcw");
|
|
|
|
|
|
|
|
|
|
Value outputNcw = rewriter.create<AtenConvolutionOp>(
|
|
|
|
|
op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(),
|
|
|
|
|
/*stride*/ oneList,
|
|
|
|
|
/*padding*/ padding, /*dilation*/ oneList,
|
|
|
|
|
/*transpose*/ cstFalse, /*output_padding*/ emptyList, groups);
|
|
|
|
|
|
|
|
|
|
// convert output from Ncw to Wnc
|
|
|
|
|
Value outputNwc;
|
|
|
|
|
Value outputWnc;
|
|
|
|
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
|
|
|
|
|
outputNcw, 1, 2, outputNwc)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"failed to transpose output to Nwc");
|
|
|
|
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
|
|
|
|
|
outputNwc, 0, 1, outputWnc)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"failed to transpose output to Wnc");
|
|
|
|
|
rewriter.replaceOp(op, outputWnc);
|
2024-01-24 13:30:03 +08:00
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
2024-01-24 13:30:03 +08:00
|
|
|
|
|
|
|
|
|
// Decompose aten.conv1d to aten.convolution
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenConv1dOp : public OpRewritePattern<AtenConv1dOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenConv1dOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Value emptyList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
SmallVector<Value>());
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
|
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList,
|
|
|
|
|
op.getGroups());
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-04-08 12:47:57 +08:00
|
|
|
|
// Decompose aten.conv2d to aten.convolution
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenConv2dOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Value emptyList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
SmallVector<Value>());
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList,
|
|
|
|
|
op.getGroups());
|
2022-04-08 12:47:57 +08:00
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-01-24 13:30:03 +08:00
|
|
|
|
// Decompose aten.conv3d to aten.convolution
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenConv3dOp : public OpRewritePattern<AtenConv3dOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenConv3dOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Value emptyList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
SmallVector<Value>());
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
|
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList,
|
|
|
|
|
op.getGroups());
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-08-25 00:19:35 +08:00
|
|
|
|
// Decompose aten.conv_transpose2d to aten.convolution
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenConvTranspose2dOp
|
|
|
|
|
: public OpRewritePattern<AtenConvTranspose2dInputOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(),
|
|
|
|
|
/*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups());
|
2022-11-04 15:57:29 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-12-08 22:15:31 +08:00
|
|
|
|
// The convolution backward op is decomposed as follows:
|
|
|
|
|
// inputH, inputW = input.shape[2:]
|
|
|
|
|
// output_padding_ = [
|
|
|
|
|
// inputH
|
|
|
|
|
// - 1
|
|
|
|
|
// + 2 * padding_[0]
|
|
|
|
|
// - dilation_[0] * (weight.shape[2] - 1)
|
|
|
|
|
// - (grad_output.shape[2] - 1) * stride_[0],
|
|
|
|
|
// inputW
|
|
|
|
|
// - 1
|
|
|
|
|
// + 2 * padding_[1]
|
|
|
|
|
// - dilation_[1] * (weight.shape[3] - 1)
|
|
|
|
|
// - (grad_output.shape[3] - 1) * stride_[1],
|
|
|
|
|
// ]
|
|
|
|
|
//
|
|
|
|
|
// decomp_grad_input = torch.nn.functional.conv_transpose2d(
|
|
|
|
|
// grad_output,
|
|
|
|
|
// weight,
|
|
|
|
|
// None,
|
|
|
|
|
// stride_,
|
|
|
|
|
// padding_,
|
|
|
|
|
// output_padding_,
|
|
|
|
|
// groups_,
|
|
|
|
|
// dilation_,
|
|
|
|
|
// )
|
|
|
|
|
//
|
|
|
|
|
// input_transposed = torch.ops.aten.transpose(input, 0, 1)
|
|
|
|
|
// grad_output_transposed = grad_output.view(
|
|
|
|
|
// grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:]
|
|
|
|
|
// )
|
|
|
|
|
// decomp_grad_weight = torch.ops.aten.convolution(
|
|
|
|
|
// input_transposed,
|
|
|
|
|
// grad_output_transposed,
|
|
|
|
|
// bias=None,
|
|
|
|
|
// stride=dilation_,
|
|
|
|
|
// padding=padding_,
|
|
|
|
|
// dilation=stride_,
|
|
|
|
|
// transposed=False,
|
|
|
|
|
// output_padding=[0, 0],
|
|
|
|
|
// groups=input.shape[0],
|
|
|
|
|
// )
|
|
|
|
|
// decomp_grad_weight = torch.narrow(decomp_grad_weight, 2, 0, weight.shape[2])
|
|
|
|
|
// decomp_grad_weight = torch.narrow(decomp_grad_weight, 3, 0, weight.shape[3])
|
|
|
|
|
// decomp_grad_weight = decomp_grad_weight.view(
|
|
|
|
|
// input_transposed.shape[0],
|
|
|
|
|
// input_transposed.shape[1],
|
|
|
|
|
// grad_output.shape[1],
|
|
|
|
|
// *decomp_grad_weight.shape[2:]
|
|
|
|
|
// )
|
|
|
|
|
// decomp_grad_weight = decomp_grad_weight.movedim(0, 2)
|
|
|
|
|
// decomp_grad_weight = decomp_grad_weight.sum(dim=0)
|
|
|
|
|
//
|
|
|
|
|
// decomp_grad_bias = torch.sum(grad_output, dim=[0, 2, 3])
|
2022-11-04 15:57:29 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenConvolutionBackwardOp
|
|
|
|
|
: public OpRewritePattern<AtenConvolutionBackwardOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-08 22:15:31 +08:00
|
|
|
|
Value input = op.getInput();
|
|
|
|
|
Value weight = op.getWeight();
|
2022-12-13 00:56:28 +08:00
|
|
|
|
Value gradOutput = op.getGradOutput();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeGradRank = getTensorRank(gradOutput);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeGradRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"expected grad output to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned gradRank = *maybeGradRank;
|
2022-12-08 22:15:31 +08:00
|
|
|
|
if (gradRank != 4)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: only 2D convolutions supported.");
|
|
|
|
|
|
2022-11-04 15:57:29 +08:00
|
|
|
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(2));
|
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
|
|
|
|
loc, rewriter.getBoolAttr(false));
|
|
|
|
|
|
2022-12-08 22:15:31 +08:00
|
|
|
|
SmallVector<Value> padding, dilation, stride;
|
|
|
|
|
SmallVector<int64_t, 2> paddingInt, dilationInt, strideInt,
|
|
|
|
|
outputPaddingInt;
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
2022-12-08 22:15:31 +08:00
|
|
|
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInt)))
|
2022-11-04 15:57:29 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-12-08 22:15:31 +08:00
|
|
|
|
op, "padding must be a list of constant ints");
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
2022-12-08 22:15:31 +08:00
|
|
|
|
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInt)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "stride must be a list of constant ints");
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
2022-12-08 22:15:31 +08:00
|
|
|
|
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInt)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "dilation must be a list of constant ints");
|
|
|
|
|
if (!llvm::all_of(dilationInt,
|
|
|
|
|
[](int64_t dilationVal) { return dilationVal == 1; }))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: only dilations of 1 supported.");
|
|
|
|
|
|
|
|
|
|
if (!matchPattern(op.getOutputPadding(),
|
|
|
|
|
m_TorchListOfConstantInts(outputPaddingInt)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "output padding must be a list of constant ints");
|
|
|
|
|
if (!llvm::all_of(outputPaddingInt,
|
|
|
|
|
[](int64_t outPad) { return outPad == 0; }))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: only output padding of 0 supported.");
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
|
|
|
|
SmallVector<bool> outMask;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getOutputMask(), m_TorchListOfConstantBools(outMask)))
|
2022-11-04 15:57:29 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant bool output_mask is supported.");
|
2022-12-08 22:15:31 +08:00
|
|
|
|
for (unsigned i = 0; i < outMask.size(); i++) {
|
|
|
|
|
if (outMask[i] == false) {
|
|
|
|
|
Value result = op->getResults()[i];
|
|
|
|
|
if (!result.getUsers().empty())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: false value supported for output_mask only "
|
|
|
|
|
"when the result tensor corresponding to that has no users.");
|
|
|
|
|
}
|
|
|
|
|
}
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
|
|
|
|
bool transposed;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
|
2022-11-04 15:57:29 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-12-08 22:15:31 +08:00
|
|
|
|
op, "transposed arg should be a constant bool.");
|
2022-11-04 15:57:29 +08:00
|
|
|
|
if (transposed)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: transposed convolutions are not supported.");
|
|
|
|
|
|
2022-12-08 22:15:31 +08:00
|
|
|
|
getListConstructElements(op.getPadding(), padding);
|
|
|
|
|
getListConstructElements(op.getStride(), stride);
|
|
|
|
|
getListConstructElements(op.getDilation(), dilation);
|
|
|
|
|
|
|
|
|
|
// Computing Grad Input.
|
|
|
|
|
// Calculate output padding for first convolution.
|
|
|
|
|
// output_padding_ = [
|
|
|
|
|
// inputH - 1 + (2 * padding_[0]) - (dilation_[0] * (weight.size()[2]
|
|
|
|
|
// - 1)) - ((grad_out.size()[2] - 1) * stride_[0]), inputW - 1 + (2 *
|
|
|
|
|
// padding_[1]) - (dilation_[1] * (weight.size()[3] - 1)) -
|
|
|
|
|
// ((grad_out.size()[3] - 1) * stride_[1]),
|
|
|
|
|
// ]
|
|
|
|
|
SmallVector<Value> outputPaddingValues;
|
2022-12-13 00:56:28 +08:00
|
|
|
|
for (unsigned i = 2; i < gradRank; i++) {
|
2022-11-04 15:57:29 +08:00
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
2022-12-08 22:15:31 +08:00
|
|
|
|
Value inputVecDim =
|
|
|
|
|
rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
|
2022-11-04 15:57:29 +08:00
|
|
|
|
Value gradOutDim =
|
|
|
|
|
rewriter.create<Torch::AtenSizeIntOp>(loc, gradOutput, dim);
|
2022-12-08 22:15:31 +08:00
|
|
|
|
Value weightDim = rewriter.create<Torch::AtenSizeIntOp>(loc, weight, dim);
|
|
|
|
|
Value inputVecDimMinusOne =
|
|
|
|
|
rewriter.create<Torch::AtenSubIntOp>(loc, inputVecDim, cstOne);
|
|
|
|
|
Value gradOutDimMinusOne =
|
|
|
|
|
rewriter.create<Torch::AtenSubIntOp>(loc, gradOutDim, cstOne);
|
|
|
|
|
Value weightDimMinusOne =
|
|
|
|
|
rewriter.create<Torch::AtenSubIntOp>(loc, weightDim, cstOne);
|
|
|
|
|
Value twoTimesPadding =
|
|
|
|
|
rewriter.create<Torch::AtenMulIntOp>(loc, padding[i - 2], cstTwo);
|
|
|
|
|
Value tmpA = rewriter.create<Torch::AtenMulIntOp>(loc, weightDimMinusOne,
|
|
|
|
|
dilation[i - 2]);
|
|
|
|
|
Value tmpB = rewriter.create<Torch::AtenMulIntOp>(loc, gradOutDimMinusOne,
|
|
|
|
|
stride[i - 2]);
|
|
|
|
|
Value outputPaddingVal = rewriter.create<AtenAddIntOp>(
|
|
|
|
|
loc, inputVecDimMinusOne, twoTimesPadding);
|
|
|
|
|
outputPaddingVal =
|
|
|
|
|
rewriter.create<AtenSubIntOp>(loc, outputPaddingVal, tmpA);
|
|
|
|
|
outputPaddingVal =
|
|
|
|
|
rewriter.create<AtenSubIntOp>(loc, outputPaddingVal, tmpB);
|
|
|
|
|
outputPaddingValues.push_back(outputPaddingVal);
|
2022-11-04 15:57:29 +08:00
|
|
|
|
}
|
2022-12-08 22:15:31 +08:00
|
|
|
|
Value outputPaddingForGradInput =
|
|
|
|
|
rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, ListType::get(IntType::get(context)), outputPaddingValues);
|
|
|
|
|
|
|
|
|
|
Value gradInput = rewriter.create<Torch::AtenConvTranspose2dInputOp>(
|
|
|
|
|
loc, op.getResultTypes()[0], gradOutput, weight, cstNone,
|
|
|
|
|
op.getStride(), op.getPadding(), outputPaddingForGradInput,
|
|
|
|
|
op.getGroups(), op.getDilation());
|
|
|
|
|
|
|
|
|
|
Type transposedType;
|
|
|
|
|
if (failed(getTransposedType(input.getType().cast<BaseTensorType>(), 0, 1,
|
|
|
|
|
transposedType)))
|
|
|
|
|
return failure();
|
2022-11-04 15:57:29 +08:00
|
|
|
|
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
2022-12-08 22:15:31 +08:00
|
|
|
|
loc, transposedType, input, cstZero, cstOne);
|
|
|
|
|
|
|
|
|
|
// For the cases where the stride is non-unit, we compute the `GradWeight`
|
|
|
|
|
// through this implementation.
|
|
|
|
|
Value gradWeight;
|
|
|
|
|
if (!llvm::all_of(strideInt, [](int64_t stride) { return stride == 1; })) {
|
|
|
|
|
// Computing Grad Weight.
|
|
|
|
|
SmallVector<Value, 4> gradOutputSize;
|
|
|
|
|
for (unsigned i = 0; i < gradRank; i++) {
|
|
|
|
|
gradOutputSize.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
|
|
|
|
loc, gradOutput,
|
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i))));
|
|
|
|
|
}
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
2022-12-08 22:15:31 +08:00
|
|
|
|
Value gradOutputViewDimZero = rewriter.create<Torch::AtenMulIntOp>(
|
|
|
|
|
loc, gradOutputSize[0], gradOutputSize[1]);
|
|
|
|
|
Value gradOutputViewShapeList =
|
|
|
|
|
rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2],
|
|
|
|
|
gradOutputSize[3]});
|
|
|
|
|
|
|
|
|
|
BaseTensorType gradOutputTy = gradOutput.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!gradOutputTy.hasSizes())
|
|
|
|
|
return failure();
|
|
|
|
|
SmallVector<int64_t> gradOutputSizesInt(gradOutputTy.getSizes());
|
|
|
|
|
SmallVector<int64_t> gradOutputViewSizesInt(gradOutputSizesInt);
|
|
|
|
|
if (gradOutputViewSizesInt[0] != kUnknownSize &&
|
|
|
|
|
gradOutputViewSizesInt[1] != kUnknownSize)
|
|
|
|
|
gradOutputViewSizesInt[0] *= gradOutputViewSizesInt[1];
|
|
|
|
|
else
|
|
|
|
|
gradOutputViewSizesInt[0] = kUnknownSize;
|
|
|
|
|
gradOutputViewSizesInt[1] = 1;
|
|
|
|
|
BaseTensorType gradOutputTypeForView =
|
|
|
|
|
gradOutputTy
|
|
|
|
|
.getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt),
|
|
|
|
|
gradOutputTy.getOptionalDtype())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
Value gradOutputView = rewriter.create<Torch::AtenViewOp>(
|
|
|
|
|
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
|
|
|
|
|
|
|
|
|
|
BaseTensorType inputTransposedTy =
|
|
|
|
|
inputTransposed.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputTransposedTy.hasSizes())
|
|
|
|
|
return failure();
|
|
|
|
|
SmallVector<int64_t> inputTransposedSizesInt(
|
|
|
|
|
inputTransposedTy.getSizes());
|
|
|
|
|
SmallVector<int64_t> gradWeightSizesInt{inputTransposedSizesInt[0],
|
|
|
|
|
gradOutputViewSizesInt[0]};
|
|
|
|
|
for (unsigned i = 2; i < gradRank; i++) {
|
|
|
|
|
if (inputTransposedSizesInt[i] != kUnknownSize &&
|
|
|
|
|
gradOutputViewSizesInt[i] != kUnknownSize) {
|
|
|
|
|
int64_t kernelSizeInt =
|
|
|
|
|
strideInt[i - 2] * (gradOutputViewSizesInt[i] - 1) + 1;
|
|
|
|
|
gradWeightSizesInt.push_back(
|
|
|
|
|
((inputTransposedSizesInt[i] + (paddingInt[i - 2] * 2) -
|
|
|
|
|
kernelSizeInt) /
|
|
|
|
|
dilationInt[i - 2]) +
|
|
|
|
|
1);
|
|
|
|
|
} else {
|
|
|
|
|
gradWeightSizesInt.push_back(kUnknownSize);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BaseTensorType gradWeightTy =
|
|
|
|
|
inputTransposedTy
|
|
|
|
|
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt),
|
|
|
|
|
inputTransposedTy.getOptionalDtype())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
Value numGroup = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
|
|
|
|
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
|
|
|
|
loc, gradWeightTy, inputTransposed, gradOutputView, cstNone,
|
|
|
|
|
/*stride=*/op.getDilation(), op.getPadding(),
|
|
|
|
|
/*dilation=*/op.getStride(), op.getTransposed(),
|
|
|
|
|
op.getOutputPadding(), numGroup);
|
|
|
|
|
|
|
|
|
|
BaseTensorType weightTy = weight.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!weightTy.hasSizes())
|
|
|
|
|
return failure();
|
|
|
|
|
SmallVector<int64_t> weightSizes(weightTy.getSizes());
|
|
|
|
|
for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) {
|
|
|
|
|
gradWeightSizesInt[i + 2] = weightSizes[i + 2];
|
|
|
|
|
BaseTensorType gradWeightNarrowTy =
|
|
|
|
|
gradWeightTy
|
|
|
|
|
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt),
|
|
|
|
|
gradWeightTy.getOptionalDtype())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
Value dim = rewriter.create<ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i + 2));
|
|
|
|
|
Value length = rewriter.create<Torch::AtenSizeIntOp>(loc, weight, dim);
|
|
|
|
|
gradWeight = rewriter.create<Torch::AtenNarrowOp>(
|
|
|
|
|
loc, gradWeightNarrowTy, gradWeight, dim, /*start=*/cstZero,
|
|
|
|
|
length);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t, 5> gradWeightViewShapeInt{
|
|
|
|
|
inputTransposedSizesInt[0], inputTransposedSizesInt[1]};
|
|
|
|
|
gradWeightViewShapeInt.push_back(gradOutputSizesInt[1]);
|
|
|
|
|
gradWeightViewShapeInt.insert(
|
|
|
|
|
gradWeightViewShapeInt.end(),
|
|
|
|
|
{gradWeightSizesInt[2], gradWeightSizesInt[3]});
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> gradWeightViewShapeValue;
|
|
|
|
|
for (unsigned i = 0; i < gradWeightViewShapeInt.size(); i++) {
|
|
|
|
|
gradWeightViewShapeValue.push_back(
|
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i])));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value gradWeightViewShapeList =
|
|
|
|
|
rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
gradWeightViewShapeValue);
|
|
|
|
|
|
|
|
|
|
BaseTensorType gradWeightTypeForView =
|
|
|
|
|
gradWeightTy
|
|
|
|
|
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt),
|
|
|
|
|
gradWeightTy.getOptionalDtype())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
gradWeight = rewriter.create<Torch::AtenViewOp>(
|
|
|
|
|
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
|
|
|
|
|
|
|
|
|
|
gradWeightTy = gradWeight.getType().cast<BaseTensorType>();
|
|
|
|
|
SmallVector<int64_t, 5> gradWeightDimsOrder =
|
|
|
|
|
computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size());
|
|
|
|
|
SmallVector<int64_t, 5> gradWeightMoveDimShape;
|
|
|
|
|
for (unsigned i = 0; i < gradWeightDimsOrder.size(); i++) {
|
|
|
|
|
gradWeightMoveDimShape.push_back(
|
|
|
|
|
gradWeightViewShapeInt[gradWeightDimsOrder[i]]);
|
|
|
|
|
}
|
|
|
|
|
BaseTensorType gradWeightTypeForMoveDim =
|
|
|
|
|
gradWeightTy
|
|
|
|
|
.getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape),
|
|
|
|
|
gradWeightTy.getOptionalDtype())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
gradWeight = rewriter.create<AtenMovedimIntOp>(
|
|
|
|
|
loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero,
|
|
|
|
|
/*destination=*/cstTwo);
|
|
|
|
|
|
|
|
|
|
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
llvm::ArrayRef{cstZero});
|
|
|
|
|
gradWeight = rewriter.create<Torch::AtenSumDimIntListOp>(
|
|
|
|
|
loc, op.getResultTypes()[1], /*self=*/gradWeight, /*dim=*/gradIntList,
|
|
|
|
|
/*keepdim=*/cstFalse,
|
|
|
|
|
/*dtype=*/cstNone);
|
|
|
|
|
} else {
|
|
|
|
|
if (failed(getTransposedType(gradOutput.getType().cast<BaseTensorType>(),
|
|
|
|
|
0, 1, transposedType)))
|
|
|
|
|
return failure();
|
|
|
|
|
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
|
loc, transposedType, gradOutput, cstZero, cstOne);
|
|
|
|
|
// Convolve input with grad_output.
|
|
|
|
|
if (failed(
|
|
|
|
|
getTransposedType(op.getResultTypes()[1].cast<BaseTensorType>(),
|
|
|
|
|
0, 1, transposedType)))
|
|
|
|
|
return failure();
|
|
|
|
|
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
|
|
|
|
loc, transposedType, inputTransposed, gradOutputTransposed, cstNone,
|
|
|
|
|
op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
|
|
|
|
op.getOutputPadding(), op.getGroups());
|
|
|
|
|
gradWeight = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
|
|
|
loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Computing Grad Bias.
|
2022-11-04 15:57:29 +08:00
|
|
|
|
SmallVector<Value> dimIntList{cstZero};
|
2022-12-13 00:56:28 +08:00
|
|
|
|
for (unsigned i = 2; i < gradRank; i++)
|
2022-11-04 15:57:29 +08:00
|
|
|
|
dimIntList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
dimIntList);
|
2022-12-08 22:15:31 +08:00
|
|
|
|
|
2022-11-04 15:57:29 +08:00
|
|
|
|
// Sum grad_output along dim 1.
|
|
|
|
|
Value gradBias = rewriter.create<Torch::AtenSumDimIntListOp>(
|
|
|
|
|
loc, op.getResultTypes()[2], gradOutput, gradIntList, cstFalse,
|
|
|
|
|
cstNone);
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
2022-11-04 15:57:29 +08:00
|
|
|
|
rewriter.replaceOp(op, {gradInput, gradWeight, gradBias});
|
2022-08-25 00:19:35 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
2021-11-11 17:02:13 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenAddmmOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value mat1 = op.getMat1();
|
|
|
|
|
Value mat2 = op.getMat2();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> mat1Rank = getTensorRank(mat1);
|
|
|
|
|
std::optional<unsigned> mat2Rank = getTensorRank(mat2);
|
2021-11-11 17:02:13 +08:00
|
|
|
|
|
|
|
|
|
// The operands `mat1`, `mat2` to aten.addmm must be of rank 2.
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!mat1Rank || !mat2Rank || *mat1Rank != 2 || *mat2Rank != 2) {
|
2021-11-11 17:02:13 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected mat1, mat2 operands to aten.addmm to be rank 2");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: Handle integer type operands.
|
2023-01-04 06:19:18 +08:00
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
2021-11-11 17:02:13 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: non-floating point dtype");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// matrix multiplication: matmul = mat1 @ mat2
|
|
|
|
|
Value matmul = rewriter.create<AtenMmOp>(loc, op.getType(), mat1, mat2);
|
|
|
|
|
// scaledInput = self * beta
|
|
|
|
|
Value scaledInput = rewriter.create<AtenMulScalarOp>(loc, input.getType(),
|
2022-12-08 04:20:41 +08:00
|
|
|
|
input, op.getBeta());
|
2021-11-11 17:02:13 +08:00
|
|
|
|
// result = scaledInput + alpha * matmul
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), scaledInput,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
matmul, op.getAlpha());
|
2021-11-11 17:02:13 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
// Decompose aten.mean into: sum(x)/div(numTensorElements).
|
2021-11-19 23:59:29 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMeanOp : public OpRewritePattern<AtenMeanOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMeanOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value output = op.getResult();
|
2021-11-19 23:59:29 +08:00
|
|
|
|
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
|
2022-12-08 01:51:37 +08:00
|
|
|
|
Value sum =
|
|
|
|
|
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.getDtype());
|
2021-11-19 23:59:29 +08:00
|
|
|
|
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
2022-12-08 01:51:37 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
|
|
|
|
|
numTensorElements);
|
2021-11-19 23:59:29 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-11 01:25:21 +08:00
|
|
|
|
// productDimSize = product(size(dim) for dim in dims)
|
|
|
|
|
// aten.mean(x, dims) = aten.sum(x, dims) / productDimSize.
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMeanDimOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value dimList = op.getDim();
|
|
|
|
|
Value keepDim = op.getKeepdim();
|
|
|
|
|
Value dtype = op.getDtype();
|
2022-12-08 01:51:37 +08:00
|
|
|
|
Type outputType = op.getType();
|
2022-03-11 01:25:21 +08:00
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
|
|
|
|
!isNoneOrFloatDtype(context, dtype)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only floating-point type is supported");
|
|
|
|
|
}
|
|
|
|
|
|
2022-08-03 00:08:06 +08:00
|
|
|
|
SmallVector<Value> dimListElements;
|
|
|
|
|
if (!getListConstructElements(dimList, dimListElements) &&
|
|
|
|
|
!dimList.getType().isa<Torch::NoneType>()) {
|
2022-03-11 01:25:21 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-08-03 00:08:06 +08:00
|
|
|
|
op, "expected `dim` to be `None` or constructed from list construct");
|
2022-03-11 01:25:21 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Compute sum along dimensions specified in `dimList`.
|
|
|
|
|
Value sumAlongDims = rewriter.create<AtenSumDimIntListOp>(
|
2022-12-08 01:51:37 +08:00
|
|
|
|
loc, outputType, input, dimList, keepDim, dtype);
|
2022-03-11 01:25:21 +08:00
|
|
|
|
|
|
|
|
|
// `productDimSize` is product of sizes of dimensions to be reduced.
|
2022-07-28 22:24:24 +08:00
|
|
|
|
Value productDimSize;
|
|
|
|
|
// Case: Reduce along all dims.
|
2022-08-03 00:08:06 +08:00
|
|
|
|
if (dimListElements.empty() && inputRank != 0) {
|
2022-07-28 22:24:24 +08:00
|
|
|
|
productDimSize = rewriter.create<AtenNumelOp>(loc, input);
|
|
|
|
|
} else {
|
|
|
|
|
productDimSize = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
2022-08-03 00:08:06 +08:00
|
|
|
|
for (Value dim : dimListElements) {
|
2022-07-28 22:24:24 +08:00
|
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
|
|
|
|
productDimSize =
|
|
|
|
|
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
|
|
|
|
}
|
2022-03-11 01:25:21 +08:00
|
|
|
|
}
|
2022-12-08 01:51:37 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
|
|
|
|
|
productDimSize);
|
2022-03-11 01:25:21 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-01-30 01:10:50 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSquareOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2022-01-30 01:10:50 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), self, self);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-02 01:30:58 +08:00
|
|
|
|
// Silu(x) = sigmoid(x) * x
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSiluOp : public OpRewritePattern<AtenSiluOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSiluOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2022-03-02 01:30:58 +08:00
|
|
|
|
Value sigmoid =
|
|
|
|
|
rewriter.create<AtenSigmoidOp>(op.getLoc(), op.getType(), self);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), sigmoid,
|
|
|
|
|
self);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-17 00:34:03 +08:00
|
|
|
|
// pDash = 1.0 - p
|
|
|
|
|
// boolMask = aten.rand_like(input) < pDash
|
|
|
|
|
// dropout(input, p, train=True) = (boolMask * input) / pDash
|
|
|
|
|
// dropout(input, p, train=False) = input
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenDropoutOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getInput();
|
|
|
|
|
Value prob = op.getP();
|
2022-02-17 00:34:03 +08:00
|
|
|
|
bool train = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train)))
|
2022-02-17 00:34:03 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"train must be a boolean constant");
|
|
|
|
|
if (!train) {
|
|
|
|
|
rewriter.replaceOp(op, input);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only support floating type input for training mode");
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value floatOne =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
|
|
|
|
|
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
|
|
|
|
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
|
|
|
|
|
Value maskedInput =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, op.getType(), maskedInput,
|
|
|
|
|
oneMinusP);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
2023-06-27 14:19:33 +08:00
|
|
|
|
|
|
|
|
|
class DeomposeAtenNativeDropoutOp
|
|
|
|
|
: public OpRewritePattern<AtenNativeDropoutOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNativeDropoutOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op->getContext();
|
|
|
|
|
Value input = op.getInput();
|
|
|
|
|
Value prob = op.getP();
|
|
|
|
|
bool train = false;
|
|
|
|
|
if (!op.getTrain().getType().isa<Torch::NoneType>()) {
|
|
|
|
|
if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "train must be a boolean constant or none");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
if (!train) {
|
|
|
|
|
Value i1Type =
|
|
|
|
|
getDtypeIntValueForType(rewriter, loc, IntegerType::get(context, 1));
|
|
|
|
|
Value inputSize = rewriter.create<AtenSizeOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), input);
|
|
|
|
|
Value trueValue = rewriter.create<ConstantIntOp>(loc, 1);
|
|
|
|
|
Value trueMask = rewriter.create<AtenFullOp>(
|
|
|
|
|
loc, op->getResultTypes()[1], inputSize, trueValue, i1Type,
|
|
|
|
|
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
|
|
|
|
|
rewriter.replaceOp(op, ArrayRef<Value>{input, trueMask});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only support floating type input for training mode");
|
|
|
|
|
}
|
|
|
|
|
Value floatOne =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
|
|
|
|
|
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
|
|
|
|
|
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
|
|
|
|
|
Value maskedInput =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
|
|
|
|
|
Value output = rewriter.create<AtenDivScalarOp>(
|
|
|
|
|
loc, op->getResultTypes()[0], maskedInput, oneMinusP);
|
|
|
|
|
rewriter.replaceOp(
|
|
|
|
|
op, ArrayRef<Value>{
|
|
|
|
|
output, convertTensorToDtype(rewriter, loc, boolMask,
|
|
|
|
|
IntegerType::get(context, 1))});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
2022-02-17 00:34:03 +08:00
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-06-29 15:23:57 +08:00
|
|
|
|
// Decompose aten.var into: aten.var.dim op.
|
2022-01-30 01:10:50 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenVarOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-01-30 01:10:50 +08:00
|
|
|
|
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
|
2022-03-10 08:44:22 +08:00
|
|
|
|
if (!rank0FloatTensorTy.hasSizes() ||
|
|
|
|
|
rank0FloatTensorTy.getSizes().size() != 0) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected aten.var to have a rank 0 tensor type");
|
|
|
|
|
}
|
2022-01-30 01:10:50 +08:00
|
|
|
|
|
2022-06-29 15:23:57 +08:00
|
|
|
|
SmallVector<Value> dims;
|
|
|
|
|
for (unsigned i = 0; i < inputRank; i++)
|
|
|
|
|
dims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dims);
|
2022-01-30 01:10:50 +08:00
|
|
|
|
|
2022-06-29 15:23:57 +08:00
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenVarDimOp>(op, rank0FloatTensorTy, self,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
dimList, op.getUnbiased(),
|
2022-06-29 15:23:57 +08:00
|
|
|
|
/*keepdim=*/cstFalse);
|
2022-01-30 01:10:50 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// Decompose aten.std to sqrt(var(x))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenStdOp : public OpRewritePattern<AtenStdOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenStdOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2022-01-30 01:10:50 +08:00
|
|
|
|
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputTensorTy.hasDtype() ||
|
|
|
|
|
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"Only aten.std support floating type");
|
|
|
|
|
}
|
|
|
|
|
Value var = rewriter.create<AtenVarOp>(op->getLoc(), op.getType(),
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getSelf(), op.getUnbiased());
|
2022-01-30 01:10:50 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), var);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-07-18 03:00:29 +08:00
|
|
|
|
// Softplus(x, beta, threshold) =
|
|
|
|
|
// x * beta > threshold ? x : log(1 + exp(x * beta)) / beta
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenSoftplusOp : public OpRewritePattern<AtenSoftplusOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSoftplusOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2022-07-18 03:00:29 +08:00
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
|
Value inputTimesBeta =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.getBeta());
|
2022-07-18 03:00:29 +08:00
|
|
|
|
|
|
|
|
|
// out = log1p(exp(input * beta)) / beta
|
|
|
|
|
Value exp = rewriter.create<AtenExpOp>(loc, inputType, inputTimesBeta);
|
|
|
|
|
Value log1p = rewriter.create<AtenLog1pOp>(loc, inputType, exp);
|
2022-07-22 20:42:14 +08:00
|
|
|
|
Value out =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.create<AtenDivScalarOp>(loc, inputType, log1p, op.getBeta());
|
2022-07-18 03:00:29 +08:00
|
|
|
|
|
|
|
|
|
// Select where x * beta > threshold
|
|
|
|
|
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
|
|
|
|
|
rewriter.getI1Type());
|
|
|
|
|
Value condition = rewriter.create<AtenGtScalarOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
loc, boolResType, inputTimesBeta, op.getThreshold());
|
2022-07-18 03:00:29 +08:00
|
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), condition,
|
|
|
|
|
input, out);
|
2022-07-18 03:00:29 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-07-27 04:02:01 +08:00
|
|
|
|
// Decompose aten.std.dim to sqrt(var.dim(x))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenStdDimOp : public OpRewritePattern<AtenStdDimOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenStdDimOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2022-07-27 04:02:01 +08:00
|
|
|
|
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputTensorType.hasDtype() ||
|
|
|
|
|
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "aten.std.dim expects input tensor of floating-point type");
|
|
|
|
|
}
|
|
|
|
|
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Value varDim = rewriter.create<AtenVarDimOp>(
|
|
|
|
|
op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(),
|
|
|
|
|
op.getKeepdim());
|
2022-07-27 04:02:01 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varDim);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-12-22 13:02:40 +08:00
|
|
|
|
// Decompose aten.std.correction to sqrt(var.correction(x))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenStdCorrectionOp
|
|
|
|
|
: public OpRewritePattern<AtenStdCorrectionOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenStdCorrectionOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputTensorType.hasDtype() ||
|
|
|
|
|
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op,
|
|
|
|
|
"aten.std.correction expects input tensor of floating-point type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value varCorrection = rewriter.create<AtenVarCorrectionOp>(
|
|
|
|
|
op->getLoc(), op.getType(), self, op.getDim(), op.getCorrection(),
|
|
|
|
|
op.getKeepdim());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varCorrection);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-14 22:46:44 +08:00
|
|
|
|
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenHardsigmoidOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2022-02-09 04:57:23 +08:00
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2022-02-14 22:46:44 +08:00
|
|
|
|
|
|
|
|
|
// outputTensor = (input + 3) / 6.
|
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(3));
|
|
|
|
|
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(6));
|
|
|
|
|
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
|
|
|
|
Value outputTensor = rewriter.create<AtenDivScalarOp>(
|
|
|
|
|
loc, inputType, inputPlusThree, constantSix);
|
|
|
|
|
|
|
|
|
|
// result = max(0, min(1, (input+3)/6))
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value oneTensor = createRank0Tensor(rewriter, loc, inputType, constantOne);
|
2022-02-14 22:46:44 +08:00
|
|
|
|
Value minResult =
|
|
|
|
|
rewriter.create<AtenMinimumOp>(loc, inputType, oneTensor, outputTensor);
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value zeroTensor =
|
|
|
|
|
createRank0Tensor(rewriter, loc, inputType, constantZero);
|
2022-02-14 22:46:44 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), zeroTensor,
|
|
|
|
|
minResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-09 04:57:23 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenHardtanhOp : public OpRewritePattern<AtenHardtanhOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenHardtanhOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2022-02-09 04:57:23 +08:00
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
2023-05-18 00:04:26 +08:00
|
|
|
|
auto resType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
|
|
|
|
}
|
2022-02-09 04:57:23 +08:00
|
|
|
|
|
|
|
|
|
// result = min(maxVal, max(minVal, x))
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal());
|
2022-02-09 04:57:23 +08:00
|
|
|
|
Value maxResult =
|
|
|
|
|
rewriter.create<AtenMaximumOp>(loc, inputType, input, minVal);
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value maxVal = createRank0Tensor(rewriter, loc, inputType, op.getMaxVal());
|
2022-02-09 04:57:23 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMinimumOp>(op, op.getType(), maxVal,
|
|
|
|
|
maxResult);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRandLikeOp : public OpRewritePattern<AtenRandLikeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRandLikeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2022-05-13 07:00:59 +08:00
|
|
|
|
Type resultType = op.getType();
|
2022-02-26 00:35:04 +08:00
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
2022-05-13 07:00:59 +08:00
|
|
|
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
2022-02-26 00:35:04 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"only support floating-point type");
|
|
|
|
|
}
|
|
|
|
|
|
2022-05-13 07:00:59 +08:00
|
|
|
|
// Create a uniform random op with low and high set to 0.0 and 1.0,
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// respectively.
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
2022-05-13 07:00:59 +08:00
|
|
|
|
Value zero =
|
2022-02-26 00:35:04 +08:00
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
2022-05-13 07:00:59 +08:00
|
|
|
|
Value one =
|
2022-02-26 00:35:04 +08:00
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
2022-09-23 10:24:36 +08:00
|
|
|
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
loc, resultType, input, zero, op.getDtype(), op.getLayout(),
|
|
|
|
|
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
|
2022-10-28 23:06:11 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenUniformOp>(op, resultType, emptyTensor,
|
|
|
|
|
/*from=*/zero, /*to=*/one,
|
|
|
|
|
/*generator=*/none);
|
2022-02-26 00:35:04 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Bernoulli(x, p) = (randLike(float(x)) < p).cast(type(x)). Here,
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// 1. p must be a float tensor.
|
|
|
|
|
// 2. The shape of p should be broadcastable to the shape of x.
|
|
|
|
|
// 3. Bernoulli(x, p) returns a tensor of the same type as that of x.
|
2022-02-09 04:57:23 +08:00
|
|
|
|
static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
|
|
|
|
Operation *op, Location loc,
|
2022-02-26 00:35:04 +08:00
|
|
|
|
Value input, Value prob,
|
|
|
|
|
Value &output) {
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
auto probType = prob.getType().cast<BaseTensorType>();
|
|
|
|
|
// Both the `input` and `prob` must be ranked tensors.
|
|
|
|
|
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
|
|
|
|
|
!probType.hasDtype()) {
|
2022-02-09 04:57:23 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-02-26 00:35:04 +08:00
|
|
|
|
op, "can't decompose bernoulli like ops without sizes or dtype");
|
|
|
|
|
}
|
|
|
|
|
// The `prob` is expected to be a float type tensor.
|
|
|
|
|
if (!probType.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "probabilities must be a float type tensor");
|
2022-02-09 04:57:23 +08:00
|
|
|
|
}
|
2022-02-04 19:43:25 +08:00
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Since the `aten.randLike` op expects float-type operand, create a
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// float-type tensor with the same shape as that of the `input`.
|
|
|
|
|
Value floatTensor =
|
|
|
|
|
convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type());
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value randomVal = rewriter.create<AtenRandLikeOp>(
|
|
|
|
|
loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
/*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none);
|
2022-02-26 00:35:04 +08:00
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Bernoulli(x, p) = randLike(float(x)) < p.
|
2022-02-26 00:35:04 +08:00
|
|
|
|
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
|
|
|
|
|
rewriter.getI1Type());
|
|
|
|
|
Value lessThanP =
|
|
|
|
|
rewriter.create<AtenLtTensorOp>(loc, boolResType, randomVal, prob);
|
|
|
|
|
|
|
|
|
|
// As the `output` is expected to be of the `input` type, convert the boolean
|
|
|
|
|
// tensor `lessThanP` to a `input` type tensor.
|
|
|
|
|
output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype());
|
2022-02-09 04:57:23 +08:00
|
|
|
|
return success();
|
2022-02-04 19:43:25 +08:00
|
|
|
|
}
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// aten.bernoulli(x) = randLike(x) < x. Here, the input x is a tensor
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// containing probabilities to be used for drawing the binary random number.
|
2022-02-04 19:43:25 +08:00
|
|
|
|
class DecomposeAtenBernoulliOp : public OpRewritePattern<AtenBernoulliOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenBernoulliOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
2022-02-04 19:43:25 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-12-20 12:59:19 +08:00
|
|
|
|
op, "The generator has to be None because only global default "
|
2022-02-04 19:43:25 +08:00
|
|
|
|
"generator is supported");
|
2022-02-26 00:35:04 +08:00
|
|
|
|
Value output;
|
|
|
|
|
if (failed(
|
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, input, output)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
|
rewriter.replaceOp(op, output);
|
2022-02-04 19:43:25 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)).
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// Since the input x can be an integer tensor, it's important to cast it to
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// float type before passing it to the `aten.randLike` op.
|
2023-02-16 01:06:29 +08:00
|
|
|
|
template <typename BernoulliLikeOp>
|
|
|
|
|
class DecomposeAtenBernoulliLikeOp : public OpRewritePattern<BernoulliLikeOp> {
|
2022-02-04 19:43:25 +08:00
|
|
|
|
public:
|
2023-02-16 01:06:29 +08:00
|
|
|
|
using OpRewritePattern<BernoulliLikeOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(BernoulliLikeOp op,
|
2022-02-04 19:43:25 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value p = op.getP();
|
2023-02-16 01:06:29 +08:00
|
|
|
|
if (!op.getGenerator().getType().template isa<Torch::NoneType>())
|
2022-02-26 00:35:04 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-12-20 12:59:19 +08:00
|
|
|
|
op, "The generator has to be None because only global default "
|
2022-02-26 00:35:04 +08:00
|
|
|
|
"generator is supported");
|
2022-02-04 19:43:25 +08:00
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
SmallVector<int64_t> empty;
|
2023-01-25 09:29:42 +08:00
|
|
|
|
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
|
2022-02-26 00:35:04 +08:00
|
|
|
|
rewriter.getF64Type());
|
|
|
|
|
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
|
|
|
|
|
Value output;
|
|
|
|
|
if (failed(
|
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
|
rewriter.replaceOp(op, output);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// aten.bernoulli.Tensor(x, p) = (randLike(float(x)) < p).cast(type(x)).
|
2022-02-26 00:35:04 +08:00
|
|
|
|
// Since the input x can be an integer tensor, it's important to cast it to
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// float type before passing it to the `aten.randLike` op.
|
2022-10-28 23:06:11 +08:00
|
|
|
|
class DecomposeAtenBernoulliTensorOp
|
|
|
|
|
: public OpRewritePattern<AtenBernoulliTensorOp> {
|
2022-02-26 00:35:04 +08:00
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2022-10-28 23:06:11 +08:00
|
|
|
|
LogicalResult matchAndRewrite(AtenBernoulliTensorOp op,
|
2022-02-26 00:35:04 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value prob = op.getP();
|
|
|
|
|
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
2022-02-04 19:43:25 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-12-20 12:59:19 +08:00
|
|
|
|
op, "The generator has to be None because only global default "
|
2022-02-04 19:43:25 +08:00
|
|
|
|
"generator is supported");
|
2022-02-26 00:35:04 +08:00
|
|
|
|
Value output;
|
|
|
|
|
if (failed(
|
|
|
|
|
decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "decomposeBernoulliLikeOp failed to decompose the op");
|
|
|
|
|
rewriter.replaceOp(op, output);
|
2022-02-04 19:43:25 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-12-28 12:33:18 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose exponential() to do inverse transform sampling.
|
|
|
|
|
// - https://en.wikipedia.org/wiki/Inverse_transform_sampling
|
|
|
|
|
// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus,
|
|
|
|
|
// exponential() = - ln(1 - uniform(0, 1)) / lambda.
|
|
|
|
|
class DecomposeAtenExponentialOp : public OpRewritePattern<AtenExponentialOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenExponentialOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "The generator has to be None because only global default "
|
|
|
|
|
"generator is supported");
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Type resultType = op.getType();
|
|
|
|
|
|
|
|
|
|
// Create a uniform random op with low and high set to 0.0 and 1.0,
|
|
|
|
|
// respectively.
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
|
|
|
|
loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
|
|
|
|
|
Value x = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensor,
|
|
|
|
|
/*from=*/zero, /*to=*/one,
|
|
|
|
|
/*generator=*/none);
|
|
|
|
|
|
|
|
|
|
Value negX = rewriter.create<AtenNegOp>(loc, resultType, x);
|
|
|
|
|
Value oneMinusX =
|
|
|
|
|
rewriter.create<AtenAddScalarOp>(loc, resultType, negX, one,
|
|
|
|
|
/*alpha=*/one);
|
|
|
|
|
Value lnOneMinusX = rewriter.create<AtenLogOp>(loc, resultType, oneMinusX);
|
|
|
|
|
Value negLambda = rewriter.create<AtenNegFloatOp>(loc, op.getLambd());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, resultType, lnOneMinusX,
|
|
|
|
|
negLambda);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2024-01-16 14:49:29 +08:00
|
|
|
|
// aten.normal_functional(mean, sigma) = randn() * sigma + mean.
|
|
|
|
|
class DecomposeAtenNormalFunctionalOp
|
|
|
|
|
: public OpRewritePattern<AtenNormalFunctionalOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "The generator has to be None because only global default "
|
|
|
|
|
"generator is supported");
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Type resultType = op.getType();
|
|
|
|
|
Value std = op.getStd();
|
|
|
|
|
Value mean = op.getMean();
|
|
|
|
|
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value randN = rewriter.create<AtenRandnLikeOp>(
|
|
|
|
|
loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
|
|
|
|
Value stdRandN =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
|
2024-01-30 01:59:33 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN, mean,
|
|
|
|
|
/*alpha=*/one);
|
2024-01-16 14:49:29 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
|
template <typename OpTy, typename T1T2Op>
|
2021-11-25 06:01:48 +08:00
|
|
|
|
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
2022-02-15 21:14:32 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
2021-11-25 06:01:48 +08:00
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value tensor1 = op.getTensor1();
|
|
|
|
|
Value tensor2 = op.getTensor2();
|
|
|
|
|
Value value = op.getValue();
|
2021-11-25 06:01:48 +08:00
|
|
|
|
|
2022-02-15 21:14:32 +08:00
|
|
|
|
Value product =
|
|
|
|
|
rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input,
|
|
|
|
|
product, value);
|
2021-11-25 06:01:48 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
2021-12-10 21:36:19 +08:00
|
|
|
|
|
|
|
|
|
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|
|
|
|
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLayerNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
auto input = op.getInput().getType().cast<BaseTensorType>();
|
2021-12-10 21:36:19 +08:00
|
|
|
|
if (!input.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input tensor should have known sizes.");
|
|
|
|
|
int64_t inputRank = input.getSizes().size();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value normalizedShape = op.getNormalizedShape();
|
2021-12-10 21:36:19 +08:00
|
|
|
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
|
|
|
|
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
2022-03-16 20:51:57 +08:00
|
|
|
|
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
|
|
|
|
std::vector<int64_t> meanVarSizes(inputRank, 1);
|
|
|
|
|
for (int i = 0; i < axis; i++)
|
|
|
|
|
meanVarSizes[i] = input.getSizes()[i];
|
2023-01-25 09:29:42 +08:00
|
|
|
|
auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes),
|
|
|
|
|
input.getOptionalDtype());
|
2021-12-10 21:36:19 +08:00
|
|
|
|
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
|
|
|
|
|
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
|
2021-12-10 21:36:19 +08:00
|
|
|
|
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
2021-11-25 06:01:48 +08:00
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-02-19 22:23:48 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenInstanceNormOp
|
|
|
|
|
: public OpRewritePattern<AtenInstanceNormOp> {
|
|
|
|
|
using OpRewritePattern<AtenInstanceNormOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenInstanceNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto context = op.getContext();
|
|
|
|
|
|
|
|
|
|
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
|
|
|
|
int64_t inputRank = inputTy.getSizes().size();
|
|
|
|
|
SmallVector<int64_t> reducedShape(inputTy.getSizes());
|
2024-02-29 02:27:19 +08:00
|
|
|
|
SmallVector<int64_t> reduceDimInts;
|
|
|
|
|
SmallVector<Value> reduceDimVals;
|
|
|
|
|
for (int i = 2; i < inputRank; ++i) {
|
|
|
|
|
reducedShape[i] = 1;
|
|
|
|
|
reduceDimVals.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
2024-02-19 22:23:48 +08:00
|
|
|
|
|
|
|
|
|
Type dtype = inputTy.getOptionalDtype();
|
|
|
|
|
Type reducedTy = ValueTensorType::get(op.getContext(),
|
|
|
|
|
llvm::ArrayRef(reducedShape), dtype);
|
|
|
|
|
|
|
|
|
|
auto sizeListType = ListType::get(IntType::get(context));
|
|
|
|
|
Value reduceDimList =
|
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
|
|
|
|
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
|
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
|
|
|
|
|
// mean(x)
|
|
|
|
|
Value inputMean = rewriter.create<AtenMeanDimOp>(
|
|
|
|
|
loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none);
|
|
|
|
|
|
|
|
|
|
// x - mean(x)
|
|
|
|
|
Value inputMeanExpanded =
|
|
|
|
|
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.getInput());
|
|
|
|
|
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
|
|
|
|
|
loc, inputTy, op.getInput(), inputMeanExpanded, one);
|
|
|
|
|
// (x - mean(x))^2
|
|
|
|
|
Value inputSubMeanSquare = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, inputTy, inputSubMean, inputSubMean);
|
|
|
|
|
|
|
|
|
|
Value variancesum = rewriter.create<AtenSumDimIntListOp>(
|
|
|
|
|
loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue,
|
|
|
|
|
/*dtype=*/none);
|
|
|
|
|
|
2024-02-29 02:27:19 +08:00
|
|
|
|
int64_t elemCount = 1;
|
|
|
|
|
for (int i = 2; i < inputRank; ++i)
|
|
|
|
|
elemCount *= inputTy.getSizes()[i];
|
|
|
|
|
|
2024-02-19 22:23:48 +08:00
|
|
|
|
Value hw = rewriter.create<Torch::ConstantIntOp>(
|
2024-02-29 02:27:19 +08:00
|
|
|
|
loc, rewriter.getI64IntegerAttr(elemCount));
|
2024-02-19 22:23:48 +08:00
|
|
|
|
Value inputVar =
|
|
|
|
|
rewriter.create<AtenDivScalarOp>(loc, reducedTy, variancesum, hw);
|
|
|
|
|
|
|
|
|
|
// rsqrt(var(x) + eps)
|
|
|
|
|
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, reducedTy, inputVar, op.getEps(), one);
|
|
|
|
|
Value inputRsqrtVar =
|
|
|
|
|
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);
|
|
|
|
|
|
|
|
|
|
// (x - mean(x)) * rsqrt(var(x) + eps)
|
|
|
|
|
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
|
|
|
|
|
loc, inputTy, inputRsqrtVar, op.getInput());
|
|
|
|
|
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, inputTy, inputSubMean, inputRsqrtVarExpanded);
|
|
|
|
|
Value out = rewriter.create<TensorStaticInfoCastOp>(
|
|
|
|
|
loc, op.getResult().getType(), inputNormalized);
|
|
|
|
|
|
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
|
auto weightTy = weight.getType().cast<BaseTensorType>();
|
|
|
|
|
dtype = weightTy.getOptionalDtype();
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> weightShape(weightTy.getSizes());
|
|
|
|
|
SmallVector<int64_t> newWeightShape;
|
|
|
|
|
newWeightShape.push_back(1);
|
|
|
|
|
newWeightShape.append(weightShape);
|
|
|
|
|
|
|
|
|
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Type newWeightTy = ValueTensorType::get(
|
|
|
|
|
op.getContext(), llvm::ArrayRef(newWeightShape), dtype);
|
|
|
|
|
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, zero);
|
|
|
|
|
|
2024-02-29 02:27:19 +08:00
|
|
|
|
while (static_cast<int64_t>(newWeightShape.size()) < inputRank) {
|
|
|
|
|
Value i = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(newWeightShape.size()));
|
|
|
|
|
newWeightShape.push_back(1);
|
|
|
|
|
newWeightTy = ValueTensorType::get(op.getContext(),
|
|
|
|
|
llvm::ArrayRef(newWeightShape), dtype);
|
|
|
|
|
weight = rewriter.create<AtenUnsqueezeOp>(loc, newWeightTy, weight, i);
|
|
|
|
|
}
|
2024-02-19 22:23:48 +08:00
|
|
|
|
|
|
|
|
|
Value weightExpanded =
|
|
|
|
|
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
|
|
|
|
|
|
|
|
|
|
Value bias = op.getBias();
|
|
|
|
|
auto biasTy = bias.getType().cast<BaseTensorType>();
|
|
|
|
|
dtype = biasTy.getOptionalDtype();
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> biasShape(biasTy.getSizes());
|
|
|
|
|
SmallVector<int64_t> newBiasShape;
|
|
|
|
|
newBiasShape.push_back(1);
|
|
|
|
|
newBiasShape.append(biasShape);
|
|
|
|
|
|
|
|
|
|
Type newBiasTy = ValueTensorType::get(op.getContext(),
|
|
|
|
|
llvm::ArrayRef(newBiasShape), dtype);
|
|
|
|
|
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, zero);
|
|
|
|
|
|
2024-02-29 02:27:19 +08:00
|
|
|
|
while (static_cast<int64_t>(newBiasShape.size()) < inputRank) {
|
|
|
|
|
Value i = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(newBiasShape.size()));
|
|
|
|
|
newBiasShape.push_back(1);
|
|
|
|
|
newBiasTy = ValueTensorType::get(op.getContext(),
|
|
|
|
|
llvm::ArrayRef(newBiasShape), dtype);
|
|
|
|
|
bias = rewriter.create<AtenUnsqueezeOp>(loc, newBiasTy, bias, i);
|
|
|
|
|
}
|
2024-02-19 22:23:48 +08:00
|
|
|
|
|
|
|
|
|
Value biasExpanded =
|
|
|
|
|
rewriter.create<AtenExpandAsOp>(loc, inputTy, bias, op.getInput());
|
|
|
|
|
|
|
|
|
|
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out,
|
|
|
|
|
weightExpanded);
|
|
|
|
|
out = rewriter.create<AtenAddTensorOp>(loc, out.getType(), out,
|
|
|
|
|
biasExpanded, one);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, out);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-09-02 09:29:22 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenNativeLayerNormOp
|
|
|
|
|
: public OpRewritePattern<AtenNativeLayerNormOp> {
|
|
|
|
|
using OpRewritePattern<AtenNativeLayerNormOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNativeLayerNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto context = op.getContext();
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
2022-09-02 09:29:22 +08:00
|
|
|
|
if (!inputTy.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input tensor should have known sizes.");
|
|
|
|
|
int64_t inputRank = inputTy.getSizes().size();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value normalizedShape = op.getNormalizedShape();
|
2022-09-02 09:29:22 +08:00
|
|
|
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
|
|
|
|
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
|
|
|
|
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
2023-11-01 11:56:54 +08:00
|
|
|
|
auto reduceDimInts =
|
|
|
|
|
llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
|
2022-09-02 09:29:22 +08:00
|
|
|
|
auto reducedTy = op.getResult(1).getType();
|
|
|
|
|
auto sizeListType = ListType::get(IntType::get(context));
|
|
|
|
|
|
|
|
|
|
// build reduce dims
|
|
|
|
|
SmallVector<Value> reduceDimVals;
|
|
|
|
|
reduceDimVals.reserve(reduceDimInts.size());
|
|
|
|
|
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
|
|
|
|
|
std::back_inserter(reduceDimVals), [&](int64_t d) {
|
|
|
|
|
return rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(d));
|
|
|
|
|
});
|
|
|
|
|
Value reduceDimList =
|
|
|
|
|
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
|
|
|
|
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
|
|
|
|
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
// mean(x)
|
|
|
|
|
Value inputMean = rewriter.create<AtenMeanDimOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none);
|
2022-09-02 09:29:22 +08:00
|
|
|
|
|
|
|
|
|
// x - mean(x)
|
|
|
|
|
Value inputMeanExpanded =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.getInput());
|
2022-09-02 09:29:22 +08:00
|
|
|
|
Value inputZeroMean = rewriter.create<AtenSubTensorOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
loc, inputTy, op.getInput(), inputMeanExpanded, one);
|
2022-09-02 09:29:22 +08:00
|
|
|
|
// var(x) = mean((x - mean(x))^2)
|
|
|
|
|
Value inputZeroMeanSquare = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, inputTy, inputZeroMean, inputZeroMean);
|
|
|
|
|
Value inputVar = rewriter.create<AtenMeanDimOp>(
|
|
|
|
|
loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none);
|
|
|
|
|
|
|
|
|
|
// rsqrt(var(x) + eps)
|
|
|
|
|
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
loc, reducedTy, inputVar, op.getEps(), one);
|
2022-09-02 09:29:22 +08:00
|
|
|
|
Value inputRsqrtVar =
|
|
|
|
|
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);
|
|
|
|
|
|
|
|
|
|
// (x - mean(x)) * rsqrt(var(x) + eps)
|
|
|
|
|
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
loc, inputTy, inputRsqrtVar, op.getInput());
|
2022-09-02 09:29:22 +08:00
|
|
|
|
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
|
|
|
|
|
Value out = rewriter.create<TensorStaticInfoCastOp>(
|
|
|
|
|
loc, op.getResult(0).getType(), inputNormalized);
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
|
Value bias = op.getBias();
|
2022-09-02 09:29:22 +08:00
|
|
|
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
|
|
|
|
|
}
|
|
|
|
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
out =
|
|
|
|
|
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar});
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-14 03:01:10 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.emptyLike` op into `aten.size` and `aten.empty` ops.
|
2021-12-14 03:01:10 +08:00
|
|
|
|
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenEmptyLikeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto sizeListType =
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
|
Value sizeList =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
|
2021-12-14 03:01:10 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, op.getType(), sizeList, op.getDtype(), op.getLayout(),
|
|
|
|
|
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
|
2021-12-14 03:01:10 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-23 21:22:45 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// The `aten.arange` op is converted to `aten.arange.startStep` op.
|
2021-12-23 21:22:45 +08:00
|
|
|
|
class DecomposeAtenArangeOp : public OpRewritePattern<AtenArangeOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenArangeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
// The AtenArangeOp doesn't have a start and step value. Therefore we set
|
|
|
|
|
// them as default values 0 and 1, respectively.
|
|
|
|
|
Value start, step;
|
|
|
|
|
start = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
|
rewriter.getI64IntegerAttr(1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, op.getType(), start, op.getEnd(), step, op.getDtype(),
|
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
2021-12-23 21:22:45 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// The `aten.arange.start` op is converted to `aten.arange.startStep` op.
|
2021-12-23 21:22:45 +08:00
|
|
|
|
class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenArangeStartOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
// The AtenArangeStartOp doesn't have a step value. Therefore we set it as
|
|
|
|
|
// default value 1.
|
|
|
|
|
Value step;
|
|
|
|
|
step = rewriter.create<Torch::ConstantIntOp>(loc,
|
|
|
|
|
rewriter.getI64IntegerAttr(1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(),
|
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
2021-12-23 21:22:45 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
|
namespace {
|
2022-09-23 10:24:36 +08:00
|
|
|
|
// Decompose constant tensor full like ops.
|
2021-12-21 19:51:19 +08:00
|
|
|
|
template <typename OpTy, int fillVal>
|
|
|
|
|
class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value constVal = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(fillVal));
|
2022-09-23 10:24:36 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenFullLikeOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op.getType(), op.getSelf(), constVal, op.getDtype(), op.getLayout(),
|
|
|
|
|
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
|
2021-12-21 19:51:19 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-12-13 11:05:12 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
|
|
|
|
|
using OpRewritePattern<AtenGroupNormOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenGroupNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
|
|
|
|
Value input = op.getInput();
|
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
|
Value bias = op.getBias();
|
|
|
|
|
Value numGroups = op.getNumGroups();
|
|
|
|
|
Value eps = op.getEps();
|
|
|
|
|
|
|
|
|
|
Value cstZero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value cstOne =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
|
|
|
|
|
|
|
|
|
Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
|
|
|
|
|
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
|
|
|
|
|
Value numElements = rewriter.create<AtenNumelOp>(loc, input);
|
|
|
|
|
Value numElementsDivN =
|
|
|
|
|
rewriter.create<AtenFloordivIntOp>(loc, numElements, N);
|
|
|
|
|
Value HxW = rewriter.create<AtenFloordivIntOp>(loc, numElementsDivN, C);
|
|
|
|
|
|
|
|
|
|
AtenNativeGroupNormOp newOp = rewriter.create<AtenNativeGroupNormOp>(
|
|
|
|
|
loc, ArrayRef<Type>{op.getResult().getType(), baseType, baseType},
|
|
|
|
|
input, weight, bias, N, C, HxW, numGroups, eps);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, newOp.getResult0());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenNativeGroupNormOp
|
|
|
|
|
: public OpRewritePattern<AtenNativeGroupNormOp> {
|
|
|
|
|
using OpRewritePattern<AtenNativeGroupNormOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNativeGroupNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
|
|
|
|
Value input = op.getInput();
|
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
|
Value bias = op.getBias();
|
|
|
|
|
Value numGroups = op.getGroup();
|
|
|
|
|
Value eps = op.getEps();
|
|
|
|
|
|
|
|
|
|
// Check the rank of the input/outputs tensor.
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
auto outputType = op.getResult0().getType().cast<BaseTensorType>();
|
|
|
|
|
auto meanType = op.getResult1().getType().cast<BaseTensorType>();
|
|
|
|
|
auto rsqrtVarType = op.getResult2().getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasSizes() || !outputType.hasSizes() ||
|
|
|
|
|
!meanType.hasSizes() || !rsqrtVarType.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input/outputs tensor should have known sizes.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value cstZero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value cstOne =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value cstNegtiveOne =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
|
|
|
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
|
|
|
|
|
|
|
|
|
// GroupNorm requires the channel dimension (C) to be exactly divisible by
|
|
|
|
|
// the number of groups.
|
|
|
|
|
Value channel = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
|
|
|
|
|
Value remainder =
|
|
|
|
|
rewriter.create<AtenRemainderIntOp>(loc, channel, numGroups);
|
|
|
|
|
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, cstZero);
|
|
|
|
|
rewriter.create<RuntimeAssertOp>(
|
|
|
|
|
loc, eqOrNot,
|
|
|
|
|
rewriter.getStringAttr("the number of channels must be divisible by "
|
|
|
|
|
"the number of groups"));
|
|
|
|
|
|
|
|
|
|
// Reshape the input tensor to (N, numGroups, -1) to apply normalization.
|
|
|
|
|
SmallVector<Value> newShape;
|
|
|
|
|
newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
|
|
|
|
|
newShape.push_back(numGroups);
|
|
|
|
|
newShape.push_back(cstNegtiveOne);
|
|
|
|
|
Value reshapedInput = rewriter.create<AtenViewOp>(
|
|
|
|
|
loc, baseType, input,
|
|
|
|
|
rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(IntType::get(context)), newShape));
|
|
|
|
|
|
|
|
|
|
// Now we proceed with the normalization steps across the 'groupSize'
|
|
|
|
|
// Compute the mean and variance for each group
|
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
ArrayRef<Value>{cstNegtiveOne});
|
|
|
|
|
auto mean = rewriter.create<AtenMeanDimOp>(
|
|
|
|
|
loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue,
|
|
|
|
|
/*dtype=*/none);
|
|
|
|
|
auto var = rewriter.create<AtenVarDimOp>(
|
|
|
|
|
loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse,
|
|
|
|
|
/*keepdim=*/cstTrue);
|
|
|
|
|
|
|
|
|
|
// Compute the normalized output: (input - mean) * rsqrt(var + eps)
|
|
|
|
|
auto varPlusEps = rewriter.create<AtenAddScalarOp>(loc, baseType, var, eps,
|
|
|
|
|
/*alpha=*/cstOne);
|
|
|
|
|
auto invStd = rewriter.create<AtenRsqrtOp>(loc, baseType, varPlusEps);
|
|
|
|
|
auto inputSubMean = rewriter.create<AtenSubTensorOp>(
|
|
|
|
|
loc, baseType, reshapedInput, mean, /*alpha=*/cstOne);
|
|
|
|
|
auto normalizedOutput =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
|
|
|
|
|
|
|
|
|
|
// Reshape normalized output back to the original input shape
|
|
|
|
|
auto inputShape = rewriter.create<AtenSizeOp>(
|
|
|
|
|
loc, Torch::ListType::get(IntType::get(context)), input);
|
|
|
|
|
auto reshapedOutput = rewriter.create<AtenViewOp>(
|
|
|
|
|
loc, inputType, normalizedOutput, /*shape=*/inputShape);
|
|
|
|
|
|
|
|
|
|
// Apply weight and bias if they are not None
|
|
|
|
|
// Reshape weight and bias to C,1,1,...
|
|
|
|
|
SmallVector<Value> viewShape = {channel};
|
|
|
|
|
for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
|
|
|
|
|
viewShape.push_back(cstOne);
|
|
|
|
|
}
|
|
|
|
|
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, ListType::get(IntType::get(context)), viewShape);
|
|
|
|
|
|
|
|
|
|
Value groupNormOutput = reshapedOutput;
|
|
|
|
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
auto weightReshaped = rewriter.create<AtenViewOp>(
|
|
|
|
|
loc, baseType, weight, /*shape=*/viewShapeSizeList);
|
|
|
|
|
groupNormOutput = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, inputType, groupNormOutput, weightReshaped);
|
|
|
|
|
}
|
|
|
|
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
auto biasReshaped = rewriter.create<AtenViewOp>(
|
|
|
|
|
loc, baseType, bias, /*shape=*/viewShapeSizeList);
|
|
|
|
|
groupNormOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
|
loc, inputType, groupNormOutput, biasReshaped,
|
|
|
|
|
/*alpha=*/cstOne);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value squeezedMean =
|
|
|
|
|
rewriter.create<AtenSqueezeDimOp>(loc, meanType, mean, cstNegtiveOne);
|
|
|
|
|
Value squeezedRsqrtVar = rewriter.create<AtenSqueezeDimOp>(
|
|
|
|
|
loc, rsqrtVarType, invStd, cstNegtiveOne);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(
|
|
|
|
|
op, ArrayRef<Value>{groupNormOutput, squeezedMean, squeezedRsqrtVar});
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-08 00:08:10 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenNativeBatchNormOp
|
|
|
|
|
: public OpRewritePattern<AtenNativeBatchNormOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNativeBatchNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getInput();
|
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
|
Value bias = op.getBias();
|
|
|
|
|
Value runningMean = op.getRunningMean();
|
|
|
|
|
Value runningVar = op.getRunningVar();
|
|
|
|
|
Value eps = op.getEps();
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
|
|
// TODO: Add support for `training` mode.
|
|
|
|
|
bool training = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training)) ||
|
2022-02-08 00:08:10 +08:00
|
|
|
|
training)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: training mode is not supported");
|
|
|
|
|
|
|
|
|
|
// Rank of the input tensor must be greater than or equal to 2. The shape of
|
|
|
|
|
// the `input` is supposed to be (N, C, D?, H?, W?).
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeInputRank || *maybeInputRank < 2)
|
2022-02-08 00:08:10 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input must have rank greater than or equal to 2");
|
2022-12-13 00:56:28 +08:00
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
|
|
// In the inference mode, the `runningMean` and `runningVar` must not be
|
|
|
|
|
// None.
|
|
|
|
|
if (runningMean.getType().isa<Torch::NoneType>() ||
|
|
|
|
|
runningVar.getType().isa<Torch::NoneType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "running stats must not be None in inference mode");
|
|
|
|
|
|
|
|
|
|
// Rank of `runningMean` and `runningVar` must be exactly 1.
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> runningMeanRank = getTensorRank(runningMean);
|
|
|
|
|
std::optional<unsigned> runningVarRank = getTensorRank(runningVar);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!runningMeanRank || !runningVarRank || *runningMeanRank != 1 ||
|
|
|
|
|
*runningVarRank != 1)
|
2022-02-08 00:08:10 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, "expected runningMean and runningVar to be rank 1");
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value numFeatures = rewriter.create<AtenSizeIntOp>(loc, input, /*dim=*/one);
|
2022-02-25 03:41:55 +08:00
|
|
|
|
// TODO: Add Runtime Asserts to check the shape of weight, bias,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// runningMean and runningVar to be (numFeatures).
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
|
|
// The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?)
|
|
|
|
|
// to make it broadcast-compatible with (N, C, D?, H?, W?).
|
|
|
|
|
// 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?)
|
|
|
|
|
// 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?)
|
|
|
|
|
SmallVector<Value> runningStatsShape(inputRank, one);
|
|
|
|
|
runningStatsShape[1] = numFeatures;
|
|
|
|
|
Value runningStatsSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, ListType::get(IntType::get(context)), runningStatsShape);
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
2023-11-01 11:56:54 +08:00
|
|
|
|
runningStatsShapeInt[1] =
|
|
|
|
|
runningMean.getType().cast<BaseTensorType>().getSizes()[0];
|
2023-01-04 06:19:18 +08:00
|
|
|
|
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
2022-02-08 00:08:10 +08:00
|
|
|
|
Type reshapeType = ValueTensorType::get(
|
2023-01-25 09:29:42 +08:00
|
|
|
|
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
|
|
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
runningVar = rewriter.create<AtenViewOp>(loc, reshapeType, runningVar,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
|
|
|
|
|
// normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)).
|
|
|
|
|
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
|
|
|
|
|
loc, input.getType(), input, runningMean, /*alpha=*/one);
|
|
|
|
|
Value varEps = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, runningVar.getType(), runningVar, eps, /*alpha=*/one);
|
|
|
|
|
Value invStd = rewriter.create<AtenRsqrtOp>(loc, varEps.getType(), varEps);
|
|
|
|
|
Value normalizedInput = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, inputSubMean.getType(), inputSubMean, invStd);
|
|
|
|
|
|
|
|
|
|
// The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it
|
|
|
|
|
// broadcast-compatible with (N, C, D?, H?, W?).
|
|
|
|
|
// 1. weight = weight.view(1, C, 1?, 1?, 1?)
|
|
|
|
|
// 2. bias = bias.view(1, C, 1?, 1?, 1?)
|
|
|
|
|
// 3. output = normalizedInput * weight + bias
|
|
|
|
|
Value batchNormOutput = normalizedInput;
|
|
|
|
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
2022-02-25 03:41:55 +08:00
|
|
|
|
// Rank of `weight` must be exactly 1.
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> weightRank = getTensorRank(weight);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!weightRank || *weightRank != 1)
|
2022-02-08 00:08:10 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
|
|
|
|
|
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
batchNormOutput = rewriter.create<AtenMulTensorOp>(
|
|
|
|
|
loc, batchNormOutput.getType(), batchNormOutput, weight);
|
|
|
|
|
}
|
|
|
|
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
2022-02-25 03:41:55 +08:00
|
|
|
|
// Rank of `bias` must be exactly 1.
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> biasRank = getTensorRank(bias);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!biasRank || *biasRank != 1)
|
2022-02-08 00:08:10 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
|
|
|
|
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
|
batchNormOutput = rewriter.create<AtenAddTensorOp>(
|
|
|
|
|
loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The `mean` and `invstd` outputs are empty tensors in inference mode.
|
|
|
|
|
Value zeroList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(zero.getType()), zero);
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value emptyMeanTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
/*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none);
|
2022-02-08 00:08:10 +08:00
|
|
|
|
Value emptyInvStdTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
/*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none);
|
2022-02-08 00:08:10 +08:00
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op,
|
|
|
|
|
{batchNormOutput, emptyMeanTensor, emptyInvStdTensor});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompse `Aten_UnsafeViewOp` into `AtenViewOp`. UnsafeView() differs from
|
2022-02-10 16:11:05 +08:00
|
|
|
|
// view() in that the returned tensor isn't treated as a view for the purposes
|
|
|
|
|
// of automatic differentiation. It's only safe to use if the `self` tensor is
|
|
|
|
|
// temporary. For example, the viewed tensor here (a + b) is discarded
|
|
|
|
|
// immediately after viewing:
|
|
|
|
|
//
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// res = UnsafeView(a + b, size);
|
2022-02-10 16:11:05 +08:00
|
|
|
|
//
|
|
|
|
|
// This is a hack because in-place operations on tensors treated like views
|
|
|
|
|
// can be much more expensive than the same operations on non-view tensors.
|
|
|
|
|
|
|
|
|
|
// Refer to
|
|
|
|
|
// https://github.com/pytorch/pytorch/blob/364055b2771ecf9b54f1d67a8bf44bb5496476d4/aten/src/ATen/native/TensorShape.cpp#L2072
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_UnsafeViewOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.getSelf(),
|
|
|
|
|
op.getSize());
|
2022-02-10 16:11:05 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// In PyTorch, ReshapeAlias just uses an already computed stride.
|
2022-03-29 12:54:28 +08:00
|
|
|
|
// See
|
|
|
|
|
// https://github.com/pytorch/pytorch/blob/d8c31a819d4a65e732b5901e3b994e1869851f1a/aten/src/ATen/native/TensorShape.cpp#L1153
|
|
|
|
|
// Note that this is the same decomposition as in AOTAutograd
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/Src/aotAutograd.py#L104
|
2022-03-29 12:54:28 +08:00
|
|
|
|
namespace {
|
2022-05-13 20:06:24 +08:00
|
|
|
|
class DecomposeAten_ReshapeAliasOp
|
|
|
|
|
: public OpRewritePattern<Aten_ReshapeAliasOp> {
|
2022-03-29 12:54:28 +08:00
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.getSelf(),
|
|
|
|
|
op.getSize());
|
2022-03-29 12:54:28 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-02-28 14:14:40 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose constant tensor like ops.
|
|
|
|
|
template <typename OpTy, typename NewOpTy>
|
|
|
|
|
class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
|
|
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(OpTy op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value dtype = op.getDtype();
|
2022-03-25 00:40:21 +08:00
|
|
|
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
BaseTensorType tensorType =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getSelf().getType().template cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
|
if (!tensorType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input tensor to have a dtype");
|
|
|
|
|
}
|
2022-03-25 00:40:21 +08:00
|
|
|
|
dtype =
|
|
|
|
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), op.getSize(), dtype,
|
|
|
|
|
op.getLayout(), op.getDevice(),
|
|
|
|
|
op.getPinMemory());
|
2022-02-28 14:14:40 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-03 21:41:14 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.full` op into `aten.broadcastTo`
|
2022-03-03 21:41:14 +08:00
|
|
|
|
class DecomposeAtenFullOp : public OpRewritePattern<AtenFullOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenFullOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-09-23 10:24:36 +08:00
|
|
|
|
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
|
if (!outTy.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
2022-09-23 10:24:36 +08:00
|
|
|
|
SmallVector<int64_t> empty;
|
|
|
|
|
auto dtype =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
2023-01-25 09:29:42 +08:00
|
|
|
|
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
|
2022-09-23 10:24:36 +08:00
|
|
|
|
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getFillValue());
|
2022-09-23 10:24:36 +08:00
|
|
|
|
fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), fillVal,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getSize());
|
2022-03-03 21:41:14 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-05-19 00:29:04 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops.
|
|
|
|
|
class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLinearOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getInput();
|
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
|
Value bias = op.getBias();
|
2022-05-19 00:29:04 +08:00
|
|
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input to be rank 2 or greater");
|
|
|
|
|
|
|
|
|
|
BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
|
|
|
|
|
// `weight` must be a rank 2 matrix.
|
|
|
|
|
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> transposeShape =
|
|
|
|
|
llvm::to_vector(llvm::reverse(weightType.getSizes()));
|
|
|
|
|
Type transposeType = weightType.getWithSizesAndDtype(
|
2023-01-25 09:29:42 +08:00
|
|
|
|
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
|
2022-05-19 00:29:04 +08:00
|
|
|
|
Value transposeWeight =
|
|
|
|
|
rewriter.create<AtenTOp>(loc, transposeType, weight);
|
|
|
|
|
|
|
|
|
|
Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
|
|
|
|
|
transposeWeight);
|
|
|
|
|
if (bias.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
rewriter.replaceOp(op, matmul);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
|
|
|
|
|
|
|
|
|
Value alpha =
|
|
|
|
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getBias(), alpha);
|
2022-05-19 00:29:04 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-10-12 05:03:10 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.mish` op into `aten.tanh` and `aten.softplus` ops.
|
|
|
|
|
// Mish(x) = x * Tanh(Softplus(x))
|
|
|
|
|
class DecomposeAtenMishOp : public OpRewritePattern<AtenMishOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMishOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2022-10-12 05:03:10 +08:00
|
|
|
|
Type type = op.getType();
|
|
|
|
|
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasDtype())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Dtype not present");
|
|
|
|
|
|
|
|
|
|
Type dType = inputType.getDtype();
|
|
|
|
|
// Form default Value tensors for `beta` and `threshold` operands
|
|
|
|
|
// of `aten.softplus` op.
|
|
|
|
|
Value beta = getConstantWithGivenDtypeAndValue(rewriter, loc, 1.0, dType);
|
|
|
|
|
Value threshold =
|
|
|
|
|
getConstantWithGivenDtypeAndValue(rewriter, loc, 20.0, dType);
|
|
|
|
|
Value softplusOp =
|
|
|
|
|
rewriter.create<AtenSoftplusOp>(loc, type, input, beta, threshold);
|
|
|
|
|
Value tanhOp = rewriter.create<AtenTanhOp>(loc, type, softplusOp);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, type, input, tanhOp);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-03 22:25:22 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.fullLike` op into `aten.emptyLike` and `aten.fill` ops.
|
2022-03-03 22:25:22 +08:00
|
|
|
|
class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenFullLikeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-09-23 10:24:36 +08:00
|
|
|
|
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
|
if (!outTy.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
2022-09-23 10:24:36 +08:00
|
|
|
|
SmallVector<int64_t> empty;
|
|
|
|
|
auto dtype =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
2023-01-25 09:29:42 +08:00
|
|
|
|
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
|
2022-09-23 10:24:36 +08:00
|
|
|
|
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getLoc(), tensorType, op.getFillValue());
|
2022-09-23 10:24:36 +08:00
|
|
|
|
fillVal =
|
|
|
|
|
convertTensorToDtype(rewriter, op.getLoc(), fillVal, outTy.getDtype());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), fillVal,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getSelf());
|
2022-03-03 22:25:22 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-09-12 22:29:08 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.new_full` op into `aten.full` op.
|
|
|
|
|
class DecomposeAtenNewFullOp : public OpRewritePattern<AtenNewFullOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNewFullOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value dtype = op.getDtype();
|
|
|
|
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
|
|
|
|
if (!tensorType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input tensor to have a dtype");
|
|
|
|
|
}
|
|
|
|
|
dtype =
|
|
|
|
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenFullOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, op.getType(), op.getSize(), op.getFillValue(), dtype,
|
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
2023-09-12 22:29:08 +08:00
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-10 23:18:08 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op.
|
2022-03-10 23:18:08 +08:00
|
|
|
|
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIndexPutOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
2022-10-28 23:06:11 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
|
|
|
|
op.getAccumulate(),
|
2022-03-10 23:18:08 +08:00
|
|
|
|
/*unsafe=*/cstFalse);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-14 16:12:37 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenExpandAsOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
auto sizeListType =
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
|
Value sizeList =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getOther());
|
2023-11-01 11:56:54 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(),
|
|
|
|
|
op.getSelf(), sizeList);
|
2022-03-14 16:12:37 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-17 21:35:17 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.ToCopy` op into `valsem.aten.copy` op.
|
2022-03-17 21:35:17 +08:00
|
|
|
|
class DecomposeAten_ToCopyOp : public OpRewritePattern<Aten_ToCopyOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2023-01-04 06:19:18 +08:00
|
|
|
|
auto resultType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resultType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
|
|
|
|
Type resultDtype = resultType.getDtype();
|
2022-10-04 21:05:59 +08:00
|
|
|
|
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
|
|
|
|
|
resultDtype);
|
2022-09-23 10:24:36 +08:00
|
|
|
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(),
|
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory(),
|
|
|
|
|
op.getMemoryFormat());
|
2022-10-28 23:06:11 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenCopyOp>(op, op.getType(), emptyTensor,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op.getSelf(), op.getNonBlocking());
|
2022-03-17 21:35:17 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-12-22 10:13:59 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.copy` op into `aten.to.dtype` and `aten.expand_as`.
|
|
|
|
|
class DecomposeAtenCopyOp : public OpRewritePattern<AtenCopyOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenCopyOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2023-01-04 06:19:18 +08:00
|
|
|
|
auto resultType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!resultType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
2023-07-18 15:52:33 +08:00
|
|
|
|
auto srcTy = op.getSrc().getType().cast<BaseTensorType>();
|
|
|
|
|
if (!srcTy.hasSizes() || !srcTy.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected src type to have a known rank and dtype");
|
|
|
|
|
}
|
2023-01-04 06:19:18 +08:00
|
|
|
|
Type resultDtype = resultType.getDtype();
|
2022-12-22 10:13:59 +08:00
|
|
|
|
Value srcToDtype =
|
|
|
|
|
convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), srcToDtype,
|
|
|
|
|
op.getSelf());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-25 00:40:21 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.newEmpty` op into `aten.empty.memoryFormat` op.
|
2022-03-25 00:40:21 +08:00
|
|
|
|
class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value dtype = op.getDtype();
|
2022-03-25 00:40:21 +08:00
|
|
|
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
2023-01-04 06:19:18 +08:00
|
|
|
|
if (!tensorType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input tensor to have a dtype");
|
|
|
|
|
}
|
2022-03-25 00:40:21 +08:00
|
|
|
|
dtype =
|
|
|
|
|
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(),
|
|
|
|
|
op.getPinMemory(), /*memoryFormat=*/noneVal);
|
2022-03-25 00:40:21 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-03-24 15:12:59 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.indexPut.hackedTwin` op into `valsem.aten.indexPutImpl`
|
2022-03-24 15:12:59 +08:00
|
|
|
|
// op.
|
|
|
|
|
class DecomposeAtenIndexPutHackedTwinOp
|
|
|
|
|
: public OpRewritePattern<AtenIndexPutHackedTwinOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
2022-10-28 23:06:11 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
|
|
|
|
op.getAccumulate(),
|
2022-03-24 15:12:59 +08:00
|
|
|
|
/*unsafe=*/cstFalse);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-07-14 15:26:54 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl`
|
|
|
|
|
// op.
|
|
|
|
|
class DecomposeAten_UnsafeIndexPutHackedTwinOp
|
|
|
|
|
: public OpRewritePattern<Aten_UnsafeIndexPutHackedTwinOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
|
|
|
|
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
|
|
|
|
op.getAccumulate(),
|
|
|
|
|
/*unsafe=*/cstFalse);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-04-26 20:18:09 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.pad` op into `aten.constantPadNd` op.
|
2022-04-26 20:18:09 +08:00
|
|
|
|
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenPadOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value value = op.getValue();
|
2022-04-26 20:18:09 +08:00
|
|
|
|
if (value.getType().isa<Torch::OptionalType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
|
|
|
|
if (value.getType().isa<Torch::NoneType>())
|
|
|
|
|
value = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
op.getLoc(), rewriter.getF64FloatAttr(0));
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op.getType(), op.getSelf(), op.getPad(), value);
|
2022-04-26 20:18:09 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-04-27 19:07:40 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.to.dtypeLayout` op into `aten.to.dtype` op.
|
2022-04-27 19:07:40 +08:00
|
|
|
|
class DecomposeAtenToDtypeLayoutOp
|
|
|
|
|
: public OpRewritePattern<AtenToDtypeLayoutOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// TODO: Add support for pinMemory arg equal to `True`.
|
|
|
|
|
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
|
2022-04-27 19:07:40 +08:00
|
|
|
|
bool pinMemory;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
2022-04-27 19:07:40 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, "unimplemented: pinMemory must be a constant");
|
2022-04-27 19:07:40 +08:00
|
|
|
|
else if (pinMemory)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, "unimplemented: pinMemory is expected to be false");
|
2022-04-27 19:07:40 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-09-25 22:00:19 +08:00
|
|
|
|
// TODO: Add support for device arg other than cpu.
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
2023-09-25 22:00:19 +08:00
|
|
|
|
std::string device;
|
|
|
|
|
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: device must be a constant str");
|
|
|
|
|
else if (device != "cpu")
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: device is expected to be cpu");
|
2022-04-27 19:07:40 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: Add support for non-strided layout.
|
|
|
|
|
// torch.layout is by default strided i.e. 0.
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
2022-04-27 19:07:40 +08:00
|
|
|
|
int64_t tensorLayout;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
2022-04-27 19:07:40 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: layout must be a constant");
|
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: layout is expected to be strided");
|
|
|
|
|
}
|
|
|
|
|
|
2023-11-01 11:56:54 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
|
|
|
|
|
op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(),
|
|
|
|
|
op.getCopy(), op.getMemoryFormat());
|
2022-04-27 19:07:40 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-04-10 22:26:48 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.to.prim_Device` op into `aten.to.dtype` op.
|
|
|
|
|
class DecomposeAtenToPrimDeviceOp
|
|
|
|
|
: public OpRewritePattern<AtenToPrimDeviceOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenToPrimDeviceOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
// Device information isn't relevant to torch-mlir, so we can drop that info
|
|
|
|
|
// here.
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
|
|
|
|
|
Value dtype = op.getDtype();
|
|
|
|
|
if (dtype.getType().template isa<Torch::NoneType>()) {
|
|
|
|
|
dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf());
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
|
|
|
|
dtype, op.getNonBlocking(),
|
|
|
|
|
op.getCopy(), constNone);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-08-11 07:24:02 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.to.device` op into `aten.to.dtype` op.
|
|
|
|
|
class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenToDeviceOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
// Device information isn't relevant to torch-mlir, so we can drop that info
|
|
|
|
|
// here.
|
2023-11-01 11:56:54 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
|
|
|
|
|
op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(),
|
|
|
|
|
op.getCopy(), op.getMemoryFormat());
|
2022-08-11 07:24:02 +08:00
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
[Torch Dialect] add support for adaptive_avgpool_1d (#2342)
* [MLIR][TORCH] Fix aten.cumsum lowering for int32 input (#2351)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
update PyTorch version to 2.1.0.dev20230729 (#2354)
- torch version: 2.1.0.dev20230729
- torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
- torchvision version: 0.16.0.dev20230729
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230730 (#2356)
- torch version: 2.1.0.dev20230730
- torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
- torchvision version: 0.16.0.dev20230730
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230731 (#2359)
- torch version: 2.1.0.dev20230731
- torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
- torchvision version: 0.16.0.dev20230731
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support
* SW-95317 Propagate Lazy->Jit->MLIR scope name.
* Enhance location information based on op names
Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.
* Update locations logic; updated debug-info.py test
* Use {scope}/{op_name} format to track names by default
---------
Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
build: update llvm tag to 41895843
Summary of changes:
- Update tags
llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
update PyTorch version to 2.1.0.dev20230802 (#2366)
- torch version: 2.1.0.dev20230802
- torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
- torchvision version: 0.16.0.dev20230802
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Change Python version from 3.10 to 3.11 in installation instructions (#2370)
Add CITATION file (#2371)
Add packaging as an install dependency (#2369)
Needed by `torch_mlir._version`. Resolves #2368.
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
update PyTorch version to 2.1.0.dev20230803 (#2372)
- torch version: 2.1.0.dev20230803
- torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
- torchvision version: 0.16.0.dev20230803
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Prevent failed stable CI job from cancelling nightly jobs (#2373)
The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)
update
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update
update:
update
parent 22e88d523b1970b2e904eb5421d49d987a3d255e
author jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114110 +0800
committer jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114119 +0800
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
update PyTorch version to 2.1.0.dev20230729 (#2354)
- torch version: 2.1.0.dev20230729
- torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
- torchvision version: 0.16.0.dev20230729
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230730 (#2356)
- torch version: 2.1.0.dev20230730
- torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
- torchvision version: 0.16.0.dev20230730
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230731 (#2359)
- torch version: 2.1.0.dev20230731
- torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
- torchvision version: 0.16.0.dev20230731
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support
* SW-95317 Propagate Lazy->Jit->MLIR scope name.
* Enhance location information based on op names
Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.
* Update locations logic; updated debug-info.py test
* Use {scope}/{op_name} format to track names by default
---------
Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
build: update llvm tag to 41895843
Summary of changes:
- Update tags
llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
update PyTorch version to 2.1.0.dev20230802 (#2366)
- torch version: 2.1.0.dev20230802
- torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
- torchvision version: 0.16.0.dev20230802
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Change Python version from 3.10 to 3.11 in installation instructions (#2370)
Add CITATION file (#2371)
Add packaging as an install dependency (#2369)
Needed by `torch_mlir._version`. Resolves #2368.
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
update PyTorch version to 2.1.0.dev20230803 (#2372)
- torch version: 2.1.0.dev20230803
- torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
- torchvision version: 0.16.0.dev20230803
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Prevent failed stable CI job from cancelling nightly jobs (#2373)
The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)
update
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update
update:
add support for adaptive_pool_id
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update:
* update
---------
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-08-05 07:48:09 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
|
|
|
|
|
|
|
|
|
|
// The logic of this decomposition is totally same with
|
|
|
|
|
// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two
|
|
|
|
|
// cases are supported:
|
|
|
|
|
// 1. inputSize = outputSize
|
|
|
|
|
// 2. outputSize = 1
|
|
|
|
|
class DecomposeAtenAdaptiveAvgPool1dOp
|
|
|
|
|
: public OpRewritePattern<AtenAdaptiveAvgPool1dOp> {
|
|
|
|
|
using OpRewritePattern<AtenAdaptiveAvgPool1dOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(input);
|
|
|
|
|
if (!maybeRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned rank = *maybeRank;
|
|
|
|
|
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(rank - 1));
|
|
|
|
|
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
|
|
|
|
|
|
|
|
|
|
Value outputShape = op.getOutputSize();
|
|
|
|
|
SmallVector<Value> outputShapeSizesTorchInt;
|
|
|
|
|
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
|
|
|
|
Value outputSize = outputShapeSizesTorchInt[0];
|
|
|
|
|
|
|
|
|
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
|
|
|
|
|
|
|
|
|
int64_t outputSizeInt;
|
|
|
|
|
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "the output size of adaptive_pool_1d must be a constant int");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<Value, 1> kernelSize;
|
|
|
|
|
if (outputSizeInt == 1) {
|
|
|
|
|
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
|
|
|
|
kernelSize.push_back(
|
|
|
|
|
inputShape[rank - 1] == kUnknownSize
|
|
|
|
|
? inputSize
|
|
|
|
|
: rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
|
|
|
|
|
} else {
|
2023-09-30 07:45:48 +08:00
|
|
|
|
if (!isAssumingStrictSymbolicShapes(rewriter)) {
|
|
|
|
|
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
|
|
|
|
|
rewriter.create<RuntimeAssertOp>(
|
|
|
|
|
loc, cond,
|
|
|
|
|
"unimplemented: only support cases where input and output size are "
|
|
|
|
|
"equal for non-unit output size");
|
|
|
|
|
}
|
[Torch Dialect] add support for adaptive_avgpool_1d (#2342)
* [MLIR][TORCH] Fix aten.cumsum lowering for int32 input (#2351)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
update PyTorch version to 2.1.0.dev20230729 (#2354)
- torch version: 2.1.0.dev20230729
- torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
- torchvision version: 0.16.0.dev20230729
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230730 (#2356)
- torch version: 2.1.0.dev20230730
- torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
- torchvision version: 0.16.0.dev20230730
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230731 (#2359)
- torch version: 2.1.0.dev20230731
- torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
- torchvision version: 0.16.0.dev20230731
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support
* SW-95317 Propagate Lazy->Jit->MLIR scope name.
* Enhance location information based on op names
Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.
* Update locations logic; updated debug-info.py test
* Use {scope}/{op_name} format to track names by default
---------
Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
build: update llvm tag to 41895843
Summary of changes:
- Update tags
llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
update PyTorch version to 2.1.0.dev20230802 (#2366)
- torch version: 2.1.0.dev20230802
- torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
- torchvision version: 0.16.0.dev20230802
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Change Python version from 3.10 to 3.11 in installation instructions (#2370)
Add CITATION file (#2371)
Add packaging as an install dependency (#2369)
Needed by `torch_mlir._version`. Resolves #2368.
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
update PyTorch version to 2.1.0.dev20230803 (#2372)
- torch version: 2.1.0.dev20230803
- torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
- torchvision version: 0.16.0.dev20230803
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Prevent failed stable CI job from cancelling nightly jobs (#2373)
The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)
update
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update
update:
update
parent 22e88d523b1970b2e904eb5421d49d987a3d255e
author jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114110 +0800
committer jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114119 +0800
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
update PyTorch version to 2.1.0.dev20230729 (#2354)
- torch version: 2.1.0.dev20230729
- torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
- torchvision version: 0.16.0.dev20230729
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230730 (#2356)
- torch version: 2.1.0.dev20230730
- torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
- torchvision version: 0.16.0.dev20230730
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230731 (#2359)
- torch version: 2.1.0.dev20230731
- torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
- torchvision version: 0.16.0.dev20230731
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support
* SW-95317 Propagate Lazy->Jit->MLIR scope name.
* Enhance location information based on op names
Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.
* Update locations logic; updated debug-info.py test
* Use {scope}/{op_name} format to track names by default
---------
Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
build: update llvm tag to 41895843
Summary of changes:
- Update tags
llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
update PyTorch version to 2.1.0.dev20230802 (#2366)
- torch version: 2.1.0.dev20230802
- torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
- torchvision version: 0.16.0.dev20230802
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Change Python version from 3.10 to 3.11 in installation instructions (#2370)
Add CITATION file (#2371)
Add packaging as an install dependency (#2369)
Needed by `torch_mlir._version`. Resolves #2368.
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
update PyTorch version to 2.1.0.dev20230803 (#2372)
- torch version: 2.1.0.dev20230803
- torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
- torchvision version: 0.16.0.dev20230803
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Prevent failed stable CI job from cancelling nightly jobs (#2373)
The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)
update
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update
update:
add support for adaptive_pool_id
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update:
* update
---------
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-08-05 07:48:09 +08:00
|
|
|
|
kernelSize.push_back(constantOne);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
|
|
|
|
Value strideList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
|
|
|
|
ValueRange{constantOne});
|
|
|
|
|
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
|
|
|
|
ValueRange{constantZero});
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAvgPool1dOp>(
|
|
|
|
|
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
|
|
|
|
|
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-05-13 20:06:24 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op.
|
2022-05-13 20:06:24 +08:00
|
|
|
|
//
|
|
|
|
|
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// output size the kernelSize, stride and padding is calculated as follows:
|
2022-05-13 20:06:24 +08:00
|
|
|
|
// strideH = inH // outH
|
|
|
|
|
// strideW = inH // outH
|
2024-04-11 17:02:59 +08:00
|
|
|
|
// kernelH = inH - [(outH - 1) * strideH] = strideH
|
|
|
|
|
// kernelW = inW - [(outW - 1) * strideW] = strideW
|
2022-05-13 20:06:24 +08:00
|
|
|
|
// paddingH = 0, paddingW = 0
|
|
|
|
|
//
|
|
|
|
|
class DecomposeAtenAdaptiveAvgPool2dOp
|
|
|
|
|
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value input = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(input);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned rank = *maybeRank;
|
2022-05-13 20:06:24 +08:00
|
|
|
|
SmallVector<Value, 2> inputHW;
|
|
|
|
|
Value dimH = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(rank - 2));
|
|
|
|
|
inputHW.push_back(
|
|
|
|
|
/*inH=*/rewriter.create<AtenSizeIntOp>(loc, input, dimH));
|
|
|
|
|
Value dimW = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(rank - 1));
|
|
|
|
|
inputHW.push_back(
|
|
|
|
|
/*inW=*/rewriter.create<AtenSizeIntOp>(loc, input, dimW));
|
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value outputShape = op.getOutputSize();
|
2022-05-13 20:06:24 +08:00
|
|
|
|
SmallVector<Value> outputShapeSizesTorchInt;
|
|
|
|
|
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
|
|
|
|
|
|
|
|
|
// TODO: Add support for cases other than:
|
2024-04-11 17:02:59 +08:00
|
|
|
|
// inH % outH != 0 or inW % outW != 0
|
2022-05-13 20:06:24 +08:00
|
|
|
|
|
|
|
|
|
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
SmallVector<Value, 2> kernelSize;
|
|
|
|
|
|
|
|
|
|
for (unsigned i = 0; i < inputHW.size(); i++) {
|
2024-04-11 17:02:59 +08:00
|
|
|
|
Value remainder = rewriter.create<AtenRemainderIntOp>(
|
|
|
|
|
loc, inputHW[i], outputShapeSizesTorchInt[i]);
|
|
|
|
|
Value cond = rewriter.create<AtenEqIntOp>(loc, remainder, constantZero);
|
|
|
|
|
rewriter.create<RuntimeAssertOp>(loc, cond,
|
|
|
|
|
"unimplemented: only support cases "
|
|
|
|
|
"input size is an integer multiple of "
|
|
|
|
|
"output size");
|
|
|
|
|
Value stride = rewriter.create<AtenFloordivIntOp>(
|
|
|
|
|
loc, inputHW[i], outputShapeSizesTorchInt[i]);
|
|
|
|
|
Value kernelSizeValue = stride;
|
|
|
|
|
kernelSize.push_back(kernelSizeValue);
|
2022-05-13 20:06:24 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
2024-04-11 17:02:59 +08:00
|
|
|
|
Value strideList = kernelSizeList;
|
2022-05-13 20:06:24 +08:00
|
|
|
|
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
|
|
|
|
ValueRange{constantZero, constantZero});
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
|
|
|
|
|
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
|
2022-12-08 04:20:41 +08:00
|
|
|
|
/*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue,
|
|
|
|
|
/*divisorOverride=*/constantNone);
|
2022-05-13 20:06:24 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-06-03 15:41:13 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.clampMin` op into `aten.clamp` op.
|
2022-06-03 15:41:13 +08:00
|
|
|
|
class DecomposeAtenClampMinOp : public OpRewritePattern<AtenClampMinOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenClampMinOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenClampOp>(op, op.getType(), op.getSelf(),
|
|
|
|
|
op.getMin(), /*max=*/constantNone);
|
2022-06-03 15:41:13 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-12-05 13:55:51 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.clamp_min.Tensor` op into `aten.clamp.Tensor` op.
|
|
|
|
|
class DecomposeAtenClampMinTensorOp
|
|
|
|
|
: public OpRewritePattern<AtenClampMinTensorOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenClampMinTensorOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenClampTensorOp>(
|
|
|
|
|
op, op.getType(), op.getSelf(), op.getMin(), /*max=*/constantNone);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-06-03 15:41:13 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.clampMax` op into `aten.clamp` op.
|
2022-06-03 15:41:13 +08:00
|
|
|
|
class DecomposeAtenClampMaxOp : public OpRewritePattern<AtenClampMaxOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenClampMaxOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenClampOp>(op, op.getType(), op.getSelf(),
|
|
|
|
|
/*min=*/constantNone, op.getMax());
|
2022-06-03 15:41:13 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-11-08 15:28:30 +08:00
|
|
|
|
namespace {
|
2023-11-23 04:31:06 +08:00
|
|
|
|
class DecomposeAtenCosineSimilarityOp
|
2023-11-08 15:28:30 +08:00
|
|
|
|
: public OpRewritePattern<AtenCosineSimilarityOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenCosineSimilarityOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value x1 = op.getX1();
|
|
|
|
|
Value x2 = op.getX2();
|
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
|
|
|
|
|
// Broadcast x1 and x2 to the same shape
|
|
|
|
|
SmallVector<int64_t> indexBroadcastShapeInt;
|
|
|
|
|
SmallVector<Value> indexBroadcastShapeValue;
|
|
|
|
|
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
|
|
|
|
|
indexBroadcastShapeValue);
|
|
|
|
|
Type dtype = x1.getType().cast<BaseTensorType>().getOptionalDtype();
|
|
|
|
|
Type broadcastType = ValueTensorType::get(
|
|
|
|
|
op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype);
|
|
|
|
|
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
indexBroadcastShapeValue);
|
|
|
|
|
x1 = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, x1,
|
|
|
|
|
indexBroadcastShapeTorchList);
|
|
|
|
|
x2 = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, x2,
|
|
|
|
|
indexBroadcastShapeTorchList);
|
|
|
|
|
|
|
|
|
|
// Compute the mul of A and B
|
2023-11-23 04:31:06 +08:00
|
|
|
|
Value dotProduct =
|
2023-11-08 15:28:30 +08:00
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, broadcastType, x1, x2);
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
|
|
|
|
ValueRange{dim});
|
|
|
|
|
Value sumDotProduct = rewriter.create<Torch::AtenSumDimIntListOp>(
|
|
|
|
|
loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList,
|
|
|
|
|
/*keepdim=*/cstFalse,
|
|
|
|
|
/*dtype=*/cstNone);
|
2023-11-23 04:31:06 +08:00
|
|
|
|
|
2023-11-08 15:28:30 +08:00
|
|
|
|
// Compute the norm of A and B
|
2023-11-23 04:31:06 +08:00
|
|
|
|
Value ord = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr(2.0));
|
2023-11-08 15:28:30 +08:00
|
|
|
|
Value normA = rewriter.create<AtenLinalgVectorNormOp>(
|
|
|
|
|
loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse,
|
|
|
|
|
/*dtype=*/cstNone);
|
|
|
|
|
Value normB = rewriter.create<AtenLinalgVectorNormOp>(
|
|
|
|
|
loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse,
|
|
|
|
|
/*dtype=*/cstNone);
|
2023-11-23 04:31:06 +08:00
|
|
|
|
|
2023-11-08 15:28:30 +08:00
|
|
|
|
// Compute the product of the norms
|
|
|
|
|
Value normProduct =
|
|
|
|
|
rewriter.create<AtenMulTensorOp>(loc, op.getType(), normA, normB);
|
|
|
|
|
Value normProductClamp = rewriter.create<AtenClampOp>(
|
|
|
|
|
loc, op.getType(), normProduct, op.getEps(), /*max=*/cstNone);
|
|
|
|
|
// Compute the final cosine similarity by division
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivTensorOp>(
|
|
|
|
|
op, op.getType(), sumDotProduct, normProductClamp);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-05-30 16:08:54 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
|
|
|
|
|
// `aten.add.Tensor` op.
|
|
|
|
|
class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenBaddbmmOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Value bmm = rewriter.create<AtenBmmOp>(loc, op.getType(), op.getBatch1(),
|
|
|
|
|
op.getBatch2());
|
2022-05-30 16:08:54 +08:00
|
|
|
|
Value alphaTimesBmm =
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
|
|
|
|
|
Value input = op.getSelf();
|
2022-05-30 16:08:54 +08:00
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
BaseTensorType resultType =
|
|
|
|
|
op->getResult(0).getType().cast<BaseTensorType>();
|
|
|
|
|
if (inputType.hasDtype() && resultType.hasDtype() &&
|
|
|
|
|
inputType.getDtype() != resultType.getDtype()) {
|
|
|
|
|
input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype());
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op.getType(), alphaTimesBmm, op.getSelf(), op.getBeta());
|
2022-05-30 16:08:54 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-06-09 14:09:28 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.floorDivide` op into `aten.div.TensorMode` op.
|
2022-06-09 14:09:28 +08:00
|
|
|
|
class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// https://pytorch.org/docs/stable/generated/torch.floorDivide.html
|
|
|
|
|
// PyTorch aten.floorDivide is a misnomer because it actually rounds
|
2022-08-06 23:38:06 +08:00
|
|
|
|
// the quotient towards zero instead of taking its floor.
|
2022-06-09 14:09:28 +08:00
|
|
|
|
Value cstStrFloor =
|
2022-08-06 23:38:06 +08:00
|
|
|
|
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
|
2022-06-09 14:09:28 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op.getType(), op.getSelf(), op.getOther(),
|
|
|
|
|
/*roundingMode=*/cstStrFloor);
|
2022-06-09 14:09:28 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-06-03 20:38:59 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.numpyT` op into `aten.permute` op.
|
2022-06-03 20:38:59 +08:00
|
|
|
|
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNumpyTOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-06-03 20:38:59 +08:00
|
|
|
|
|
|
|
|
|
SmallVector<Value> dimListElements;
|
2022-12-13 00:56:28 +08:00
|
|
|
|
SmallVector<int> dimListInts(llvm::reverse(
|
|
|
|
|
llvm::iota_range<int>(0, inputRank, /*inclusive=*/false)));
|
|
|
|
|
for (int dimListInt : dimListInts) {
|
2022-06-03 20:38:59 +08:00
|
|
|
|
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
2022-12-13 00:56:28 +08:00
|
|
|
|
loc, rewriter.getI64IntegerAttr(dimListInt)));
|
|
|
|
|
}
|
2022-06-03 20:38:59 +08:00
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
|
|
|
|
dimListElements);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenPermuteOp>(op, op.getType(), self, dimList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
|
template <typename OpTy>
|
|
|
|
|
static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
2023-02-28 10:32:21 +08:00
|
|
|
|
bool unbiased, double correction) {
|
2022-07-22 20:42:14 +08:00
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
Value dimList = op.getDim();
|
|
|
|
|
Value keepDim = op.getKeepdim();
|
2022-07-22 20:42:14 +08:00
|
|
|
|
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
|
Type outputType = op.getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
|
BaseTensorType outputTensorType = cast<BaseTensorType>(outputType);
|
2023-01-04 06:19:18 +08:00
|
|
|
|
if (!outputTensorType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"expected result type to have a dtype");
|
|
|
|
|
}
|
2022-12-08 01:51:37 +08:00
|
|
|
|
Type newOutputType = outputTensorType.getWithSizesAndDtype(
|
2022-07-22 20:42:14 +08:00
|
|
|
|
outputTensorType.getSizes(), rewriter.getF64Type());
|
|
|
|
|
if (!inputTensorTy.hasDtype() ||
|
|
|
|
|
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "support floating-point type input only");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Upcasting the input tensor to `F64` dtype for higher precision during the
|
|
|
|
|
// computation of the result.
|
|
|
|
|
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
|
|
|
|
|
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
|
|
|
|
|
inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
|
}
|
|
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
2022-12-13 00:56:28 +08:00
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
2022-07-22 20:42:14 +08:00
|
|
|
|
SmallVector<Value> dimListElements;
|
|
|
|
|
bool isNoneOrEmpty = true;
|
|
|
|
|
if (!dimList.getType().template isa<Torch::NoneType>()) {
|
|
|
|
|
if (!getListConstructElements(dimList, dimListElements))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expect dimList to be constructed from list construct");
|
|
|
|
|
if (!dimListElements.empty() || inputRank == 0)
|
|
|
|
|
isNoneOrEmpty = false;
|
|
|
|
|
}
|
|
|
|
|
if (isNoneOrEmpty) {
|
|
|
|
|
for (unsigned i = 0; i < inputRank; i++)
|
|
|
|
|
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
dimListElements);
|
|
|
|
|
}
|
|
|
|
|
Type meanDimResultType = inputTensorTy;
|
|
|
|
|
for (unsigned i = 0; i < dimListElements.size(); i++)
|
|
|
|
|
meanDimResultType = computeReductionType(
|
2024-04-11 21:47:35 +08:00
|
|
|
|
rewriter, op, cast<BaseTensorType>(meanDimResultType),
|
2022-07-22 20:42:14 +08:00
|
|
|
|
dimListElements[i],
|
|
|
|
|
/*keepDim=*/true);
|
|
|
|
|
|
|
|
|
|
Value constantNone = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value constantTrue = rewriter.create<ConstantBoolOp>(loc, true);
|
|
|
|
|
Value meanAlongDims = rewriter.create<AtenMeanDimOp>(
|
|
|
|
|
loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue,
|
|
|
|
|
/*dtype=*/constantNone);
|
|
|
|
|
Value subMean =
|
|
|
|
|
createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims);
|
|
|
|
|
Value square = rewriter.create<AtenSquareOp>(loc, inputTensorTy, subMean);
|
|
|
|
|
|
|
|
|
|
if (!unbiased) {
|
2022-12-08 01:51:37 +08:00
|
|
|
|
Value result = rewriter.create<AtenMeanDimOp>(
|
|
|
|
|
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
2022-07-22 20:42:14 +08:00
|
|
|
|
result = convertTensorToDtype(rewriter, loc, result,
|
|
|
|
|
outputTensorType.getDtype());
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
// Divide the square sum by productDimSize - correction.
|
|
|
|
|
Value squareSum = rewriter.create<AtenSumDimIntListOp>(
|
2022-12-08 01:51:37 +08:00
|
|
|
|
loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone);
|
2022-07-22 20:42:14 +08:00
|
|
|
|
|
|
|
|
|
// `productDimSize` is product of sizes of dimensions to be reduced.
|
|
|
|
|
Value constantOne =
|
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value productDimSize = constantOne;
|
|
|
|
|
for (Value dim : dimListElements) {
|
|
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
|
|
|
|
productDimSize =
|
|
|
|
|
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
|
|
|
|
}
|
2023-02-28 10:32:21 +08:00
|
|
|
|
productDimSize = rewriter.create<AtenFloatScalarOp>(loc, productDimSize);
|
|
|
|
|
constantOne = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr(1.0));
|
|
|
|
|
Value cstCorrection = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr(correction));
|
2022-07-22 20:42:14 +08:00
|
|
|
|
// The `correction` value should be less than or equal to `productDimSize +
|
|
|
|
|
// 1`.
|
2023-09-30 07:45:48 +08:00
|
|
|
|
if (!isAssumingStrictSymbolicShapes(rewriter)) {
|
|
|
|
|
Value productDimSizePlusOne = rewriter.create<AtenAddOp>(
|
|
|
|
|
loc, productDimSize.getType(), productDimSize, constantOne);
|
|
|
|
|
Value cond = rewriter.create<AtenGeFloatOp>(loc, productDimSizePlusOne,
|
|
|
|
|
cstCorrection);
|
|
|
|
|
rewriter.create<RuntimeAssertOp>(
|
|
|
|
|
loc, cond,
|
|
|
|
|
"correction value should be less than or equal to productDimSize + 1");
|
|
|
|
|
}
|
2022-07-22 20:42:14 +08:00
|
|
|
|
Value productDimSizeSubCorrection =
|
2023-02-28 10:32:21 +08:00
|
|
|
|
rewriter.create<AtenSubFloatOp>(loc, productDimSize, cstCorrection);
|
2022-12-08 01:51:37 +08:00
|
|
|
|
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
|
|
|
|
|
productDimSizeSubCorrection);
|
2022-07-22 20:42:14 +08:00
|
|
|
|
result =
|
|
|
|
|
convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype());
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2022-06-29 15:23:57 +08:00
|
|
|
|
// Decompose aten.var(x, dims) into:
|
|
|
|
|
// sub = aten.sub(x, aten.mean(x, dims))
|
|
|
|
|
// square = aten.square(sub)
|
|
|
|
|
// For Unbiased case:
|
|
|
|
|
// out = aten.sum(square, dims) / (productDimSize-1)
|
|
|
|
|
// For Biased case:
|
|
|
|
|
// out = aten.mean(square, dims)
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenVarDimOp : public OpRewritePattern<AtenVarDimOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenVarDimOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
bool unbiased;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getUnbiased(), m_TorchConstantBool(&unbiased))) {
|
2022-06-29 15:23:57 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Only support constant unbiased for aten.var");
|
|
|
|
|
}
|
2023-02-28 10:32:21 +08:00
|
|
|
|
double correction = unbiased ? 1.0 : 0.0;
|
2022-07-22 20:42:14 +08:00
|
|
|
|
if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased,
|
|
|
|
|
correction)))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
2022-06-29 15:23:57 +08:00
|
|
|
|
|
2022-07-22 20:42:14 +08:00
|
|
|
|
// Decompose aten.var(x, dims) into:
|
|
|
|
|
// sub = aten.sub(x, aten.mean(x, dims))
|
|
|
|
|
// square = aten.square(sub)
|
|
|
|
|
// out = aten.sum(square, dims) / (productDimSize - correction)
|
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenVarCorrectionOp
|
|
|
|
|
: public OpRewritePattern<AtenVarCorrectionOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenVarCorrectionOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2023-02-28 10:32:21 +08:00
|
|
|
|
int64_t correctionValInt;
|
|
|
|
|
double correctionValFloat = 1.0;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!op.getCorrection().getType().isa<Torch::NoneType>()) {
|
2023-02-28 10:32:21 +08:00
|
|
|
|
if (op.getCorrection().getType().isa<Torch::FloatType>()) {
|
|
|
|
|
if (!matchPattern(op.getCorrection(),
|
|
|
|
|
m_TorchConstantFloat(&correctionValFloat)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Only support constant int or float correction value for "
|
|
|
|
|
"aten.var");
|
|
|
|
|
} else if (op.getCorrection().getType().isa<Torch::IntType>()) {
|
|
|
|
|
if (!matchPattern(op.getCorrection(),
|
|
|
|
|
m_TorchConstantInt(&correctionValInt)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Only support constant int or float correction value for "
|
|
|
|
|
"aten.var");
|
|
|
|
|
correctionValFloat = (double)correctionValInt;
|
|
|
|
|
} else {
|
2022-07-22 20:42:14 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-02-28 10:32:21 +08:00
|
|
|
|
op, "unimplemented: correction value should be only constant int "
|
|
|
|
|
"or float for aten.var");
|
|
|
|
|
}
|
2022-06-29 15:23:57 +08:00
|
|
|
|
}
|
2023-02-28 10:32:21 +08:00
|
|
|
|
|
|
|
|
|
bool unbiased = correctionValFloat == 0.0 ? false : true;
|
2022-07-22 20:42:14 +08:00
|
|
|
|
if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased,
|
2023-02-28 10:32:21 +08:00
|
|
|
|
correctionValFloat)))
|
2022-07-22 20:42:14 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
|
2022-06-29 15:23:57 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-05-10 21:15:59 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose the `aten.selectScatter` operation into `aten.sliceScatter` op.
|
2022-05-10 21:15:59 +08:00
|
|
|
|
class DecomposeAtenSelectScatterOp
|
|
|
|
|
: public OpRewritePattern<AtenSelectScatterOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenSelectScatterOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value start = op.getIndex();
|
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
Value src = op.getSrc();
|
2022-05-10 21:15:59 +08:00
|
|
|
|
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value startPlusOne =
|
|
|
|
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
|
|
|
|
|
2023-02-03 10:20:47 +08:00
|
|
|
|
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim);
|
|
|
|
|
if (failed(unsqueezedInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"cannot generate unsqueeze tensor op");
|
2022-05-10 21:15:59 +08:00
|
|
|
|
}
|
2023-02-03 10:20:47 +08:00
|
|
|
|
src = *unsqueezedInfo;
|
2022-05-10 21:15:59 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
|
2022-05-10 21:15:59 +08:00
|
|
|
|
/*step=*/one);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-08-09 06:56:49 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAten_EmbeddingBagOp
|
|
|
|
|
: public OpRewritePattern<Aten_EmbeddingBagOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
|
Value indices = op.getIndices();
|
|
|
|
|
Value offsets = op.getOffsets();
|
|
|
|
|
Value scaleGradByFreq = op.getScaleGradByFreq();
|
|
|
|
|
Value mode = op.getMode();
|
|
|
|
|
Value sparse = op.getSparse();
|
|
|
|
|
Value perSampleWeights = op.getPerSampleWeights();
|
|
|
|
|
Value includeLastOffset = op.getIncludeLastOffset();
|
|
|
|
|
Value paddingIdx = op.getPaddingIdx();
|
2022-08-09 06:56:49 +08:00
|
|
|
|
|
|
|
|
|
auto resultType0 = op->getResult(0).getType();
|
|
|
|
|
auto resultType1 = op->getResult(1).getType();
|
|
|
|
|
auto resultType2 = op->getResult(2).getType();
|
|
|
|
|
auto resultType3 = op->getResult(3).getType();
|
|
|
|
|
|
2024-01-30 01:23:05 +08:00
|
|
|
|
llvm::SmallVector<Type> returnTypes{resultType0, resultType1, resultType2,
|
|
|
|
|
resultType3};
|
2022-08-09 06:56:49 +08:00
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEmbeddingBagPaddingIdxOp>(
|
|
|
|
|
op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode,
|
|
|
|
|
sparse, perSampleWeights, includeLastOffset, paddingIdx);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-09-06 22:07:17 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.liftFreshCopy` op into `aten.clone` op.
|
2022-09-06 22:07:17 +08:00
|
|
|
|
class DecomposeAtenLiftFreshCopyOp
|
|
|
|
|
: public OpRewritePattern<AtenLiftFreshCopyOp> {
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLiftFreshCopyOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value constantNone = rewriter.create<ConstantNoneOp>(op.getLoc());
|
2022-12-08 04:20:41 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenCloneOp>(op, op.getType(), op.getSelf(),
|
2022-09-06 22:07:17 +08:00
|
|
|
|
/*memoryFormat=*/constantNone);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-10-20 19:02:09 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMseLossOp : public OpRewritePattern<AtenMseLossOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMseLossOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
// The `reduction` arg would have only three valid values.
|
|
|
|
|
// 0 means no reduction.
|
|
|
|
|
// 1 means mean reduction.
|
|
|
|
|
// 2 means sum reduction.
|
|
|
|
|
int64_t reductionType;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reductionType)))
|
2022-10-20 19:02:09 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected a constant integer value for reduction");
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
BaseTensorType resultType = op.getType().cast<BaseTensorType>();
|
2022-12-08 04:20:41 +08:00
|
|
|
|
BaseTensorType inputType = op.getSelf().getType().cast<BaseTensorType>();
|
2022-10-20 19:02:09 +08:00
|
|
|
|
if (!inputType.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected the input tensor to have sizes");
|
|
|
|
|
BaseTensorType subType =
|
|
|
|
|
inputType
|
2023-01-25 09:29:42 +08:00
|
|
|
|
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
|
2023-01-04 06:19:18 +08:00
|
|
|
|
resultType.getOptionalDtype())
|
2022-10-20 19:02:09 +08:00
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Value sub =
|
|
|
|
|
createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget());
|
2022-10-20 19:02:09 +08:00
|
|
|
|
Value result = rewriter.create<AtenSquareOp>(loc, subType, sub);
|
|
|
|
|
if (reductionType == torch_upstream::Reduction::None) {
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
if (reductionType == torch_upstream::Reduction::Mean)
|
|
|
|
|
result = rewriter.create<AtenMeanDimOp>(loc, resultType, result,
|
|
|
|
|
/*dim=*/cstNone,
|
|
|
|
|
/*keepdim=*/cstFalse,
|
|
|
|
|
/*dtype=*/cstNone);
|
|
|
|
|
else
|
|
|
|
|
result = rewriter.create<AtenSumDimIntListOp>(
|
|
|
|
|
loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse,
|
|
|
|
|
/*dtype=*/cstNone);
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-02-21 12:08:29 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op
|
|
|
|
|
class DecomposeAtenNormScalarOptDimOp
|
|
|
|
|
: public OpRewritePattern<AtenNormScalarOptDimOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNormScalarOptDimOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
Value ord = op.getP();
|
|
|
|
|
if (ord.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
ord = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr(2.0));
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenLinalgVectorNormOp>(
|
|
|
|
|
op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(),
|
|
|
|
|
/*dtype=*/none);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-11-06 20:44:05 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRandintLowOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Type resultType = op.getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
|
BaseTensorType resultTensorType = cast<BaseTensorType>(resultType);
|
2023-01-04 06:19:18 +08:00
|
|
|
|
if (!resultTensorType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
2022-11-06 20:44:05 +08:00
|
|
|
|
|
|
|
|
|
int64_t cstLow, cstHigh;
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow)))
|
2022-11-06 20:44:05 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: low must be a constant integer");
|
2022-12-08 04:20:41 +08:00
|
|
|
|
if (!matchPattern(op.getHigh(), m_TorchConstantInt(&cstHigh)))
|
2022-11-06 20:44:05 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: high must be a constant integer");
|
|
|
|
|
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
|
Value low = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr((double)cstLow));
|
|
|
|
|
Value high = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr((double)cstHigh));
|
|
|
|
|
|
|
|
|
|
BaseTensorType floatResultType =
|
|
|
|
|
resultTensorType
|
|
|
|
|
.getWithSizesAndDtype(resultTensorType.getSizes(),
|
|
|
|
|
rewriter.getF32Type())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
loc, floatResultType, op.getSize(), /*dtype=*/none,
|
|
|
|
|
/*layout=*/op.getLayout(),
|
2022-12-08 04:20:41 +08:00
|
|
|
|
/*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(),
|
|
|
|
|
/*memoryFormat=*/none);
|
2022-11-06 20:44:05 +08:00
|
|
|
|
|
|
|
|
|
Value result =
|
|
|
|
|
rewriter.create<AtenUniformOp>(loc, floatResultType, emptyTensor,
|
|
|
|
|
/*from=*/low,
|
|
|
|
|
/*to=*/high,
|
|
|
|
|
/*generator=*/none);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
|
|
|
|
|
op, resultType, result,
|
|
|
|
|
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
|
2022-12-08 04:20:41 +08:00
|
|
|
|
/*nonBlocking=*/cstFalse, /*copy=*/cstFalse, /*memoryFormat=*/none);
|
2022-11-06 20:44:05 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-04-04 17:31:21 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRandintOp : public OpRewritePattern<AtenRandintOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRandintOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Type resultType = op.getType();
|
|
|
|
|
|
|
|
|
|
Value low = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
2023-11-01 11:56:54 +08:00
|
|
|
|
|
2023-04-04 17:31:21 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenRandintLowOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(),
|
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
|
|
|
|
|
2023-04-04 17:31:21 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-11-15 22:39:40 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `aten.varMean.correction` op into `aten.var.correction` and
|
2022-11-15 22:39:40 +08:00
|
|
|
|
// `aten.mean.dim` op.
|
|
|
|
|
class DecomposeAtenVarMeanCorrectionOp
|
|
|
|
|
: public OpRewritePattern<AtenVarMeanCorrectionOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenVarMeanCorrectionOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value var = rewriter.create<AtenVarCorrectionOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(),
|
|
|
|
|
op.getKeepdim());
|
|
|
|
|
Value mean = rewriter.create<AtenMeanDimOp>(
|
|
|
|
|
loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(),
|
|
|
|
|
/*dtype=*/noneVal);
|
2022-11-15 22:39:40 +08:00
|
|
|
|
rewriter.replaceOp(op, {var, mean});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-11-21 16:38:47 +08:00
|
|
|
|
namespace {
|
2022-12-08 04:20:41 +08:00
|
|
|
|
// Decompose `prims.convertElementType` op into `aten.to.dtype` op.
|
2022-11-21 16:38:47 +08:00
|
|
|
|
class DecomposePrimsConvertElementTypeOp
|
|
|
|
|
: public OpRewritePattern<PrimsConvertElementTypeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(PrimsConvertElementTypeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op.getType(), op.getA(), op.getDtype(), /*nonBlocking=*/cstFalse,
|
|
|
|
|
/*copy=*/cstFalse, /*memoryFormat=*/cstNone);
|
2022-11-21 16:38:47 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-01-11 14:01:45 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `prims.var` op into `aten.var.correction` op.
|
|
|
|
|
class DecomposePrimsVarOp : public OpRewritePattern<PrimsVarOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(PrimsVarOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
if (!op.getOutputDtype().getType().isa<Torch::NoneType>())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented non-None dtype for prims::var op");
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenVarCorrectionOp>(
|
|
|
|
|
op, op.getType(), op.getInp(), op.getDims(), op.getCorrection(),
|
|
|
|
|
/*keepdim=*/cstFalse);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `prims.sqrt` op into `aten.sqrt` op.
|
|
|
|
|
class DecomposePrimsSqrtOp : public OpRewritePattern<PrimsSqrtOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(PrimsSqrtOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), op.getSelf());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-11-14 15:08:13 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// The op is decomposed using the Box-Muller transform.
|
|
|
|
|
// Refer: https://en.wikipedia.org/wiki/Box-Muller_transform
|
|
|
|
|
class DecomposeAtenRandnGeneratorOp
|
|
|
|
|
: public OpRewritePattern<AtenRandnGeneratorOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2023-04-19 23:25:39 +08:00
|
|
|
|
auto resultType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
if (!resultType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
2022-11-14 15:08:13 +08:00
|
|
|
|
|
2023-04-19 23:25:39 +08:00
|
|
|
|
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
|
2022-11-14 15:08:13 +08:00
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value low = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr((double)0.0));
|
|
|
|
|
Value high = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr((double)1.0));
|
|
|
|
|
Value cstMinusTwo = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr((double)-2.0));
|
|
|
|
|
Value cstTwoPie = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));
|
|
|
|
|
|
|
|
|
|
Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
|
2023-04-19 23:25:39 +08:00
|
|
|
|
loc, resultType, op.getSize(), /*dtype=*/dtype,
|
|
|
|
|
/*layout=*/op.getLayout(),
|
2022-12-08 04:20:41 +08:00
|
|
|
|
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
2022-11-14 15:08:13 +08:00
|
|
|
|
/*memory_format=*/none);
|
|
|
|
|
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
|
2023-04-19 23:25:39 +08:00
|
|
|
|
loc, resultType, op.getSize(), /*dtype=*/dtype,
|
|
|
|
|
/*layout=*/op.getLayout(),
|
2022-12-08 04:20:41 +08:00
|
|
|
|
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
2022-11-14 15:08:13 +08:00
|
|
|
|
/*memory_format=*/none);
|
|
|
|
|
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Value uOne =
|
|
|
|
|
rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorA,
|
|
|
|
|
/*from=*/low,
|
|
|
|
|
/*to=*/high,
|
|
|
|
|
/*generator=*/op.getGenerator());
|
|
|
|
|
Value uTwo =
|
|
|
|
|
rewriter.create<AtenUniformOp>(loc, resultType, emptyTensorB,
|
|
|
|
|
/*from=*/low,
|
|
|
|
|
/*to=*/high,
|
|
|
|
|
/*generator=*/op.getGenerator());
|
2022-11-14 15:08:13 +08:00
|
|
|
|
|
|
|
|
|
Value logUOne = rewriter.create<AtenLogOp>(loc, resultType, uOne);
|
|
|
|
|
Value minusTwoLogUOne =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resultType, logUOne, cstMinusTwo);
|
|
|
|
|
Value r = rewriter.create<AtenSqrtOp>(loc, resultType, minusTwoLogUOne);
|
|
|
|
|
Value theta =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, resultType, uTwo, cstTwoPie);
|
|
|
|
|
Value cosTheta = rewriter.create<AtenCosOp>(loc, resultType, theta);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, op.getType(), r, cosTheta);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.randn` op into `aten.randn.generator` op.
|
|
|
|
|
class DecomposeAtenRandnOp : public OpRewritePattern<AtenRandnOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRandnOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenRandnGeneratorOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
|
op, op.getType(), op.getSize(), /*generator=*/none, op.getDtype(),
|
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
2022-11-14 15:08:13 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-01-16 19:40:21 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.randn_like` op into `aten.randn.generator` op.
|
|
|
|
|
class DecomposeAtenRandnLikeOp : public OpRewritePattern<AtenRandnLikeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
|
|
|
|
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
|
|
|
|
int64_t memoryFormat;
|
|
|
|
|
if (!matchPattern(op.getMemoryFormat(),
|
|
|
|
|
m_TorchConstantInt(&memoryFormat)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: the memory format should be specified in "
|
|
|
|
|
"an integer constant");
|
|
|
|
|
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
|
|
|
|
memoryFormat != torch_upstream::MemoryFormat::Preserve)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: only none, contiguous and preserve "
|
|
|
|
|
"memory_format is supported");
|
|
|
|
|
}
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
|
|
|
|
auto sizeListType =
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
|
|
|
|
Value sizeList =
|
|
|
|
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenRandnGeneratorOp>(
|
|
|
|
|
op, op.getType(), sizeList, /*generator=*/none, op.getDtype(),
|
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-09-02 02:13:58 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenRandOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto resultType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
if (!resultType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected result type to have a dtype");
|
|
|
|
|
}
|
|
|
|
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
Value low = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr((double)0.0));
|
|
|
|
|
Value high = rewriter.create<Torch::ConstantFloatOp>(
|
|
|
|
|
loc, rewriter.getF64FloatAttr((double)1.0));
|
|
|
|
|
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
|
|
|
|
loc, resultType, op.getSize(), /*dtype=*/op.getDtype(),
|
|
|
|
|
/*layout=*/op.getLayout(),
|
|
|
|
|
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
|
|
|
|
/*memory_format=*/noneVal);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenUniformOp>(op, resultType, emptyTensor,
|
|
|
|
|
/*from=*/low,
|
|
|
|
|
/*to=*/high,
|
|
|
|
|
/*generator=*/noneVal);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-03-14 08:28:33 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern<AtenLinspaceOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLinspaceOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = getContext();
|
|
|
|
|
|
|
|
|
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
|
Value zero =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
|
|
|
|
|
Value addStart;
|
|
|
|
|
int64_t steps;
|
|
|
|
|
if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) {
|
|
|
|
|
// specically handle steps == 1
|
|
|
|
|
Value arange = rewriter.create<AtenArangeStartOp>(
|
|
|
|
|
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
|
|
|
|
|
op.getDevice(), op.getPinMemory());
|
|
|
|
|
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, arange,
|
|
|
|
|
op.getStart(), one);
|
|
|
|
|
} else {
|
|
|
|
|
// handle steps != 1 or dynamic steps
|
|
|
|
|
Value neOrNot = rewriter.create<AtenNeIntOp>(loc, op.getSteps(), one);
|
|
|
|
|
rewriter.create<RuntimeAssertOp>(
|
|
|
|
|
loc, neOrNot,
|
|
|
|
|
rewriter.getStringAttr("linspace's dynamic steps must not be 1"));
|
|
|
|
|
// create arange: [0, ..., steps - 1]
|
|
|
|
|
Value arange = rewriter.create<AtenArangeStartOp>(
|
|
|
|
|
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
|
|
|
|
|
op.getDevice(), op.getPinMemory());
|
|
|
|
|
// calculate (end - start) / (steps - 1)
|
|
|
|
|
Value sub;
|
|
|
|
|
if (op.getEnd().getType().isa<Torch::FloatType>() ||
|
|
|
|
|
op.getStart().getType().isa<Torch::FloatType>()) {
|
|
|
|
|
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
|
|
|
|
|
op.getEnd(), op.getStart());
|
|
|
|
|
} else {
|
|
|
|
|
sub = rewriter.create<AtenSubIntOp>(loc, op.getEnd(), op.getStart());
|
|
|
|
|
}
|
|
|
|
|
Value div = rewriter.create<AtenDivOp>(
|
|
|
|
|
loc, sub, rewriter.create<AtenSubIntOp>(loc, op.getSteps(), one));
|
|
|
|
|
// calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start
|
|
|
|
|
Value mulScalar =
|
|
|
|
|
rewriter.create<AtenMulScalarOp>(loc, baseType, arange, div);
|
|
|
|
|
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, mulScalar,
|
|
|
|
|
op.getStart(), one);
|
|
|
|
|
}
|
|
|
|
|
// to dtype
|
|
|
|
|
Value result;
|
|
|
|
|
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
|
|
|
|
|
result = rewriter.create<AtenToDtypeOp>(
|
|
|
|
|
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
|
|
|
|
|
/*copy=*/falseVal, /*memory_format=*/none);
|
|
|
|
|
} else {
|
|
|
|
|
Value f32Type = rewriter.create<ConstantIntOp>(
|
|
|
|
|
loc, (int)torch_upstream::ScalarType::Float);
|
|
|
|
|
result = rewriter.create<AtenToDtypeOp>(
|
|
|
|
|
loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal,
|
|
|
|
|
/*copy=*/falseVal, /*memory_format=*/none);
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-12-09 23:22:26 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenVarMeanOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value var = rewriter.create<AtenVarDimOp>(loc, op.getType(0), op.getSelf(),
|
|
|
|
|
/*dim=*/noneVal, op.getUnbiased(),
|
|
|
|
|
/*keepdim=*/falseVal);
|
|
|
|
|
Value mean = rewriter.create<AtenMeanOp>(loc, op.getType(0), op.getSelf(),
|
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
|
rewriter.replaceOp(op, {var, mean});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-12-29 22:52:23 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenNewEmptyStridedOp
|
|
|
|
|
: public OpRewritePattern<AtenNewEmptyStridedOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value opSize = op.getSize();
|
|
|
|
|
Value opStride = op.getStride();
|
2022-12-29 22:52:23 +08:00
|
|
|
|
|
2023-11-01 11:56:54 +08:00
|
|
|
|
if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc)))
|
2022-12-29 22:52:23 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, "Unable to determine if stride is default");
|
2022-12-29 22:52:23 +08:00
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
|
|
|
|
|
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
|
|
|
|
|
op.getLayout(), op.getDevice(), op.getPinMemory());
|
2023-11-01 11:56:54 +08:00
|
|
|
|
|
2022-12-29 22:52:23 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-09-14 01:04:31 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenEmptyStridedOp
|
|
|
|
|
: public OpRewritePattern<AtenEmptyStridedOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
2023-11-01 11:56:54 +08:00
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value opSize = op.getSize();
|
|
|
|
|
Value opStride = op.getStride();
|
|
|
|
|
|
|
|
|
|
if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc)))
|
2023-09-14 01:04:31 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, "Unable to determine if stride is default");
|
2023-09-14 01:04:31 +08:00
|
|
|
|
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
2023-11-01 11:56:54 +08:00
|
|
|
|
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(),
|
|
|
|
|
op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal);
|
2023-09-14 01:04:31 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-11-16 13:57:58 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposePrimsSqueezeOp : public OpRewritePattern<PrimsSqueezeOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(PrimsSqueezeOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.getA();
|
|
|
|
|
SmallVector<int64_t> dimensions;
|
|
|
|
|
if (!matchPattern(op.getDimensions(),
|
|
|
|
|
m_TorchListOfConstantInts(dimensions)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "all dimensions must be constant ints");
|
|
|
|
|
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
std::sort(dimensions.rbegin(), dimensions.rend());
|
2022-11-16 13:57:58 +08:00
|
|
|
|
|
|
|
|
|
if (dimensions.size() == 0) {
|
|
|
|
|
rewriter.replaceOp(op, input);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
Value result = input;
|
|
|
|
|
for (unsigned i = 0; i < dimensions.size(); i++) {
|
|
|
|
|
auto squeezeTensorInfo =
|
|
|
|
|
squeezeTensor(rewriter, op, loc, dimensions[i], result);
|
|
|
|
|
if (failed(squeezeTensorInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"cannot generate unsqueeze tensor");
|
|
|
|
|
}
|
|
|
|
|
result = *squeezeTensorInfo;
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-01-02 22:34:39 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenMovedimIntOp : public OpRewritePattern<AtenMovedimIntOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenMovedimIntOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input tensor to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
if (inputRank <= 1) {
|
|
|
|
|
rewriter.replaceOp(op, input);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t srcDimInt, dstDimInt;
|
|
|
|
|
if (matchPattern(op.getSource(), m_TorchConstantInt(&srcDimInt))) {
|
|
|
|
|
srcDimInt = toPositiveDim(srcDimInt, inputRank);
|
|
|
|
|
if (!isValidDim(srcDimInt, inputRank))
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "source is not a valid dim");
|
|
|
|
|
} else {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "source is not a constant int");
|
|
|
|
|
}
|
|
|
|
|
if (matchPattern(op.getDestination(), m_TorchConstantInt(&dstDimInt))) {
|
|
|
|
|
dstDimInt = toPositiveDim(dstDimInt, inputRank);
|
|
|
|
|
if (!isValidDim(dstDimInt, inputRank))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"destination is not a valid dim");
|
|
|
|
|
} else {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"destination is not a constant int");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> dimsOrder =
|
|
|
|
|
computeDimsOrderForMoveDim(srcDimInt, dstDimInt, inputRank);
|
|
|
|
|
SmallVector<Value> cstDimsOrder;
|
|
|
|
|
for (int64_t dim : dimsOrder)
|
|
|
|
|
cstDimsOrder.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dim)));
|
|
|
|
|
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
|
|
|
|
cstDimsOrder);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenPermuteOp>(op, op.getType(), input,
|
|
|
|
|
permuteDimsOrder);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-03-15 16:00:03 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenCrossEntropyLossOp
|
|
|
|
|
: public OpRewritePattern<AtenCrossEntropyLossOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenCrossEntropyLossOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
Value target = op.getTarget();
|
|
|
|
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
|
|
|
|
if (!maybeRank)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: unranked input tensor");
|
|
|
|
|
unsigned selfRank = maybeRank.value();
|
|
|
|
|
maybeRank = getTensorRank(target);
|
|
|
|
|
if (!maybeRank)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: unranked target tensor");
|
|
|
|
|
unsigned targetRank = maybeRank.value();
|
|
|
|
|
|
|
|
|
|
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
|
|
|
|
|
// of the form [minibatch] the cross entropy loss decomposes to the
|
|
|
|
|
// combination of softmax and nll loss as follows:
|
|
|
|
|
// cross_entropy_loss = NLLLoss(LogSoftmax(input, dim=1), target)
|
|
|
|
|
// Currently, we only support the above-mentioned case.
|
|
|
|
|
if (selfRank != 2 || targetRank != 1) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op,
|
|
|
|
|
"unimplemented: only support cases with 2-d input and 1-d target");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: Add support for label_smoothing value other than 0.0 (default
|
|
|
|
|
// value).
|
|
|
|
|
double labelSmoothing;
|
|
|
|
|
if (!matchPattern(op.getLabelSmoothing(),
|
|
|
|
|
m_TorchConstantFloat(&labelSmoothing))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Only support constant float label_smoothing value");
|
|
|
|
|
} else if (labelSmoothing != 0.0) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"unimplemented: only support default "
|
|
|
|
|
"value of 0.0 for label_smoothing");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value logSoftmax = rewriter.create<AtenLogSoftmaxIntOp>(
|
|
|
|
|
loc, self.getType(), self, dim, /*dtype=*/noneVal);
|
|
|
|
|
Value nllLoss =
|
|
|
|
|
rewriter
|
|
|
|
|
.create<AtenNllLossForwardOp>(
|
|
|
|
|
loc, op.getType(), target.getType(), logSoftmax, target,
|
|
|
|
|
op.getWeight(), op.getReduction(), op.getIgnoreIndex())
|
|
|
|
|
->getResult(0);
|
|
|
|
|
rewriter.replaceOp(op, nllLoss);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-04-11 16:02:28 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|
|
|
|
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenOneHotOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto context = op.getContext();
|
|
|
|
|
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasSizes())
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input tensor should have known sizes.");
|
|
|
|
|
int64_t inputRank = inputType.getSizes().size();
|
|
|
|
|
int64_t numClasses;
|
|
|
|
|
if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: num_classes must be constant");
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
|
|
|
|
|
// arange tensor
|
|
|
|
|
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
|
|
|
|
auto arangeType =
|
|
|
|
|
ValueTensorType::get(context, llvm::ArrayRef(numClasses), si64Type);
|
|
|
|
|
Value arangeTensor = rewriter.create<AtenArangeOp>(
|
|
|
|
|
loc, arangeType, op.getNumClasses(), /*dtype=*/none, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none);
|
|
|
|
|
|
|
|
|
|
// unsqueeze input
|
|
|
|
|
llvm::SmallVector<int64_t> unsqueezeShape(inputType.getSizes());
|
|
|
|
|
unsqueezeShape.push_back(1);
|
|
|
|
|
auto unsqueezeType =
|
|
|
|
|
ValueTensorType::get(context, unsqueezeShape, si64Type);
|
|
|
|
|
Value unsqueezeTensor = rewriter.create<AtenUnsqueezeOp>(
|
|
|
|
|
loc, unsqueezeType, input,
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc,
|
|
|
|
|
rewriter.getI64IntegerAttr(inputRank)));
|
|
|
|
|
|
|
|
|
|
// compare
|
|
|
|
|
auto eqType = ValueTensorType::get(
|
|
|
|
|
context, op.getType().cast<BaseTensorType>().getSizes(),
|
|
|
|
|
IntegerType::get(context, 1));
|
|
|
|
|
Value eqTensor = rewriter.create<AtenEqTensorOp>(
|
|
|
|
|
loc, eqType, unsqueezeTensor, arangeTensor);
|
|
|
|
|
|
|
|
|
|
// convert to si64
|
2023-06-23 16:11:33 +08:00
|
|
|
|
Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type);
|
2023-04-11 16:02:28 +08:00
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-04-26 15:14:06 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.var_mean.dim` op into `aten.var.dim` and
|
|
|
|
|
// `aten.mean.dim` op.
|
|
|
|
|
class DecomposeAtenVarMeanDimOp : public OpRewritePattern<AtenVarMeanDimOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenVarMeanDimOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value var = rewriter.create<AtenVarDimOp>(loc, op.getType(0), op.getSelf(),
|
|
|
|
|
op.getDim(), op.getUnbiased(),
|
|
|
|
|
op.getKeepdim());
|
|
|
|
|
Value mean = rewriter.create<AtenMeanDimOp>(
|
|
|
|
|
loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(),
|
|
|
|
|
/*dtype=*/noneVal);
|
|
|
|
|
rewriter.replaceOp(op, {var, mean});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-06-01 11:38:50 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// decompose aten.scalar_tensor to prim.NumToTensor.Scalar and
|
|
|
|
|
// aten.to.dtype_layout
|
|
|
|
|
class DecomposeAtenScalarTensor : public OpRewritePattern<AtenScalarTensorOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenScalarTensorOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
auto resultTy = op.getResult().getType().cast<BaseTensorType>();
|
|
|
|
|
auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType());
|
|
|
|
|
Value numToTensor = rewriter.create<PrimNumToTensorScalarOp>(
|
|
|
|
|
op.getLoc(),
|
|
|
|
|
resultTy.getWithSizesAndDtype(resultTy.getOptionalSizes(), scalarTy),
|
|
|
|
|
op.getS());
|
|
|
|
|
|
|
|
|
|
Value cstNone = rewriter.create<ConstantNoneOp>(op.getLoc());
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
|
|
|
Value dtype =
|
|
|
|
|
getDtypeIntValueForType(rewriter, op.getLoc(), resultTy.getDtype());
|
|
|
|
|
Value toDTypeLayout = rewriter.create<AtenToDtypeLayoutOp>(
|
|
|
|
|
op.getLoc(), op.getType(), numToTensor, dtype, op.getLayout(),
|
|
|
|
|
op.getDevice(), op.getPinMemory(), /*non_blocking=*/cstFalse,
|
|
|
|
|
/*copy=*/cstFalse, /*memory_format=*/cstNone);
|
|
|
|
|
rewriter.replaceOp(op, toDTypeLayout);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-05-02 21:29:00 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op.
|
|
|
|
|
class DecomposeAtenTopkOp : public OpRewritePattern<AtenTopkOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTopkOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
auto context = op.getContext();
|
|
|
|
|
|
|
|
|
|
bool sorted;
|
|
|
|
|
if (!matchPattern(op.getSorted(), m_TorchConstantBool(&sorted)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Expected a constant boolean value for sorted");
|
|
|
|
|
if (!sorted)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "unimplemented: sorted value arg must be set to True");
|
|
|
|
|
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
Value dim = op.getDim();
|
|
|
|
|
auto selfType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
auto sortIndicesType = selfType.getWithSizesAndDtype(
|
|
|
|
|
selfType.getOptionalSizes(),
|
|
|
|
|
IntegerType::get(context, 64, IntegerType::Signed));
|
|
|
|
|
auto sortOpResult = rewriter.create<AtenSortOp>(
|
|
|
|
|
loc, self.getType(), sortIndicesType, self, dim,
|
|
|
|
|
/*descending=*/op.getLargest());
|
|
|
|
|
Value start = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value step = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
Value resultValue = rewriter.create<AtenSliceTensorOp>(
|
|
|
|
|
loc, op->getResultTypes()[0], sortOpResult->getResult(0), dim, start,
|
|
|
|
|
/*end=*/op.getK(), step);
|
|
|
|
|
Value resultIndices = rewriter.create<AtenSliceTensorOp>(
|
|
|
|
|
loc, op->getResultTypes()[1], sortOpResult->getResult(1), dim, start,
|
|
|
|
|
/*end=*/op.getK(), step);
|
|
|
|
|
rewriter.replaceOp(op, {resultValue, resultIndices});
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2022-10-16 05:46:06 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose `aten.scatter.value` op into `aten.scatter.src` op.
|
|
|
|
|
class DecomposeAtenScatterValueOp
|
|
|
|
|
: public OpRewritePattern<AtenScatterValueOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenScatterValueOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
Value index = op.getIndex();
|
|
|
|
|
std::optional<unsigned> maybeIndexRank = getTensorRank(index);
|
|
|
|
|
if (!maybeIndexRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected index tensor to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned indexRank = *maybeIndexRank;
|
|
|
|
|
SmallVector<Value> sizes;
|
|
|
|
|
for (int64_t i = 0; i < indexRank; ++i) {
|
|
|
|
|
Value dim =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, index, /*dim=*/dim));
|
|
|
|
|
}
|
|
|
|
|
Value sizeList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, ListType::get(IntType::get(context)), sizes);
|
|
|
|
|
|
|
|
|
|
auto selfType = self.getType().cast<BaseTensorType>();
|
|
|
|
|
auto indexType = index.getType().cast<BaseTensorType>();
|
|
|
|
|
BaseTensorType srcType =
|
|
|
|
|
selfType
|
|
|
|
|
.getWithSizesAndDtype(indexType.getOptionalSizes(),
|
|
|
|
|
selfType.getOptionalDtype())
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
Value src =
|
|
|
|
|
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
|
|
|
|
|
op.getDim(), index, src);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-06-11 02:45:35 +08:00
|
|
|
|
namespace {
|
2024-04-08 20:05:42 +08:00
|
|
|
|
// Decompose `aten.sgn` op into comparisons and aten.where.
|
|
|
|
|
class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
|
2023-06-11 02:45:35 +08:00
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
2024-04-08 20:05:42 +08:00
|
|
|
|
LogicalResult matchAndRewrite(AtenSgnOp op,
|
2023-06-11 02:45:35 +08:00
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
2024-04-08 20:05:42 +08:00
|
|
|
|
auto outType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!outType.hasDtype()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"expected result type to have dtype");
|
|
|
|
|
}
|
|
|
|
|
// TODO: support complex type in future.
|
|
|
|
|
if (outType.getDtype().isa<mlir::ComplexType>()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"doesn't support complex type now");
|
|
|
|
|
}
|
2023-06-11 02:45:35 +08:00
|
|
|
|
|
|
|
|
|
auto zero =
|
2024-04-08 20:05:42 +08:00
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
2023-06-11 02:45:35 +08:00
|
|
|
|
auto one =
|
2024-04-08 20:05:42 +08:00
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
2023-06-11 02:45:35 +08:00
|
|
|
|
auto minusOne =
|
2024-04-08 20:05:42 +08:00
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
2023-06-11 02:45:35 +08:00
|
|
|
|
|
|
|
|
|
auto compTy = outType.getWithSizesAndDtype(outType.getOptionalSizes(),
|
|
|
|
|
rewriter.getI1Type());
|
|
|
|
|
|
|
|
|
|
auto greater =
|
|
|
|
|
rewriter.create<AtenGtScalarOp>(loc, compTy, op.getSelf(), zero);
|
2024-04-08 20:05:42 +08:00
|
|
|
|
auto less =
|
|
|
|
|
rewriter.create<AtenLtScalarOp>(loc, compTy, op.getSelf(), zero);
|
2023-06-11 02:45:35 +08:00
|
|
|
|
|
|
|
|
|
// Pseudo code:
|
2024-04-08 20:05:42 +08:00
|
|
|
|
// if (in > 0)
|
2023-06-11 02:45:35 +08:00
|
|
|
|
// return 1
|
2024-04-08 20:05:42 +08:00
|
|
|
|
// else if (in < 0)
|
2023-06-11 02:45:35 +08:00
|
|
|
|
// return -1
|
2024-04-08 20:05:42 +08:00
|
|
|
|
// else
|
|
|
|
|
// return 0
|
|
|
|
|
// note: return 0 if nan/0.0/-0.0
|
|
|
|
|
// return 1 if inf
|
|
|
|
|
// return -1 if -inf
|
2023-06-11 02:45:35 +08:00
|
|
|
|
auto selectGreater =
|
|
|
|
|
rewriter.create<AtenWhereScalarOp>(loc, outType, greater, one, zero);
|
|
|
|
|
|
2024-04-08 20:05:42 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<AtenWhereScalarSelfOp>(op, outType, less,
|
|
|
|
|
minusOne, selectGreater);
|
2023-06-11 02:45:35 +08:00
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-07-20 09:51:58 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
|
|
|
|
|
// `torch.to.dtype`.
|
|
|
|
|
class DecomposeAtenTypeAsOp : public OpRewritePattern<AtenTypeAsOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTypeAsOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto input = op.getSelf();
|
|
|
|
|
auto other = op.getOther();
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
|
|
|
|
|
Value targetDtype = rewriter.create<Torch::PrimDtypeOp>(loc, other);
|
|
|
|
|
Value nonBlocking = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value copy = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
|
|
|
Value memoryFormat = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
|
|
|
|
op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-08-15 19:36:08 +08:00
|
|
|
|
// AtenIndexTensorOp
|
|
|
|
|
namespace {
|
|
|
|
|
// The goal of this pattern is to eliminate none index in aten.Index.Tensor's
|
|
|
|
|
// `indices` param for the ease of various backend. The detailed steps are:
|
|
|
|
|
// 1. reorder input tensor so that the non-none index appears at adjacent
|
|
|
|
|
// positions.
|
|
|
|
|
// 2. manually generate index tensor with some ops like iota, to replace the
|
|
|
|
|
// none index in `indices`
|
|
|
|
|
// 3. replace the old aten.Index.Tensor with a new
|
|
|
|
|
// aten.Index.Tensor_hacked_twin.
|
|
|
|
|
class DecomposeAtenIndexTensorOp : public OpRewritePattern<AtenIndexTensorOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
|
|
|
|
|
// TODO: It might be better to use aten.view op instead of mulitple
|
|
|
|
|
// aten.unsqueeze. But currently, torch-to-linalg pass has limited support for
|
|
|
|
|
// view on dynamic shapes, such as [?] -> [?,1,1,1]. Using aten.view op will
|
|
|
|
|
// cause relevant e2e tests fail.
|
|
|
|
|
static FailureOr<Value>
|
|
|
|
|
unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter,
|
|
|
|
|
Value input, int count) {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(-1));
|
|
|
|
|
Value result = input;
|
|
|
|
|
while (count--) {
|
|
|
|
|
auto unsqzTensorInfo =
|
|
|
|
|
unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne);
|
|
|
|
|
if (failed(unsqzTensorInfo)) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
result = *unsqzTensorInfo;
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Value createIndexToReplaceNone(Operation *op,
|
|
|
|
|
PatternRewriter &rewriter, Value input,
|
|
|
|
|
int dimInt, int64_t dimSize) {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
MLIRContext *context = op->getContext();
|
|
|
|
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
|
|
|
auto int64Dtype = getDtypeIntValueForType(
|
|
|
|
|
rewriter, loc,
|
|
|
|
|
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
|
|
|
|
|
|
|
|
|
auto resultType = ValueTensorType::get(
|
|
|
|
|
context, {dimSize},
|
|
|
|
|
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
|
|
|
|
auto dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dimInt));
|
|
|
|
|
auto end = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
|
|
|
|
|
auto v = rewriter.create<Torch::AtenArangeOp>(
|
|
|
|
|
loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none,
|
|
|
|
|
/*device=*/none, /*pin_memory=*/none);
|
|
|
|
|
return v;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(AtenIndexTensorOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
SmallVector<Value> indices;
|
|
|
|
|
if (!getListConstructElements(op.getIndices(), indices))
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"failed to get elements of `indices`");
|
|
|
|
|
|
|
|
|
|
auto input = op.getSelf();
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only input with shape information is supported");
|
|
|
|
|
}
|
|
|
|
|
auto inputSizes = inputType.getSizes();
|
|
|
|
|
int64_t inputRank = inputSizes.size();
|
|
|
|
|
auto outputType = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!outputType.hasSizes()) {
|
2023-08-22 09:56:11 +08:00
|
|
|
|
return rewriter.notifyMatchFailure(
|
2023-08-15 19:36:08 +08:00
|
|
|
|
op, "only output with shape information is supported");
|
|
|
|
|
}
|
|
|
|
|
auto outputRank = outputType.getSizes().size();
|
|
|
|
|
|
|
|
|
|
auto isTensor = [](Value v) {
|
|
|
|
|
return v.getType().isa<Torch::BaseTensorType>();
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin
|
|
|
|
|
if (llvm::all_of(indices, isTensor)) {
|
|
|
|
|
if (indices.size() == 0) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "the indices is empty, it should be folded as a nop");
|
|
|
|
|
}
|
|
|
|
|
// By default, we regard the first index type as the list element type.
|
|
|
|
|
auto indexElemType = indices[0]
|
|
|
|
|
.getType()
|
|
|
|
|
.template cast<BaseTensorType>()
|
|
|
|
|
.getWithSizesAndDtype(std::nullopt, nullptr);
|
|
|
|
|
auto newIndex = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(indexElemType), indices);
|
|
|
|
|
rewriter.replaceOpWithNewOp<AtenIndexTensorHackedTwinOp>(op, op.getType(),
|
|
|
|
|
input, newIndex);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<bool> indexUsed =
|
|
|
|
|
llvm::to_vector(llvm::map_range(indices, isTensor));
|
2023-09-12 09:31:47 +08:00
|
|
|
|
for (int64_t i = indices.size(); i < inputRank; ++i)
|
2023-08-15 19:36:08 +08:00
|
|
|
|
indexUsed.emplace_back(false);
|
|
|
|
|
bool indexIsConsecutive = true;
|
|
|
|
|
int64_t firstUsedIndex = -1;
|
|
|
|
|
for (size_t i = 0; i < indices.size(); ++i) {
|
|
|
|
|
if (indexUsed[i] && firstUsedIndex == -1) {
|
|
|
|
|
firstUsedIndex = i;
|
|
|
|
|
} else if (indexUsed[i] && !indexUsed[i - 1]) {
|
|
|
|
|
indexIsConsecutive = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// use aten.permute to reorder the input
|
|
|
|
|
Value newInput;
|
|
|
|
|
// `dims` stores the mapping from new index to the old index of input
|
|
|
|
|
// tensor.
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
|
if (!indexIsConsecutive) {
|
|
|
|
|
SmallVector<Value> dimValues;
|
|
|
|
|
SmallVector<int64_t> permutedSizes;
|
|
|
|
|
for (int i = 0; i < inputRank; i++) {
|
|
|
|
|
if (indexUsed[i]) {
|
|
|
|
|
dims.emplace_back(i);
|
|
|
|
|
dimValues.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
permutedSizes.emplace_back(inputSizes[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < inputRank; i++) {
|
|
|
|
|
if (!indexUsed[i]) {
|
|
|
|
|
dims.emplace_back(i);
|
|
|
|
|
dimValues.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
permutedSizes.emplace_back(inputSizes[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), dimValues);
|
|
|
|
|
newInput = rewriter.create<Torch::AtenPermuteOp>(
|
|
|
|
|
loc,
|
|
|
|
|
inputType.getWithSizesAndDtype(permutedSizes,
|
|
|
|
|
inputType.getOptionalDtype()),
|
|
|
|
|
input, dimValueList);
|
|
|
|
|
} else {
|
|
|
|
|
newInput = input;
|
|
|
|
|
for (int i = 0; i < inputRank; i++) {
|
|
|
|
|
dims.emplace_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// manually generate new indices.
|
|
|
|
|
SmallVector<Value> listElements(inputRank);
|
|
|
|
|
|
|
|
|
|
int64_t trailingDimCnt = 0;
|
|
|
|
|
int64_t i;
|
|
|
|
|
// handle trailing none index.
|
|
|
|
|
for (i = inputRank - 1; i >= 0; --i) {
|
|
|
|
|
int64_t oldI = dims[i];
|
|
|
|
|
if (indexUsed[oldI])
|
|
|
|
|
break;
|
|
|
|
|
Value v =
|
|
|
|
|
createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]);
|
|
|
|
|
auto vInfo =
|
|
|
|
|
unsqueezeTensorAtTrailingDim(op, rewriter, v, trailingDimCnt);
|
|
|
|
|
if (failed(vInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
|
|
|
|
|
}
|
|
|
|
|
listElements[i] = *vInfo;
|
|
|
|
|
trailingDimCnt++;
|
|
|
|
|
}
|
|
|
|
|
// handle non-none index in between.
|
|
|
|
|
for (; i >= 0; --i) {
|
|
|
|
|
int64_t oldI = dims[i];
|
|
|
|
|
if (!indexUsed[oldI])
|
|
|
|
|
break;
|
|
|
|
|
auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, indices[oldI],
|
|
|
|
|
trailingDimCnt);
|
|
|
|
|
if (failed(vInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
|
|
|
|
|
}
|
|
|
|
|
listElements[i] = *vInfo;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// handle possible leading none dimensions.
|
|
|
|
|
for (; i >= 0; --i) {
|
|
|
|
|
int64_t oldI = dims[i];
|
|
|
|
|
if (indexUsed[oldI]) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "the indices are still unconsecutive after reordering input "
|
|
|
|
|
"tensor");
|
|
|
|
|
}
|
|
|
|
|
Value v =
|
|
|
|
|
createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]);
|
|
|
|
|
auto vInfo =
|
|
|
|
|
unsqueezeTensorAtTrailingDim(op, rewriter, v, outputRank - 1 - i);
|
|
|
|
|
if (failed(vInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor");
|
|
|
|
|
}
|
|
|
|
|
listElements[i] = *vInfo;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr);
|
|
|
|
|
auto newIndexList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(listElemType), listElements);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenIndexTensorHackedTwinOp>(
|
|
|
|
|
op, op.getType(), newInput, newIndexList);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-08-04 09:05:34 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Unconditionally decompose `aten.tile` into `aten.repeat`.
|
|
|
|
|
class DecomposeAtenTileOp : public OpRewritePattern<AtenTileOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTileOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
auto input = op.getSelf();
|
|
|
|
|
auto repeats = op.getDims();
|
|
|
|
|
SmallVector<Value> dimsElements;
|
|
|
|
|
if (!getListConstructElements(repeats, dimsElements)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to get elements of `dims` param");
|
|
|
|
|
}
|
|
|
|
|
auto dimsSize = dimsElements.size();
|
|
|
|
|
auto inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!inputType.hasSizes()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only support input tensor with shape information");
|
|
|
|
|
}
|
|
|
|
|
auto inputRank = inputType.getSizes().size();
|
|
|
|
|
if (dimsSize < inputRank) {
|
|
|
|
|
auto constantOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
op.getLoc(), rewriter.getI64IntegerAttr(1));
|
|
|
|
|
for (auto i = dimsSize; i < inputRank; ++i) {
|
|
|
|
|
dimsElements.insert(dimsElements.begin(), constantOne);
|
|
|
|
|
}
|
|
|
|
|
repeats = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
|
op.getLoc(),
|
|
|
|
|
Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
|
|
|
dimsElements);
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenRepeatOp>(op, op.getType(), input,
|
|
|
|
|
repeats);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-11-05 11:38:36 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Unconditionally decompose `aten.reshape_as` into `aten.size` +
|
|
|
|
|
// `aten.reshape`.
|
|
|
|
|
class DecomposeAtenReshapeAsOp : public OpRewritePattern<AtenReshapeAsOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenReshapeAsOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
MLIRContext *context = op->getContext();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value other = op.getOther();
|
|
|
|
|
|
|
|
|
|
auto otherShape = rewriter.create<Torch::AtenSizeOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(context)), other);
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(op, op.getType(), input,
|
|
|
|
|
otherShape);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-03-06 08:31:01 +08:00
|
|
|
|
namespace {
|
|
|
|
|
// Decompose AtenLinalgNormOp to AtenLinalgVectorNormOp only
|
|
|
|
|
class DecomposeAtenLinalgNormOp : public OpRewritePattern<AtenLinalgNormOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenLinalgNormOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
SmallVector<Value> dimList;
|
|
|
|
|
if (!getListConstructElements(op.getDim(), dimList)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "dim should comes from a PrimListConstructOp");
|
|
|
|
|
}
|
|
|
|
|
if (dimList.size() != 1) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: only dim size of 1 is supported");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// default ord value is 2 for vector_norm
|
|
|
|
|
auto ord = op.getOrd();
|
|
|
|
|
if (ord.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
ord = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
|
|
|
|
|
}
|
|
|
|
|
rewriter.replaceOpWithNewOp<Torch::AtenLinalgVectorNormOp>(
|
|
|
|
|
op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(),
|
|
|
|
|
op.getDtype());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2024-03-15 08:53:29 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeAtenFakeQuantizePerTensorAffineOp
|
|
|
|
|
: public OpRewritePattern<AtenFakeQuantizePerTensorAffineOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpRewritePattern<AtenFakeQuantizePerTensorAffineOp>::OpRewritePattern;
|
|
|
|
|
LogicalResult matchAndRewrite(AtenFakeQuantizePerTensorAffineOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = getContext();
|
|
|
|
|
|
|
|
|
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
|
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
|
Value one =
|
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
|
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
|
|
|
|
|
|
|
|
|
// input/scale
|
|
|
|
|
Value divScale = rewriter.create<AtenDivScalarOp>(
|
|
|
|
|
loc, op.getType(), op.getSelf(), op.getScale());
|
|
|
|
|
// std::nearby_int(input/scale)
|
|
|
|
|
Value round = rewriter.create<AtenRoundOp>(loc, op.getType(), divScale);
|
|
|
|
|
// std::nearby_int(input/scale) + zero_point
|
|
|
|
|
Value addZeroPoint = rewriter.create<AtenAddScalarOp>(
|
|
|
|
|
loc, op.getType(), round, op.getZeroPoint(), one);
|
|
|
|
|
// max(quant_min, std::nearby_int(input/scale) + zero_point)
|
|
|
|
|
Value max = rewriter.create<AtenMaximumOp>(
|
|
|
|
|
loc, op.getType(), addZeroPoint,
|
|
|
|
|
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMin(),
|
|
|
|
|
/*dtype=*/none,
|
|
|
|
|
/*device=*/none,
|
|
|
|
|
/*requires_grad=*/falseVal));
|
|
|
|
|
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
|
|
|
|
|
Value min = rewriter.create<AtenMinimumOp>(
|
|
|
|
|
loc, op.getType(), max,
|
|
|
|
|
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMax(),
|
|
|
|
|
/*dtype=*/none, /*device=*/none,
|
|
|
|
|
/*requires_grad=*/falseVal));
|
|
|
|
|
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
|
|
|
|
|
// - zero_point
|
|
|
|
|
Value subZeroPoint = rewriter.create<AtenSubScalarOp>(
|
|
|
|
|
loc, op.getType(), min, op.getZeroPoint(), one);
|
|
|
|
|
// (min(quant_max, max(quant_min, std::nearby_int(input/scale) +
|
|
|
|
|
// zero_point)) - zero_point) * scale
|
|
|
|
|
Value result = rewriter.create<AtenMulScalarOp>(
|
|
|
|
|
loc, op.getType(), subZeroPoint, op.getScale());
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2021-10-16 06:23:59 +08:00
|
|
|
|
namespace {
|
|
|
|
|
class DecomposeComplexOpsPass
|
|
|
|
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
2022-12-09 01:26:38 +08:00
|
|
|
|
private:
|
|
|
|
|
llvm::StringSet<> legalOpsSet;
|
|
|
|
|
|
|
|
|
|
template <typename DecomposePattern>
|
|
|
|
|
void addPatternIfTargetOpIsIllegal(RewritePatternSet &patterns) {
|
|
|
|
|
MLIRContext *context = &getContext();
|
2022-12-20 18:17:27 +08:00
|
|
|
|
std::optional<OperationName> opName =
|
|
|
|
|
DecomposePattern(context).getRootKind();
|
2022-12-09 01:26:38 +08:00
|
|
|
|
// Because the `DecomposeComplexOpsPass` uses a greedy algorithm
|
|
|
|
|
// to apply patterns, only patterns that we for sure know we want to run
|
|
|
|
|
// must be added. This restricts the set of patterns allowed in this file to
|
|
|
|
|
// patterns that apply to a single op. In other words, patterns that match
|
|
|
|
|
// on `Operation *` are not allowed, since there is no way of telling if
|
|
|
|
|
// that pattern will match on an op in the `legalOpsSet` or not.
|
|
|
|
|
assert(opName && "All decomposition patterns must target a single op");
|
2023-03-25 10:50:01 +08:00
|
|
|
|
if (!legalOpsSet.contains(opName->getStringRef().ltrim(kTorchOpPrefix)))
|
2022-12-09 01:26:38 +08:00
|
|
|
|
patterns.add<DecomposePattern>(context);
|
|
|
|
|
}
|
|
|
|
|
|
2022-08-18 07:23:52 +08:00
|
|
|
|
public:
|
|
|
|
|
DecomposeComplexOpsPass() = default;
|
|
|
|
|
DecomposeComplexOpsPass(ArrayRef<std::string> legalOps) {
|
|
|
|
|
this->legalOps = legalOps;
|
|
|
|
|
}
|
2021-10-16 06:23:59 +08:00
|
|
|
|
void runOnOperation() override {
|
|
|
|
|
MLIRContext *context = &getContext();
|
|
|
|
|
RewritePatternSet patterns(context);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
// The strings in the `legalOps` ArrayRef don't exist during the call to the
|
|
|
|
|
// constructor `DecomposeComplexOpsPass`, so the creation of the
|
|
|
|
|
// `legalOpsSet` must be delayed to when `runOnOperation` gets called.
|
|
|
|
|
legalOpsSet.clear();
|
|
|
|
|
legalOpsSet.insert(legalOps.begin(), legalOps.end());
|
|
|
|
|
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
|
2023-03-11 09:25:25 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
|
2023-10-31 15:36:16 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenUnflattenIntOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
2024-01-16 14:29:34 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
|
2023-02-11 05:58:39 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
|
|
|
|
|
patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
2024-02-28 14:48:07 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAMinMaxOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
2024-03-14 03:17:22 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
|
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
|
|
|
|
|
patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
|
2024-02-19 22:23:48 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenInstanceNormOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
|
2023-12-13 11:05:12 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
|
|
|
|
|
patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionBackwardOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
2024-03-14 08:28:33 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
|
2024-01-30 01:59:33 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeViewOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_ReshapeAliasOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliOp>(patterns);
|
2023-02-16 01:06:29 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeAtenBernoulliLikeOp<ValsemVariantAtenBernoulliFloatOp>>(
|
2022-12-09 01:26:38 +08:00
|
|
|
|
patterns);
|
2023-02-16 01:06:29 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
|
2023-12-28 12:33:18 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenExponentialOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
2023-11-02 02:23:28 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
|
2023-06-07 10:06:27 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
|
2023-11-04 22:26:01 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
|
2024-01-16 14:29:34 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
2024-03-29 08:05:00 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
2023-12-10 12:30:37 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
2024-02-10 00:00:24 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
|
|
|
|
|
patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<
|
|
|
|
|
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
|
2023-09-12 22:29:08 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
|
2022-12-22 10:13:59 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
|
2023-06-27 14:19:33 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
2023-11-01 11:56:54 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(
|
|
|
|
|
patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
2024-04-10 22:26:48 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
|
[Torch Dialect] add support for adaptive_avgpool_1d (#2342)
* [MLIR][TORCH] Fix aten.cumsum lowering for int32 input (#2351)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
update PyTorch version to 2.1.0.dev20230729 (#2354)
- torch version: 2.1.0.dev20230729
- torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
- torchvision version: 0.16.0.dev20230729
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230730 (#2356)
- torch version: 2.1.0.dev20230730
- torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
- torchvision version: 0.16.0.dev20230730
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230731 (#2359)
- torch version: 2.1.0.dev20230731
- torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
- torchvision version: 0.16.0.dev20230731
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support
* SW-95317 Propagate Lazy->Jit->MLIR scope name.
* Enhance location information based on op names
Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.
* Update locations logic; updated debug-info.py test
* Use {scope}/{op_name} format to track names by default
---------
Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
build: update llvm tag to 41895843
Summary of changes:
- Update tags
llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
update PyTorch version to 2.1.0.dev20230802 (#2366)
- torch version: 2.1.0.dev20230802
- torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
- torchvision version: 0.16.0.dev20230802
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Change Python version from 3.10 to 3.11 in installation instructions (#2370)
Add CITATION file (#2371)
Add packaging as an install dependency (#2369)
Needed by `torch_mlir._version`. Resolves #2368.
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
update PyTorch version to 2.1.0.dev20230803 (#2372)
- torch version: 2.1.0.dev20230803
- torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
- torchvision version: 0.16.0.dev20230803
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Prevent failed stable CI job from cancelling nightly jobs (#2373)
The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)
update
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update
update:
update
parent 22e88d523b1970b2e904eb5421d49d987a3d255e
author jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114110 +0800
committer jianzhe.xiao <jianzhe.xiao@bytedance.com> 1691114119 +0800
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op (#2340)
[Stablehlo] Add converter to stablehlo for aten.(Int,Float,Bool).Tensor op and configure crashing e2e sets for stablehlo backend.
update PyTorch version to 2.1.0.dev20230729 (#2354)
- torch version: 2.1.0.dev20230729
- torch commit hash: b638df0afb83572724032c824c64e481bb4499a0
- torchvision version: 0.16.0.dev20230729
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230730 (#2356)
- torch version: 2.1.0.dev20230730
- torch commit hash: 0ff243ff350268cc98fe03fa6364375ee2824742
- torchvision version: 0.16.0.dev20230730
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
update PyTorch version to 2.1.0.dev20230731 (#2359)
- torch version: 2.1.0.dev20230731
- torch commit hash: 6298ac688f8caafe30d71ff2ea2e20fbb32065c7
- torchvision version: 0.16.0.dev20230731
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
LTC->MLIR Debug Info support (#1922)
* LTC->MLIR Debug Info support
* SW-95317 Propagate Lazy->Jit->MLIR scope name.
* Enhance location information based on op names
Currently, the location information attached to the ops just considers
the filename, line number and column number. Attaching operation name
would help identify the type of computation by just looking at the
profile of execution.
* Update locations logic; updated debug-info.py test
* Use {scope}/{op_name} format to track names by default
---------
Co-authored-by: Gleb Kazantaev <gleb.kazantaev@cerebras.net>
Co-authored-by: Mark Browning <mark@cerebras.net>
Co-authored-by: Vimal Patel <vimal@polymagelabs.com>
build: update llvm tag to 41895843
Summary of changes:
- Update tags
llvm: 41895843b5915bb78e9d02aa711fa10f7174db43
mhlo: 4726d31f7025da66de0dea709bd56c462edb83c2
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
update PyTorch version to 2.1.0.dev20230802 (#2366)
- torch version: 2.1.0.dev20230802
- torch commit hash: c89b16917755c2abbef7b6420e340baf9ae8089e
- torchvision version: 0.16.0.dev20230802
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Change Python version from 3.10 to 3.11 in installation instructions (#2370)
Add CITATION file (#2371)
Add packaging as an install dependency (#2369)
Needed by `torch_mlir._version`. Resolves #2368.
[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)
* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
update PyTorch version to 2.1.0.dev20230803 (#2372)
- torch version: 2.1.0.dev20230803
- torch commit hash: f89c73be3a3e8274d025ac46a33a780853841c9e
- torchvision version: 0.16.0.dev20230803
Co-authored-by: Roll PyTorch Action <torch-mlir@users.noreply.github.com>
Prevent failed stable CI job from cancelling nightly jobs (#2373)
The CI jobs that use stable PyTorch are currently not required to pass
in order for a patch to get merged in `main`. This commit makes sure
that if a CI job for stable PyTorch fails, it does not cancel the
other required jobs.
[Torch Dialect] emit aten.tile op and decompose it into aten.repeat (#2355)
update
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update
update:
add support for adaptive_pool_id
update xfail sets
update xfail_sets
update
fix xfail_sets
update:
update:
* update
---------
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-08-05 07:48:09 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
|
2023-12-05 13:55:51 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
|
2023-11-08 15:28:30 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
2022-12-22 13:02:40 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
2024-03-06 07:01:21 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitSizesOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
|
2023-07-20 16:46:44 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
|
2023-10-26 10:36:18 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenGluOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
|
2023-02-21 12:08:29 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNormScalarOptDimOp>(patterns);
|
2023-04-04 17:31:21 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
|
2023-01-11 14:01:45 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposePrimsVarOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposePrimsSqrtOp>(patterns);
|
2023-09-02 02:13:58 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandOp>(patterns);
|
2022-12-09 01:26:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
2023-01-16 19:40:21 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
|
2024-01-16 14:49:29 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNormalFunctionalOp>(patterns);
|
2022-12-09 23:22:26 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
2023-08-24 00:04:14 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
|
2024-03-15 08:53:29 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenFakeQuantizePerTensorAffineOp>(
|
|
|
|
|
patterns);
|
2023-12-14 12:28:08 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
|
2023-01-04 00:30:16 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
2024-02-01 01:39:38 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpScalarOp>(patterns);
|
2022-12-29 22:52:23 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
2023-09-14 01:04:31 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
|
2023-02-03 10:20:47 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
2024-03-06 07:01:21 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposePrimTolistOp>(patterns);
|
2022-11-16 13:57:58 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
|
2023-01-02 22:34:39 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
2023-04-11 16:02:28 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
2023-03-15 16:00:03 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
2023-04-26 15:14:06 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
2023-05-02 21:29:00 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
2023-06-01 11:38:50 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
|
2022-10-16 05:46:06 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
|
2024-04-08 20:05:42 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSgnOp>(patterns);
|
2023-07-20 09:51:58 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
|
2023-08-04 09:05:34 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
|
2023-11-05 11:38:36 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
|
2023-08-15 19:36:08 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
|
[Torch Dialect] Decompose AtenTriuOp (#2561)
decompose like:
```
import torch
def my_triu(x, diag):
rows = torch.ops.aten.size(x, -2)
cols = torch.ops.aten.size(x, -1)
row_indices = torch.ops.aten.arange(rows).unsqueeze(1)
col_indices = torch.ops.aten.arange(cols).unsqueeze(0)
cond = torch.ops.aten.ge(
col_indices, torch.ops.aten.add(row_indices, diag))
return torch.ops.aten.where(cond, x, 0)
x = torch.rand(5, 7)
assert torch.allclose(my_triu(x, 0), torch.triu(x, 0))
assert torch.allclose(my_triu(x, 1), torch.triu(x, 1))
assert torch.allclose(my_triu(x, 2), torch.triu(x, 2))
assert torch.allclose(my_triu(x, -1), torch.triu(x, -1))
```
---------
Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>
2023-11-29 10:35:26 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
|
2024-03-06 08:31:01 +08:00
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
|
2024-01-24 13:30:03 +08:00
|
|
|
|
// More specific conv ops
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTbcOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
|
|
|
|
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConv3dOp>(patterns);
|
|
|
|
|
|
2022-12-09 01:26:38 +08:00
|
|
|
|
GreedyRewriteConfig config;
|
|
|
|
|
config.useTopDownTraversal = true;
|
2023-01-11 07:07:19 +08:00
|
|
|
|
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
2022-12-09 01:26:38 +08:00
|
|
|
|
|
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
|
|
|
|
config))) {
|
2021-10-16 06:23:59 +08:00
|
|
|
|
return signalPassFailure();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
2022-08-19 08:01:54 +08:00
|
|
|
|
|
2022-04-27 03:27:51 +08:00
|
|
|
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
2022-08-19 08:01:54 +08:00
|
|
|
|
mlir::torch::Torch::createDecomposeComplexOpsPass(
|
|
|
|
|
ArrayRef<std::string> legalOps) {
|
|
|
|
|
return std::make_unique<DecomposeComplexOpsPass>(legalOps);
|
2023-01-21 02:40:13 +08:00
|
|
|
|
}
|