mirror of https://github.com/llvm/torch-mlir
[torch] Support diagonal `einsum.Diagonal` (#3618)
The einsum lowering was missing the behavior for duplicate indices in the equation. This amounts to a diagonalization along duplicate pairs of indices in the equation.pull/3631/head
parent
d11d6f6fea
commit
9ab93436c4
|
@ -304,6 +304,84 @@ static bool parseEquation(const std::string &equation,
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
diagonalizeInputAndRewriteEquation(Location loc, PatternRewriter &rewriter,
|
||||
std::string &equation,
|
||||
SmallVector<Value> &inputTensors) {
|
||||
SmallVector<char> resultTokens;
|
||||
SmallVector<SmallVector<char>> inputTokens;
|
||||
|
||||
if (!parseEquation(equation, inputTokens, resultTokens)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0, d = inputTokens.size(); i < d; ++i) {
|
||||
SmallVector<char> inputStr = inputTokens[i];
|
||||
Value input = inputTensors[i];
|
||||
|
||||
for (size_t d0 = 0; d0 < inputStr.size(); ++d0) {
|
||||
char id = inputStr[d0];
|
||||
|
||||
size_t d1;
|
||||
for (d1 = d0 + 1; d1 < inputStr.size(); d1++) {
|
||||
if (id == inputStr[d1])
|
||||
break;
|
||||
}
|
||||
|
||||
// No duplicate found so we can continue.
|
||||
if (d1 == inputStr.size())
|
||||
continue;
|
||||
|
||||
// Remove the ID and move to the end:
|
||||
for (size_t i = d0 + 1; i < d1; ++i)
|
||||
inputStr[i - 1] = inputStr[i];
|
||||
for (size_t i = d1 + 1, s = inputStr.size(); i < s; ++i)
|
||||
inputStr[i - 2] = inputStr[i];
|
||||
|
||||
inputStr[inputStr.size() - 2] = id;
|
||||
inputStr.resize(inputStr.size() - 1);
|
||||
|
||||
auto inputTy = cast<ValueTensorType>(input.getType());
|
||||
llvm::SmallVector<int64_t> newShape;
|
||||
for (size_t i = 0, s = inputTy.getSizes().size(); i < s; ++i) {
|
||||
if (i == d0 || i == d1)
|
||||
continue;
|
||||
newShape.push_back(inputTy.getSizes()[i]);
|
||||
}
|
||||
newShape.push_back(inputTy.getSizes()[d0]);
|
||||
|
||||
inputTy = rewriter.getType<ValueTensorType>(newShape, inputTy.getDtype());
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
|
||||
Value d0Val = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(d0));
|
||||
Value d1Val = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(d1));
|
||||
|
||||
input = rewriter.create<AtenDiagonalOp>(loc, inputTy, /*input=*/input,
|
||||
/*offset=*/zero, /*dim1=*/d0Val,
|
||||
/*dim2=*/d1Val);
|
||||
|
||||
// Frontmost token will have changed:
|
||||
d0--;
|
||||
}
|
||||
|
||||
inputTokens[i] = inputStr;
|
||||
inputTensors[i] = input;
|
||||
}
|
||||
|
||||
llvm::SmallVector<std::string> inputStrings;
|
||||
for (auto inputStr : inputTokens)
|
||||
inputStrings.emplace_back(inputStr.begin(), inputStr.end());
|
||||
|
||||
std::string resultString(resultTokens.begin(), resultTokens.end());
|
||||
|
||||
equation = llvm::join(inputStrings, ",") + "->" + resultString;
|
||||
return true;
|
||||
}
|
||||
|
||||
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] =>
|
||||
// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
|
||||
static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
||||
|
@ -523,12 +601,6 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|||
generateOutDimShapeMap(lhsOtherDims);
|
||||
generateOutDimShapeMap(rhsOtherDims);
|
||||
|
||||
if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 &&
|
||||
rhsOtherDims.size() == 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "Hadamard product is currently not supported");
|
||||
}
|
||||
|
||||
// shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims]
|
||||
lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims,
|
||||
contractingDims, lhsOtherDims, lhsReduceDims,
|
||||
|
@ -548,7 +620,12 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|||
|
||||
// perform matmul
|
||||
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType);
|
||||
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
|
||||
|
||||
if (contractingDims.size() != 0) {
|
||||
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
|
||||
} else {
|
||||
result = rewriter.create<Torch::AtenMulTensorOp>(loc, outType, lhs, rhs);
|
||||
}
|
||||
|
||||
// generate ideal result dims.
|
||||
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
|
||||
|
@ -1777,6 +1854,13 @@ public:
|
|||
op, "Unexpected character in equations encountered");
|
||||
}
|
||||
}
|
||||
|
||||
if (!diagonalizeInputAndRewriteEquation(op.getLoc(), rewriter, equation,
|
||||
inputTensors)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Failed to handle diagonalization");
|
||||
}
|
||||
|
||||
SmallVector<char> resultTokens;
|
||||
SmallVector<SmallVector<char>> inputTokens;
|
||||
if (!parseEquation(equation, inputTokens, resultTokens)) {
|
||||
|
|
|
@ -1303,6 +1303,27 @@ def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 4, 5, 6), tu.rand(3, 7, 5, 6))
|
||||
|
||||
|
||||
class EinsumStaticDiagonalDimensionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([5, 5, 4, 4], torch.float32, True),
|
||||
([5, 4, 5, 4], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, tensor1, tensor2):
|
||||
return torch.ops.aten.einsum("iijj,ijij->ji", [tensor1, tensor2])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: EinsumStaticDiagonalDimensionModule())
|
||||
def EinsumStaticDiagonalDimensionModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 5, 4, 4), tu.rand(5, 4, 5, 4))
|
||||
|
||||
|
||||
class EinsumStaticContractRhsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue