[torch] Support diagonal `einsum.Diagonal` (#3618)

The einsum lowering was missing the behavior for duplicate indices in
the equation. This amounts to a diagonalization along duplicate pairs of
indices in the equation.
pull/3631/head
Rob Suderman 2024-08-13 09:38:43 -07:00 committed by GitHub
parent d11d6f6fea
commit 9ab93436c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 112 additions and 7 deletions

View File

@ -304,6 +304,84 @@ static bool parseEquation(const std::string &equation,
return true;
}
static bool
diagonalizeInputAndRewriteEquation(Location loc, PatternRewriter &rewriter,
std::string &equation,
SmallVector<Value> &inputTensors) {
SmallVector<char> resultTokens;
SmallVector<SmallVector<char>> inputTokens;
if (!parseEquation(equation, inputTokens, resultTokens)) {
return false;
}
for (size_t i = 0, d = inputTokens.size(); i < d; ++i) {
SmallVector<char> inputStr = inputTokens[i];
Value input = inputTensors[i];
for (size_t d0 = 0; d0 < inputStr.size(); ++d0) {
char id = inputStr[d0];
size_t d1;
for (d1 = d0 + 1; d1 < inputStr.size(); d1++) {
if (id == inputStr[d1])
break;
}
// No duplicate found so we can continue.
if (d1 == inputStr.size())
continue;
// Remove the ID and move to the end:
for (size_t i = d0 + 1; i < d1; ++i)
inputStr[i - 1] = inputStr[i];
for (size_t i = d1 + 1, s = inputStr.size(); i < s; ++i)
inputStr[i - 2] = inputStr[i];
inputStr[inputStr.size() - 2] = id;
inputStr.resize(inputStr.size() - 1);
auto inputTy = cast<ValueTensorType>(input.getType());
llvm::SmallVector<int64_t> newShape;
for (size_t i = 0, s = inputTy.getSizes().size(); i < s; ++i) {
if (i == d0 || i == d1)
continue;
newShape.push_back(inputTy.getSizes()[i]);
}
newShape.push_back(inputTy.getSizes()[d0]);
inputTy = rewriter.getType<ValueTensorType>(newShape, inputTy.getDtype());
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value d0Val = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d0));
Value d1Val = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d1));
input = rewriter.create<AtenDiagonalOp>(loc, inputTy, /*input=*/input,
/*offset=*/zero, /*dim1=*/d0Val,
/*dim2=*/d1Val);
// Frontmost token will have changed:
d0--;
}
inputTokens[i] = inputStr;
inputTensors[i] = input;
}
llvm::SmallVector<std::string> inputStrings;
for (auto inputStr : inputTokens)
inputStrings.emplace_back(inputStr.begin(), inputStr.end());
std::string resultString(resultTokens.begin(), resultTokens.end());
equation = llvm::join(inputStrings, ",") + "->" + resultString;
return true;
}
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] =>
// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
@ -523,12 +601,6 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
generateOutDimShapeMap(lhsOtherDims);
generateOutDimShapeMap(rhsOtherDims);
if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 &&
rhsOtherDims.size() == 0) {
return rewriter.notifyMatchFailure(
loc, "Hadamard product is currently not supported");
}
// shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims]
lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims,
contractingDims, lhsOtherDims, lhsReduceDims,
@ -548,7 +620,12 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
// perform matmul
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType);
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
if (contractingDims.size() != 0) {
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
} else {
result = rewriter.create<Torch::AtenMulTensorOp>(loc, outType, lhs, rhs);
}
// generate ideal result dims.
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
@ -1777,6 +1854,13 @@ public:
op, "Unexpected character in equations encountered");
}
}
if (!diagonalizeInputAndRewriteEquation(op.getLoc(), rewriter, equation,
inputTensors)) {
return rewriter.notifyMatchFailure(op,
"Failed to handle diagonalization");
}
SmallVector<char> resultTokens;
SmallVector<SmallVector<char>> inputTokens;
if (!parseEquation(equation, inputTokens, resultTokens)) {

View File

@ -1303,6 +1303,27 @@ def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, 6), tu.rand(3, 7, 5, 6))
class EinsumStaticDiagonalDimensionModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 5, 4, 4], torch.float32, True),
([5, 4, 5, 4], torch.float32, True),
]
)
def forward(self, tensor1, tensor2):
return torch.ops.aten.einsum("iijj,ijij->ji", [tensor1, tensor2])
@register_test_case(module_factory=lambda: EinsumStaticDiagonalDimensionModule())
def EinsumStaticDiagonalDimensionModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5, 4, 4), tu.rand(5, 4, 5, 4))
class EinsumStaticContractRhsModule(torch.nn.Module):
def __init__(self):
super().__init__()