[torch] Fix attention on linalg for dynamic shapes (#3714)

Current version does not work for a mixture of dynamic and static shaped
batch dimensions. Rework to grab the correct dynamic shapes.

---------

Co-authored-by: dan <danimal197@gmail.com>
pull/3717/head
Rob Suderman 2024-09-18 12:52:54 -07:00 committed by GitHub
parent 3f46348e8e
commit 5ce48dfacd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 26 deletions

View File

@ -1607,35 +1607,23 @@ public:
op.getLoc(), "expected no attention mask when isCausal is true"); op.getLoc(), "expected no attention mask when isCausal is true");
} }
SmallVector<OpFoldResult> maskSizes; SmallVector<int64_t> maskStatic;
SmallVector<Value> maskDyn;
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
maskStatic.push_back(queryTy.getDimSize(i));
if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back(
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
}
if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) { maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
auto seqLenQ = if (maskStatic.back() == ShapedType::kDynamic)
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2)); maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
auto seqLenK = keyTy.getRank() - 2));
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
maskSizes = {seqLenQ, seqLenK};
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
maskSizes.insert(maskSizes.begin(), batchSize);
}
} else { // Dynamic shape case: <?x?x...x?xf32> for example
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
Value batchSize =
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
maskSizes.push_back(batchSize);
}
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
queryTy.getRank() - 2);
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2);
maskSizes.push_back(seqLenQ);
maskSizes.push_back(seqLenK);
}
Type maskType = getElementTypeOrSelf(queryTy); Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask = Value emptyMask = rewriter.create<tensor::EmptyOp>(
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType); op.getLoc(), maskStatic, maskType, maskDyn);
Value zero = rewriter.create<arith::ConstantOp>( Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(), op.getLoc(),

View File

@ -37,6 +37,7 @@ if torch_version_for_comparison() < version.parse("2.5.0.dev"):
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE # WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
"ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic", "ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic",
@ -833,6 +834,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"SafeSoftmaxNonNoneDtypeModule_basic", "SafeSoftmaxNonNoneDtypeModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED # REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic", "ScaledDotProductAttentionMaskModule_basic",
@ -3176,6 +3178,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
# REMOVE WHEN ENABLE_GQA IS ADDED # REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic",
"ScatterAddStaticModule_basic", "ScatterAddStaticModule_basic",
"TensorsConcatComplex128FloatModule_basic", "TensorsConcatComplex128FloatModule_basic",
@ -4679,6 +4682,7 @@ ONNX_TOSA_XFAIL_SET = {
"ScalarImplicitIntModule_basic", "ScalarImplicitIntModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED # REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic", "ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModule",

View File

@ -5370,6 +5370,35 @@ class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module):
@register_test_case( @register_test_case(
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule() module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
) )
def ScaledDotProductAttentionDifferentDynamicCausalModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionDifferentDynamicCausalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, -1, 16], torch.float32, True),
([2, 3, -1, 16], torch.float32, True),
([2, 3, -1, 20], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, is_causal=True
)
@register_test_case(
module_factory=lambda: ScaledDotProductAttentionDifferentDynamicCausalModule()
)
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils): def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32) query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32) key = torch.randn(2, 3, 12, 16, dtype=torch.float32)