diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 4a87d6888..b0b0b0df2 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1607,35 +1607,23 @@ public: op.getLoc(), "expected no attention mask when isCausal is true"); } - SmallVector maskSizes; - - if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) { - auto seqLenQ = - rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2)); - auto seqLenK = - rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2)); - maskSizes = {seqLenQ, seqLenK}; - for (int i = queryTy.getRank() - 3; i >= 0; --i) { - auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i)); - maskSizes.insert(maskSizes.begin(), batchSize); - } - } else { // Dynamic shape case: for example - for (int i = 0; i < queryTy.getRank() - 2; ++i) { - Value batchSize = - rewriter.create(op.getLoc(), query, i); - maskSizes.push_back(batchSize); - } - Value seqLenQ = rewriter.create(op.getLoc(), query, - queryTy.getRank() - 2); - Value seqLenK = rewriter.create(op.getLoc(), key, - keyTy.getRank() - 2); - maskSizes.push_back(seqLenQ); - maskSizes.push_back(seqLenK); + SmallVector maskStatic; + SmallVector maskDyn; + for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) { + maskStatic.push_back(queryTy.getDimSize(i)); + if (maskStatic.back() == ShapedType::kDynamic) + maskDyn.push_back( + rewriter.create(op.getLoc(), query, i)); } + maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2)); + if (maskStatic.back() == ShapedType::kDynamic) + maskDyn.push_back(rewriter.create(op.getLoc(), key, + keyTy.getRank() - 2)); + Type maskType = getElementTypeOrSelf(queryTy); - Value emptyMask = - rewriter.create(op.getLoc(), maskSizes, maskType); + Value emptyMask = rewriter.create( + op.getLoc(), maskStatic, maskType, maskDyn); Value zero = rewriter.create( op.getLoc(), diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0e66d1cd3..8230f5e5a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -37,6 +37,7 @@ if torch_version_for_comparison() < version.parse("2.5.0.dev"): # WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", "ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionMaskModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", @@ -833,6 +834,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "SafeSoftmaxNonNoneDtypeModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionMaskModule_basic", @@ -3176,6 +3178,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = { # REMOVE WHEN ENABLE_GQA IS ADDED "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", "ScatterAddStaticModule_basic", "TensorsConcatComplex128FloatModule_basic", @@ -4679,6 +4682,7 @@ ONNX_TOSA_XFAIL_SET = { "ScalarImplicitIntModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", "ScaledDotProductAttentionSameDynamicModule_basic", "ScatterReduceFloatMaxModule", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index ce9a254f6..cb6aa7fc1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5370,6 +5370,35 @@ class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module): @register_test_case( module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule() ) +def ScaledDotProductAttentionDifferentDynamicCausalModule_basic(module, tu: TestUtils): + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + module.forward(query, key, value) + + +class ScaledDotProductAttentionDifferentDynamicCausalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, -1, 16], torch.float32, True), + ([2, 3, -1, 16], torch.float32, True), + ([2, 3, -1, 20], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention( + query, key, value, is_causal=True + ) + + +@register_test_case( + module_factory=lambda: ScaledDotProductAttentionDifferentDynamicCausalModule() +) def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils): query = torch.randn(2, 3, 8, 16, dtype=torch.float32) key = torch.randn(2, 3, 12, 16, dtype=torch.float32)