[Torch] support decompose aten.einsum with ellipsis slicing (#3056)

pull/3073/head
Xinyu Yang 2024-03-28 03:42:10 +08:00 committed by GitHub
parent 5f325749f9
commit e6e7689a24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 166 additions and 11 deletions

View File

@ -21,6 +21,7 @@
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSet.h"
#include <cstdint> #include <cstdint>
#include <set>
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -158,6 +159,105 @@ static SmallVector<int64_t> computeDimsOrderForMoveDim(int64_t srcDimInt,
return dimsOrder; return dimsOrder;
} }
static bool
rewriteEquationWithEllipsisSlicing(std::string &equation,
SmallVector<int64_t> &inputRanks) {
// split equation into input and result
size_t arrowPos = equation.find("->");
if (arrowPos == std::string::npos) {
return false;
}
std::string inputStr = equation.substr(0, arrowPos);
std::string resultStr = equation.substr(arrowPos + 2);
// split input into tokens
SmallVector<std::string> inputTokens;
size_t start = 0;
size_t end = 0;
std::set<char> usedTokens;
while (end < inputStr.size()) {
end = inputStr.find(",", start);
if (end == std::string::npos) {
end = inputStr.size();
}
std::string token = inputStr.substr(start, end - start);
inputTokens.push_back(token);
start = end + 1;
}
if (inputTokens.size() != inputRanks.size()) {
return false;
}
// find the rank which ellipsis represents, and max ellipsis rank because a
// tensor can be broadcasted
SmallVector<int64_t> ellipsisRanks;
int maxEllipsisRank = 0;
for (const auto &[token, inputRank] : llvm::zip(inputTokens, inputRanks)) {
int explictRank = 0;
for (auto c : token) {
if (std::isalpha(c)) {
usedTokens.insert(c);
explictRank++;
} else if (c == '.' || c == ' ') {
continue;
} else {
return false;
}
}
int ellipsisRank = inputRank - explictRank;
if (ellipsisRank > maxEllipsisRank) {
maxEllipsisRank = ellipsisRank;
}
if (ellipsisRank < 0) {
return false;
}
ellipsisRanks.push_back(inputRank - explictRank);
}
auto isTokenUsed = [&usedTokens](char c) {
return usedTokens.find(c) != usedTokens.end();
};
std::string ellipsisToken;
int usedCount = 0;
// Iterate over the alphabet to create a new token for ellipsis
for (char c = 'a'; c <= 'z'; ++c) {
if (!isTokenUsed(c)) {
ellipsisToken.push_back(c);
usedCount++;
if (usedCount == maxEllipsisRank) {
break;
}
}
}
// replace ellipsis with ellipsisToken
for (size_t i = 0; i < inputTokens.size(); i++) {
size_t ellipsisPos = inputTokens[i].find("...");
if (ellipsisPos == std::string::npos) {
continue;
}
if (ellipsisRanks[i] == maxEllipsisRank) {
inputTokens[i].replace(ellipsisPos, 3, ellipsisToken);
} else if (ellipsisRanks[i] == 0) {
inputTokens[i].replace(ellipsisPos, 3, "");
} else {
inputTokens[i].replace(
ellipsisPos, 3,
ellipsisToken.substr(ellipsisToken.size() - ellipsisRanks[i]));
}
}
// replace ellipsis in result
size_t ellipsisPos = resultStr.find("...");
if (ellipsisPos != std::string::npos) {
resultStr.replace(ellipsisPos, 3, ellipsisToken);
}
// join input and result
equation = llvm::join(inputTokens, ",") + " -> " + resultStr;
return true;
}
static bool parseEquation(const std::string &equation, static bool parseEquation(const std::string &equation,
SmallVector<SmallVector<char>> &inputTokens, SmallVector<SmallVector<char>> &inputTokens,
SmallVector<char> &resultTokens) { SmallVector<char> &resultTokens) {
@ -1135,16 +1235,6 @@ public:
LogicalResult matchAndRewrite(AtenEinsumOp op, LogicalResult matchAndRewrite(AtenEinsumOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
std::string equation;
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
}
SmallVector<char> resultTokens;
SmallVector<SmallVector<char>> inputTokens;
if (!parseEquation(equation, inputTokens, resultTokens)) {
return rewriter.notifyMatchFailure(
op, "Unexpected character in equations encountered");
}
SmallVector<Value> inputTensors; SmallVector<Value> inputTensors;
if (!getListConstructElements(op.getTensors(), inputTensors)) { if (!getListConstructElements(op.getTensors(), inputTensors)) {
@ -1164,6 +1254,30 @@ public:
"all input tensors should have sizes"); "all input tensors should have sizes");
} }
std::string equation;
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
}
// if "..." in equation, modify it
if (equation.find("...") != std::string::npos) {
SmallVector<int64_t> inputRanks;
for (Value tensor : inputTensors) {
auto type = tensor.getType().cast<BaseTensorType>();
inputRanks.push_back(type.getSizes().size());
}
if (!rewriteEquationWithEllipsisSlicing(equation, inputRanks)) {
return rewriter.notifyMatchFailure(
op, "Unexpected character in equations encountered");
}
}
SmallVector<char> resultTokens;
SmallVector<SmallVector<char>> inputTokens;
if (!parseEquation(equation, inputTokens, resultTokens)) {
return rewriter.notifyMatchFailure(
op, "Unexpected character in equations encountered");
}
SmallVector<char> lhsTokens = inputTokens[0]; SmallVector<char> lhsTokens = inputTokens[0];
Value lhs = inputTensors[0]; Value lhs = inputTensors[0];
Value result; Value result;

View File

@ -467,6 +467,8 @@ STABLEHLO_PASS_SET = {
"EinsumStaticContractRhsModule_basic", "EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic", "EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic", "EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"ElementwiseAbsFloatModule_basic", "ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic", "ElementwiseAbsIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic",
@ -954,6 +956,8 @@ TOSA_PASS_SET = {
"EinsumStaticContractRhsModule_basic", "EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic", "EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic", "EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"ElementwiseAbsFloatModule_basic", "ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic", "ElementwiseAbsIntModule_basic",
"ElementwiseAddModule_basic", "ElementwiseAddModule_basic",
@ -1923,6 +1927,8 @@ ONNX_XFAIL_SET = {
"EinsumStaticContractRhsModule_basic", "EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic", "EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic", "EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
# Failure - onnx_lowering: onnx.MaxPool # Failure - onnx_lowering: onnx.MaxPool
"MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesAllNegativeValuesModule_basic",

View File

@ -1100,3 +1100,38 @@ class EinsumStaticContractRhsModule(torch.nn.Module):
@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule()) @register_test_case(module_factory=lambda: EinsumStaticContractRhsModule())
def EinsumStaticContractRhsModule_basic(module, tu: TestUtils): def EinsumStaticContractRhsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) module.forward(tu.rand(3, 4, 5), tu.rand(4, 5))
class EinsumStaticWithEllipsisSlicingModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 6], torch.float32, True),
([3, 6, 5], torch.float32, True),
])
def forward(self, tensor1, tensor2):
return torch.ops.aten.einsum('...mn,...nd->...md', [tensor1, tensor2])
@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingModule())
def EinsumStaticWithEllipsisSlicingModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 6), tu.rand(3, 6, 5))
class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 6, 4, 5], torch.float32, True),
([6, 5], torch.float32, True),
])
def forward(self, tensor1, tensor2):
# should be abnd,bd -> abn
return torch.ops.aten.einsum('...nd,...d->...n', [tensor1, tensor2])
@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule())
def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5))