mirror of https://github.com/llvm/torch-mlir
[Torch] support decompose aten.einsum with ellipsis slicing (#3056)
parent
5f325749f9
commit
e6e7689a24
|
@ -21,6 +21,7 @@
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/ADT/StringSet.h"
|
#include "llvm/ADT/StringSet.h"
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -158,6 +159,105 @@ static SmallVector<int64_t> computeDimsOrderForMoveDim(int64_t srcDimInt,
|
||||||
return dimsOrder;
|
return dimsOrder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool
|
||||||
|
rewriteEquationWithEllipsisSlicing(std::string &equation,
|
||||||
|
SmallVector<int64_t> &inputRanks) {
|
||||||
|
// split equation into input and result
|
||||||
|
size_t arrowPos = equation.find("->");
|
||||||
|
if (arrowPos == std::string::npos) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string inputStr = equation.substr(0, arrowPos);
|
||||||
|
std::string resultStr = equation.substr(arrowPos + 2);
|
||||||
|
|
||||||
|
// split input into tokens
|
||||||
|
SmallVector<std::string> inputTokens;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = 0;
|
||||||
|
std::set<char> usedTokens;
|
||||||
|
while (end < inputStr.size()) {
|
||||||
|
end = inputStr.find(",", start);
|
||||||
|
if (end == std::string::npos) {
|
||||||
|
end = inputStr.size();
|
||||||
|
}
|
||||||
|
std::string token = inputStr.substr(start, end - start);
|
||||||
|
inputTokens.push_back(token);
|
||||||
|
start = end + 1;
|
||||||
|
}
|
||||||
|
if (inputTokens.size() != inputRanks.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the rank which ellipsis represents, and max ellipsis rank because a
|
||||||
|
// tensor can be broadcasted
|
||||||
|
SmallVector<int64_t> ellipsisRanks;
|
||||||
|
int maxEllipsisRank = 0;
|
||||||
|
for (const auto &[token, inputRank] : llvm::zip(inputTokens, inputRanks)) {
|
||||||
|
int explictRank = 0;
|
||||||
|
for (auto c : token) {
|
||||||
|
if (std::isalpha(c)) {
|
||||||
|
usedTokens.insert(c);
|
||||||
|
explictRank++;
|
||||||
|
} else if (c == '.' || c == ' ') {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int ellipsisRank = inputRank - explictRank;
|
||||||
|
if (ellipsisRank > maxEllipsisRank) {
|
||||||
|
maxEllipsisRank = ellipsisRank;
|
||||||
|
}
|
||||||
|
if (ellipsisRank < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ellipsisRanks.push_back(inputRank - explictRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto isTokenUsed = [&usedTokens](char c) {
|
||||||
|
return usedTokens.find(c) != usedTokens.end();
|
||||||
|
};
|
||||||
|
std::string ellipsisToken;
|
||||||
|
int usedCount = 0;
|
||||||
|
// Iterate over the alphabet to create a new token for ellipsis
|
||||||
|
for (char c = 'a'; c <= 'z'; ++c) {
|
||||||
|
if (!isTokenUsed(c)) {
|
||||||
|
ellipsisToken.push_back(c);
|
||||||
|
usedCount++;
|
||||||
|
if (usedCount == maxEllipsisRank) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// replace ellipsis with ellipsisToken
|
||||||
|
for (size_t i = 0; i < inputTokens.size(); i++) {
|
||||||
|
size_t ellipsisPos = inputTokens[i].find("...");
|
||||||
|
if (ellipsisPos == std::string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (ellipsisRanks[i] == maxEllipsisRank) {
|
||||||
|
inputTokens[i].replace(ellipsisPos, 3, ellipsisToken);
|
||||||
|
} else if (ellipsisRanks[i] == 0) {
|
||||||
|
inputTokens[i].replace(ellipsisPos, 3, "");
|
||||||
|
} else {
|
||||||
|
inputTokens[i].replace(
|
||||||
|
ellipsisPos, 3,
|
||||||
|
ellipsisToken.substr(ellipsisToken.size() - ellipsisRanks[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// replace ellipsis in result
|
||||||
|
size_t ellipsisPos = resultStr.find("...");
|
||||||
|
if (ellipsisPos != std::string::npos) {
|
||||||
|
resultStr.replace(ellipsisPos, 3, ellipsisToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
// join input and result
|
||||||
|
equation = llvm::join(inputTokens, ",") + " -> " + resultStr;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static bool parseEquation(const std::string &equation,
|
static bool parseEquation(const std::string &equation,
|
||||||
SmallVector<SmallVector<char>> &inputTokens,
|
SmallVector<SmallVector<char>> &inputTokens,
|
||||||
SmallVector<char> &resultTokens) {
|
SmallVector<char> &resultTokens) {
|
||||||
|
@ -1135,16 +1235,6 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenEinsumOp op,
|
LogicalResult matchAndRewrite(AtenEinsumOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
std::string equation;
|
|
||||||
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
|
|
||||||
}
|
|
||||||
SmallVector<char> resultTokens;
|
|
||||||
SmallVector<SmallVector<char>> inputTokens;
|
|
||||||
if (!parseEquation(equation, inputTokens, resultTokens)) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Unexpected character in equations encountered");
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> inputTensors;
|
SmallVector<Value> inputTensors;
|
||||||
if (!getListConstructElements(op.getTensors(), inputTensors)) {
|
if (!getListConstructElements(op.getTensors(), inputTensors)) {
|
||||||
|
@ -1164,6 +1254,30 @@ public:
|
||||||
"all input tensors should have sizes");
|
"all input tensors should have sizes");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string equation;
|
||||||
|
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
|
||||||
|
}
|
||||||
|
// if "..." in equation, modify it
|
||||||
|
if (equation.find("...") != std::string::npos) {
|
||||||
|
SmallVector<int64_t> inputRanks;
|
||||||
|
for (Value tensor : inputTensors) {
|
||||||
|
auto type = tensor.getType().cast<BaseTensorType>();
|
||||||
|
inputRanks.push_back(type.getSizes().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!rewriteEquationWithEllipsisSlicing(equation, inputRanks)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unexpected character in equations encountered");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SmallVector<char> resultTokens;
|
||||||
|
SmallVector<SmallVector<char>> inputTokens;
|
||||||
|
if (!parseEquation(equation, inputTokens, resultTokens)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unexpected character in equations encountered");
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<char> lhsTokens = inputTokens[0];
|
SmallVector<char> lhsTokens = inputTokens[0];
|
||||||
Value lhs = inputTensors[0];
|
Value lhs = inputTensors[0];
|
||||||
Value result;
|
Value result;
|
||||||
|
|
|
@ -467,6 +467,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"EinsumStaticContractRhsModule_basic",
|
"EinsumStaticContractRhsModule_basic",
|
||||||
"EinsumStaticFourDimensionModule_basic",
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
"EinsumStaticModule_basic",
|
"EinsumStaticModule_basic",
|
||||||
|
"EinsumStaticWithEllipsisSlicingModule_basic",
|
||||||
|
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
|
||||||
"ElementwiseAbsFloatModule_basic",
|
"ElementwiseAbsFloatModule_basic",
|
||||||
"ElementwiseAbsIntModule_basic",
|
"ElementwiseAbsIntModule_basic",
|
||||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
|
@ -954,6 +956,8 @@ TOSA_PASS_SET = {
|
||||||
"EinsumStaticContractRhsModule_basic",
|
"EinsumStaticContractRhsModule_basic",
|
||||||
"EinsumStaticFourDimensionModule_basic",
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
"EinsumStaticModule_basic",
|
"EinsumStaticModule_basic",
|
||||||
|
"EinsumStaticWithEllipsisSlicingModule_basic",
|
||||||
|
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
|
||||||
"ElementwiseAbsFloatModule_basic",
|
"ElementwiseAbsFloatModule_basic",
|
||||||
"ElementwiseAbsIntModule_basic",
|
"ElementwiseAbsIntModule_basic",
|
||||||
"ElementwiseAddModule_basic",
|
"ElementwiseAddModule_basic",
|
||||||
|
@ -1923,6 +1927,8 @@ ONNX_XFAIL_SET = {
|
||||||
"EinsumStaticContractRhsModule_basic",
|
"EinsumStaticContractRhsModule_basic",
|
||||||
"EinsumStaticFourDimensionModule_basic",
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
"EinsumStaticModule_basic",
|
"EinsumStaticModule_basic",
|
||||||
|
"EinsumStaticWithEllipsisSlicingModule_basic",
|
||||||
|
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.MaxPool
|
# Failure - onnx_lowering: onnx.MaxPool
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
|
|
|
@ -1100,3 +1100,38 @@ class EinsumStaticContractRhsModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule())
|
@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule())
|
||||||
def EinsumStaticContractRhsModule_basic(module, tu: TestUtils):
|
def EinsumStaticContractRhsModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5))
|
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5))
|
||||||
|
|
||||||
|
class EinsumStaticWithEllipsisSlicingModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4, 6], torch.float32, True),
|
||||||
|
([3, 6, 5], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor1, tensor2):
|
||||||
|
return torch.ops.aten.einsum('...mn,...nd->...md', [tensor1, tensor2])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingModule())
|
||||||
|
def EinsumStaticWithEllipsisSlicingModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 6), tu.rand(3, 6, 5))
|
||||||
|
|
||||||
|
class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 6, 4, 5], torch.float32, True),
|
||||||
|
([6, 5], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor1, tensor2):
|
||||||
|
# should be abnd,bd -> abn
|
||||||
|
return torch.ops.aten.einsum('...nd,...d->...n', [tensor1, tensor2])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule())
|
||||||
|
def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5))
|
Loading…
Reference in New Issue