From e6e7689a24ca16453d59a71b7d330903cecce8d7 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Thu, 28 Mar 2024 03:42:10 +0800 Subject: [PATCH] [Torch] support decompose aten.einsum with ellipsis slicing (#3056) --- .../Torch/Transforms/DecomposeComplexOps.cpp | 134 ++++++++++++++++-- projects/pt1/e2e_testing/xfail_sets.py | 6 + .../test_suite/reshape_like.py | 37 ++++- 3 files changed, 166 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2a0ed9428..e25c6808e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include +#include using namespace mlir; using namespace mlir::torch; @@ -158,6 +159,105 @@ static SmallVector computeDimsOrderForMoveDim(int64_t srcDimInt, return dimsOrder; } +static bool +rewriteEquationWithEllipsisSlicing(std::string &equation, + SmallVector &inputRanks) { + // split equation into input and result + size_t arrowPos = equation.find("->"); + if (arrowPos == std::string::npos) { + return false; + } + std::string inputStr = equation.substr(0, arrowPos); + std::string resultStr = equation.substr(arrowPos + 2); + + // split input into tokens + SmallVector inputTokens; + size_t start = 0; + size_t end = 0; + std::set usedTokens; + while (end < inputStr.size()) { + end = inputStr.find(",", start); + if (end == std::string::npos) { + end = inputStr.size(); + } + std::string token = inputStr.substr(start, end - start); + inputTokens.push_back(token); + start = end + 1; + } + if (inputTokens.size() != inputRanks.size()) { + return false; + } + + // find the rank which ellipsis represents, and max ellipsis rank because a + // tensor can be broadcasted + SmallVector ellipsisRanks; + int maxEllipsisRank = 0; + for (const auto &[token, inputRank] : llvm::zip(inputTokens, inputRanks)) { + int explictRank = 0; + for (auto c : token) { + if (std::isalpha(c)) { + usedTokens.insert(c); + explictRank++; + } else if (c == '.' || c == ' ') { + continue; + } else { + return false; + } + } + int ellipsisRank = inputRank - explictRank; + if (ellipsisRank > maxEllipsisRank) { + maxEllipsisRank = ellipsisRank; + } + if (ellipsisRank < 0) { + return false; + } + ellipsisRanks.push_back(inputRank - explictRank); + } + + auto isTokenUsed = [&usedTokens](char c) { + return usedTokens.find(c) != usedTokens.end(); + }; + std::string ellipsisToken; + int usedCount = 0; + // Iterate over the alphabet to create a new token for ellipsis + for (char c = 'a'; c <= 'z'; ++c) { + if (!isTokenUsed(c)) { + ellipsisToken.push_back(c); + usedCount++; + if (usedCount == maxEllipsisRank) { + break; + } + } + } + + // replace ellipsis with ellipsisToken + for (size_t i = 0; i < inputTokens.size(); i++) { + size_t ellipsisPos = inputTokens[i].find("..."); + if (ellipsisPos == std::string::npos) { + continue; + } + if (ellipsisRanks[i] == maxEllipsisRank) { + inputTokens[i].replace(ellipsisPos, 3, ellipsisToken); + } else if (ellipsisRanks[i] == 0) { + inputTokens[i].replace(ellipsisPos, 3, ""); + } else { + inputTokens[i].replace( + ellipsisPos, 3, + ellipsisToken.substr(ellipsisToken.size() - ellipsisRanks[i])); + } + } + + // replace ellipsis in result + size_t ellipsisPos = resultStr.find("..."); + if (ellipsisPos != std::string::npos) { + resultStr.replace(ellipsisPos, 3, ellipsisToken); + } + + // join input and result + equation = llvm::join(inputTokens, ",") + " -> " + resultStr; + return true; +} + static bool parseEquation(const std::string &equation, SmallVector> &inputTokens, SmallVector &resultTokens) { @@ -1135,16 +1235,6 @@ public: LogicalResult matchAndRewrite(AtenEinsumOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - std::string equation; - if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) { - return rewriter.notifyMatchFailure(op, "Unsupported value of equation"); - } - SmallVector resultTokens; - SmallVector> inputTokens; - if (!parseEquation(equation, inputTokens, resultTokens)) { - return rewriter.notifyMatchFailure( - op, "Unexpected character in equations encountered"); - } SmallVector inputTensors; if (!getListConstructElements(op.getTensors(), inputTensors)) { @@ -1164,6 +1254,30 @@ public: "all input tensors should have sizes"); } + std::string equation; + if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) { + return rewriter.notifyMatchFailure(op, "Unsupported value of equation"); + } + // if "..." in equation, modify it + if (equation.find("...") != std::string::npos) { + SmallVector inputRanks; + for (Value tensor : inputTensors) { + auto type = tensor.getType().cast(); + inputRanks.push_back(type.getSizes().size()); + } + + if (!rewriteEquationWithEllipsisSlicing(equation, inputRanks)) { + return rewriter.notifyMatchFailure( + op, "Unexpected character in equations encountered"); + } + } + SmallVector resultTokens; + SmallVector> inputTokens; + if (!parseEquation(equation, inputTokens, resultTokens)) { + return rewriter.notifyMatchFailure( + op, "Unexpected character in equations encountered"); + } + SmallVector lhsTokens = inputTokens[0]; Value lhs = inputTensors[0]; Value result; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 64e352cf6..26314b1eb 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -467,6 +467,8 @@ STABLEHLO_PASS_SET = { "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", + "EinsumStaticWithEllipsisSlicingModule_basic", + "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", "ElementwiseAbsFloatModule_basic", "ElementwiseAbsIntModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", @@ -954,6 +956,8 @@ TOSA_PASS_SET = { "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", + "EinsumStaticWithEllipsisSlicingModule_basic", + "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", "ElementwiseAbsFloatModule_basic", "ElementwiseAbsIntModule_basic", "ElementwiseAddModule_basic", @@ -1923,6 +1927,8 @@ ONNX_XFAIL_SET = { "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", + "EinsumStaticWithEllipsisSlicingModule_basic", + "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", # Failure - onnx_lowering: onnx.MaxPool "MaxPool2dWithIndicesAllNegativeValuesModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 73371058c..d5a8aaf76 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1099,4 +1099,39 @@ class EinsumStaticContractRhsModule(torch.nn.Module): @register_test_case(module_factory=lambda: EinsumStaticContractRhsModule()) def EinsumStaticContractRhsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) \ No newline at end of file + module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) + +class EinsumStaticWithEllipsisSlicingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 6], torch.float32, True), + ([3, 6, 5], torch.float32, True), + ]) + def forward(self, tensor1, tensor2): + return torch.ops.aten.einsum('...mn,...nd->...md', [tensor1, tensor2]) + +@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingModule()) +def EinsumStaticWithEllipsisSlicingModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 6), tu.rand(3, 6, 5)) + +class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 6, 4, 5], torch.float32, True), + ([6, 5], torch.float32, True), + ]) + def forward(self, tensor1, tensor2): + # should be abnd,bd -> abn + return torch.ops.aten.einsum('...nd,...d->...n', [tensor1, tensor2]) + +@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule()) +def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5)) \ No newline at end of file