mirror of https://github.com/llvm/torch-mlir
[torch] Fix attention on linalg for dynamic shapes (#3714)
Current version does not work for a mixture of dynamic and static shaped batch dimensions. Rework to grab the correct dynamic shapes. --------- Co-authored-by: dan <danimal197@gmail.com>pull/3717/head
parent
3f46348e8e
commit
5ce48dfacd
|
@ -1607,35 +1607,23 @@ public:
|
|||
op.getLoc(), "expected no attention mask when isCausal is true");
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> maskSizes;
|
||||
SmallVector<int64_t> maskStatic;
|
||||
SmallVector<Value> maskDyn;
|
||||
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
|
||||
maskStatic.push_back(queryTy.getDimSize(i));
|
||||
if (maskStatic.back() == ShapedType::kDynamic)
|
||||
maskDyn.push_back(
|
||||
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
|
||||
}
|
||||
|
||||
if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
|
||||
auto seqLenQ =
|
||||
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
|
||||
auto seqLenK =
|
||||
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
|
||||
maskSizes = {seqLenQ, seqLenK};
|
||||
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
|
||||
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
|
||||
maskSizes.insert(maskSizes.begin(), batchSize);
|
||||
}
|
||||
} else { // Dynamic shape case: <?x?x...x?xf32> for example
|
||||
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
|
||||
Value batchSize =
|
||||
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
|
||||
maskSizes.push_back(batchSize);
|
||||
}
|
||||
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
|
||||
queryTy.getRank() - 2);
|
||||
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
|
||||
keyTy.getRank() - 2);
|
||||
maskSizes.push_back(seqLenQ);
|
||||
maskSizes.push_back(seqLenK);
|
||||
}
|
||||
maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
|
||||
if (maskStatic.back() == ShapedType::kDynamic)
|
||||
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
|
||||
keyTy.getRank() - 2));
|
||||
|
||||
Type maskType = getElementTypeOrSelf(queryTy);
|
||||
Value emptyMask =
|
||||
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
|
||||
Value emptyMask = rewriter.create<tensor::EmptyOp>(
|
||||
op.getLoc(), maskStatic, maskType, maskDyn);
|
||||
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
|
|
|
@ -37,6 +37,7 @@ if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
|||
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
|
@ -833,6 +834,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"SafeSoftmaxNonNoneDtypeModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
|
@ -3176,6 +3178,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScatterAddStaticModule_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
|
@ -4679,6 +4682,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ScalarImplicitIntModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScatterReduceFloatMaxModule",
|
||||
|
|
|
@ -5370,6 +5370,35 @@ class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
|
||||
)
|
||||
def ScaledDotProductAttentionDifferentDynamicCausalModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
|
||||
class ScaledDotProductAttentionDifferentDynamicCausalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, -1, 16], torch.float32, True),
|
||||
([2, 3, -1, 16], torch.float32, True),
|
||||
([2, 3, -1, 20], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten.scaled_dot_product_attention(
|
||||
query, key, value, is_causal=True
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ScaledDotProductAttentionDifferentDynamicCausalModule()
|
||||
)
|
||||
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||
|
|
Loading…
Reference in New Issue