mirror of https://github.com/llvm/torch-mlir
[torch] Fix attention on linalg for dynamic shapes (#3714)
Current version does not work for a mixture of dynamic and static shaped batch dimensions. Rework to grab the correct dynamic shapes. --------- Co-authored-by: dan <danimal197@gmail.com>pull/3717/head
parent
3f46348e8e
commit
5ce48dfacd
|
@ -1607,35 +1607,23 @@ public:
|
||||||
op.getLoc(), "expected no attention mask when isCausal is true");
|
op.getLoc(), "expected no attention mask when isCausal is true");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpFoldResult> maskSizes;
|
SmallVector<int64_t> maskStatic;
|
||||||
|
SmallVector<Value> maskDyn;
|
||||||
|
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
|
||||||
|
maskStatic.push_back(queryTy.getDimSize(i));
|
||||||
|
if (maskStatic.back() == ShapedType::kDynamic)
|
||||||
|
maskDyn.push_back(
|
||||||
|
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
|
||||||
|
}
|
||||||
|
|
||||||
if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
|
maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
|
||||||
auto seqLenQ =
|
if (maskStatic.back() == ShapedType::kDynamic)
|
||||||
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
|
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
|
||||||
auto seqLenK =
|
keyTy.getRank() - 2));
|
||||||
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
|
|
||||||
maskSizes = {seqLenQ, seqLenK};
|
|
||||||
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
|
|
||||||
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
|
|
||||||
maskSizes.insert(maskSizes.begin(), batchSize);
|
|
||||||
}
|
|
||||||
} else { // Dynamic shape case: <?x?x...x?xf32> for example
|
|
||||||
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
|
|
||||||
Value batchSize =
|
|
||||||
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
|
|
||||||
maskSizes.push_back(batchSize);
|
|
||||||
}
|
|
||||||
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
|
|
||||||
queryTy.getRank() - 2);
|
|
||||||
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
|
|
||||||
keyTy.getRank() - 2);
|
|
||||||
maskSizes.push_back(seqLenQ);
|
|
||||||
maskSizes.push_back(seqLenK);
|
|
||||||
}
|
|
||||||
|
|
||||||
Type maskType = getElementTypeOrSelf(queryTy);
|
Type maskType = getElementTypeOrSelf(queryTy);
|
||||||
Value emptyMask =
|
Value emptyMask = rewriter.create<tensor::EmptyOp>(
|
||||||
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
|
op.getLoc(), maskStatic, maskType, maskDyn);
|
||||||
|
|
||||||
Value zero = rewriter.create<arith::ConstantOp>(
|
Value zero = rewriter.create<arith::ConstantOp>(
|
||||||
op.getLoc(),
|
op.getLoc(),
|
||||||
|
|
|
@ -37,6 +37,7 @@ if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
||||||
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
|
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
"ScaledDotProductAttentionMaskModule_basic",
|
"ScaledDotProductAttentionMaskModule_basic",
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
|
@ -833,6 +834,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"SafeSoftmaxNonNoneDtypeModule_basic",
|
"SafeSoftmaxNonNoneDtypeModule_basic",
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
"ScaledDotProductAttentionMaskModule_basic",
|
"ScaledDotProductAttentionMaskModule_basic",
|
||||||
|
@ -3176,6 +3178,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
"ScatterAddStaticModule_basic",
|
"ScatterAddStaticModule_basic",
|
||||||
"TensorsConcatComplex128FloatModule_basic",
|
"TensorsConcatComplex128FloatModule_basic",
|
||||||
|
@ -4679,6 +4682,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
|
|
|
@ -5370,6 +5370,35 @@ class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module):
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
|
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
|
||||||
)
|
)
|
||||||
|
def ScaledDotProductAttentionDifferentDynamicCausalModule_basic(module, tu: TestUtils):
|
||||||
|
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||||
|
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||||
|
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||||
|
module.forward(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledDotProductAttentionDifferentDynamicCausalModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3, -1, 16], torch.float32, True),
|
||||||
|
([2, 3, -1, 16], torch.float32, True),
|
||||||
|
([2, 3, -1, 20], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
return torch.ops.aten.scaled_dot_product_attention(
|
||||||
|
query, key, value, is_causal=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ScaledDotProductAttentionDifferentDynamicCausalModule()
|
||||||
|
)
|
||||||
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
|
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
|
||||||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||||
|
|
Loading…
Reference in New Issue