From 25738b8c19fe74e325b6bdfcd33e3e550304bf6f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 31 Oct 2024 17:59:24 -0700 Subject: [PATCH] [linalg] Broadcast batch for mask on sdpa lowering (#3824) Attention often broadcasts a mask across the batch dimension as masking is usually performed the same across attention heads. Added this materialization to the mask dimensions optionally. --- .../TorchToTMTensor/TorchToTMTensor.cpp | 102 ++++++++++++++---- .../torch_mlir_e2e_test/test_suite/basic.py | 4 +- 2 files changed, 81 insertions(+), 25 deletions(-) diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 94d715411..e154f5cb9 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1661,6 +1661,7 @@ public: auto valueTy = cast(value.getType()); auto keyTy = cast(key.getType()); + auto loc = op.getLoc(); Value dropoutP = op.getDropoutP(); Value isCausal = op.getIsCausal(); Value scale = op.getScale(); @@ -1671,13 +1672,13 @@ public: double dropout; if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) || dropout > 0.0) - return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported"); + return rewriter.notifyMatchFailure(loc, "dropout not supported"); bool causal; if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) { if (!isa(mask.getType())) { return rewriter.notifyMatchFailure( - op.getLoc(), "expected no attention mask when isCausal is true"); + loc, "expected no attention mask when isCausal is true"); } SmallVector maskStatic; @@ -1685,35 +1686,32 @@ public: for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) { maskStatic.push_back(queryTy.getDimSize(i)); if (maskStatic.back() == ShapedType::kDynamic) - maskDyn.push_back( - rewriter.create(op.getLoc(), query, i)); + maskDyn.push_back(rewriter.create(loc, query, i)); } maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2)); if (maskStatic.back() == ShapedType::kDynamic) - maskDyn.push_back(rewriter.create(op.getLoc(), key, - keyTy.getRank() - 2)); + maskDyn.push_back( + rewriter.create(loc, key, keyTy.getRank() - 2)); Type maskType = getElementTypeOrSelf(queryTy); - Value emptyMask = rewriter.create( - op.getLoc(), maskStatic, maskType, maskDyn); + Value emptyMask = + rewriter.create(loc, maskStatic, maskType, maskDyn); Value zero = rewriter.create( - op.getLoc(), - rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0)); + loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0)); Value negInf = rewriter.create( - op.getLoc(), + loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY)); - mask = rewriter.create(op.getLoc(), zero, emptyMask) - .getResult(0); + mask = rewriter.create(loc, zero, emptyMask).getResult(0); int64_t rank = cast(queryTy).getRank(); AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); auto genericOp = rewriter.create( - op.getLoc(), mask.getType(), ValueRange{}, mask, + loc, mask.getType(), ValueRange{}, mask, SmallVector{maskMap}, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value i = b.create(loc, queryTy.getRank() - 2); @@ -1727,18 +1725,78 @@ public: mask = genericOp.getResult(0); } + // Broadcast the batch dimensions of the mask: + if (!isa(mask.getType())) { + auto maskTy = cast(mask.getType()); + int64_t rank = maskTy.getRank(); + bool needsBroadcast = false; + for (int i = 0, s = rank - 2; i < s; ++i) { + needsBroadcast |= maskTy.getDimSize(i) != keyTy.getDimSize(i); + } + + if (needsBroadcast) { + SmallVector maskShape; + SmallVector maskDynDims; + + SmallVector maskExprs; + for (int i = 0, s = rank - 2; i < s; ++i) { + maskShape.push_back(keyTy.getDimSize(i)); + + if (maskTy.getDimSize(i) != keyTy.getDimSize(i)) { + maskExprs.push_back(rewriter.getAffineConstantExpr(0)); + } else { + maskExprs.push_back(rewriter.getAffineDimExpr(i)); + } + + if (keyTy.isDynamicDim(i)) { + maskDynDims.push_back(rewriter.create(loc, key, i)); + } + } + + maskExprs.push_back(rewriter.getAffineDimExpr(rank - 2)); + maskExprs.push_back(rewriter.getAffineDimExpr(rank - 1)); + maskShape.push_back(maskTy.getDimSize(rank - 2)); + maskShape.push_back(maskTy.getDimSize(rank - 1)); + if (maskTy.isDynamicDim(rank - 2)) + maskDynDims.push_back( + rewriter.create(loc, mask, rank - 2)); + if (maskTy.isDynamicDim(rank - 1)) + maskDynDims.push_back( + rewriter.create(loc, mask, rank - 1)); + + SmallVector affineMaps = { + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, maskExprs, + op.getContext()), + rewriter.getMultiDimIdentityMap(rank)}; + SmallVector findMaxIteratorTypes( + rank, utils::IteratorType::parallel); + + Value emptyMask = rewriter.create( + loc, maskShape, maskTy.getElementType(), maskDynDims); + Value newMask = + rewriter + .create( + loc, emptyMask.getType(), mask, ValueRange({emptyMask}), + affineMaps, findMaxIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + mask = newMask; + } + } + if (!isa(scale.getType())) { double scaleFloat; if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) || scaleFloat != 1.0) - return rewriter.notifyMatchFailure(op.getLoc(), - "only default scale supported"); + return rewriter.notifyMatchFailure(loc, "only default scale supported"); } bool isGQAEnabled; if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) || isGQAEnabled) return rewriter.notifyMatchFailure( - op.getLoc(), "grouped query attention not supported"); + loc, "grouped query attention not supported"); if (queryTy.getRank() != valueTy.getRank() || queryTy.getRank() != keyTy.getRank()) @@ -1753,7 +1811,6 @@ public: reassociation[1].push_back(valueTy.getRank() - 2); reassociation[2].push_back(valueTy.getRank() - 1); - auto loc = op.getLoc(); auto collapseBatch = [&rewriter, &reassociation, loc](Value value) -> Value { auto valueTy = cast(value.getType()); @@ -1788,13 +1845,12 @@ public: SmallVector valueSizes( cast(value.getType()).getShape()); outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1]; - SmallVector outSizesDynamic( - getTensorSizes(rewriter, op.getLoc(), query)); + SmallVector outSizesDynamic(getTensorSizes(rewriter, loc, query)); outSizesDynamic[outSizesDynamic.size() - 1] = - getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1]; + getTensorSizes(rewriter, loc, value)[valueSizes.size() - 1]; Type outType = RankedTensorType::get(outSizes, elementType); - Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic, - elementType); + Value output = + createZeroInitTensor(rewriter, loc, outSizesDynamic, elementType); SmallVector inputs = SmallVector{query, key, value}; diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index bef16f3ef..bc87cc67d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5501,7 +5501,7 @@ class ScaledDotProductAttentionMaskModule(torch.nn.Module): ([2, 3, 8, 16], torch.float32, True), ([2, 3, 12, 16], torch.float32, True), ([2, 3, 12, 20], torch.float32, True), - ([2, 3, 8, 12], torch.float32, True), + ([2, 1, 8, 12], torch.float32, True), ] ) def forward(self, query, key, value, mask): @@ -5513,7 +5513,7 @@ def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils): query = torch.randn(2, 3, 8, 16, dtype=torch.float32) key = torch.randn(2, 3, 12, 16, dtype=torch.float32) value = torch.randn(2, 3, 12, 20, dtype=torch.float32) - mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) + mask = torch.randn(2, 1, 8, 12, dtype=torch.float32) module.forward(query, key, value, mask)