[linalg] Broadcast batch for mask on sdpa lowering (#3824)

Attention often broadcasts a mask across the batch dimension as masking
is usually performed the same across attention heads. Added this
materialization to the mask dimensions optionally.
pull/3828/head
Rob Suderman 2024-10-31 17:59:24 -07:00 committed by GitHub
parent 5aa323dd29
commit 25738b8c19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 25 deletions

View File

@ -1661,6 +1661,7 @@ public:
auto valueTy = cast<ShapedType>(value.getType()); auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType()); auto keyTy = cast<ShapedType>(key.getType());
auto loc = op.getLoc();
Value dropoutP = op.getDropoutP(); Value dropoutP = op.getDropoutP();
Value isCausal = op.getIsCausal(); Value isCausal = op.getIsCausal();
Value scale = op.getScale(); Value scale = op.getScale();
@ -1671,13 +1672,13 @@ public:
double dropout; double dropout;
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) || if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
dropout > 0.0) dropout > 0.0)
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported"); return rewriter.notifyMatchFailure(loc, "dropout not supported");
bool causal; bool causal;
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) { if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
if (!isa<Torch::NoneType>(mask.getType())) { if (!isa<Torch::NoneType>(mask.getType())) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op.getLoc(), "expected no attention mask when isCausal is true"); loc, "expected no attention mask when isCausal is true");
} }
SmallVector<int64_t> maskStatic; SmallVector<int64_t> maskStatic;
@ -1685,35 +1686,32 @@ public:
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) { for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
maskStatic.push_back(queryTy.getDimSize(i)); maskStatic.push_back(queryTy.getDimSize(i));
if (maskStatic.back() == ShapedType::kDynamic) if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back( maskDyn.push_back(rewriter.create<tensor::DimOp>(loc, query, i));
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
} }
maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2)); maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
if (maskStatic.back() == ShapedType::kDynamic) if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key, maskDyn.push_back(
keyTy.getRank() - 2)); rewriter.create<tensor::DimOp>(loc, key, keyTy.getRank() - 2));
Type maskType = getElementTypeOrSelf(queryTy); Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask = rewriter.create<tensor::EmptyOp>( Value emptyMask =
op.getLoc(), maskStatic, maskType, maskDyn); rewriter.create<tensor::EmptyOp>(loc, maskStatic, maskType, maskDyn);
Value zero = rewriter.create<arith::ConstantOp>( Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(), loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
Value negInf = rewriter.create<arith::ConstantOp>( Value negInf = rewriter.create<arith::ConstantOp>(
op.getLoc(), loc,
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY)); rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask) mask = rewriter.create<linalg::FillOp>(loc, zero, emptyMask).getResult(0);
.getResult(0);
int64_t rank = cast<ShapedType>(queryTy).getRank(); int64_t rank = cast<ShapedType>(queryTy).getRank();
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank); AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
SmallVector<utils::IteratorType> iteratorTypes( SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel); rank, utils::IteratorType::parallel);
auto genericOp = rewriter.create<linalg::GenericOp>( auto genericOp = rewriter.create<linalg::GenericOp>(
op.getLoc(), mask.getType(), ValueRange{}, mask, loc, mask.getType(), ValueRange{}, mask,
SmallVector<AffineMap>{maskMap}, iteratorTypes, SmallVector<AffineMap>{maskMap}, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2); Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
@ -1727,18 +1725,78 @@ public:
mask = genericOp.getResult(0); mask = genericOp.getResult(0);
} }
// Broadcast the batch dimensions of the mask:
if (!isa<Torch::NoneType>(mask.getType())) {
auto maskTy = cast<RankedTensorType>(mask.getType());
int64_t rank = maskTy.getRank();
bool needsBroadcast = false;
for (int i = 0, s = rank - 2; i < s; ++i) {
needsBroadcast |= maskTy.getDimSize(i) != keyTy.getDimSize(i);
}
if (needsBroadcast) {
SmallVector<int64_t> maskShape;
SmallVector<Value> maskDynDims;
SmallVector<AffineExpr> maskExprs;
for (int i = 0, s = rank - 2; i < s; ++i) {
maskShape.push_back(keyTy.getDimSize(i));
if (maskTy.getDimSize(i) != keyTy.getDimSize(i)) {
maskExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
maskExprs.push_back(rewriter.getAffineDimExpr(i));
}
if (keyTy.isDynamicDim(i)) {
maskDynDims.push_back(rewriter.create<tensor::DimOp>(loc, key, i));
}
}
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 2));
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 1));
maskShape.push_back(maskTy.getDimSize(rank - 2));
maskShape.push_back(maskTy.getDimSize(rank - 1));
if (maskTy.isDynamicDim(rank - 2))
maskDynDims.push_back(
rewriter.create<tensor::DimOp>(loc, mask, rank - 2));
if (maskTy.isDynamicDim(rank - 1))
maskDynDims.push_back(
rewriter.create<tensor::DimOp>(loc, mask, rank - 1));
SmallVector<AffineMap> affineMaps = {
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, maskExprs,
op.getContext()),
rewriter.getMultiDimIdentityMap(rank)};
SmallVector<utils::IteratorType> findMaxIteratorTypes(
rank, utils::IteratorType::parallel);
Value emptyMask = rewriter.create<tensor::EmptyOp>(
loc, maskShape, maskTy.getElementType(), maskDynDims);
Value newMask =
rewriter
.create<linalg::GenericOp>(
loc, emptyMask.getType(), mask, ValueRange({emptyMask}),
affineMaps, findMaxIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
mask = newMask;
}
}
if (!isa<Torch::NoneType>(scale.getType())) { if (!isa<Torch::NoneType>(scale.getType())) {
double scaleFloat; double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) || if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
scaleFloat != 1.0) scaleFloat != 1.0)
return rewriter.notifyMatchFailure(op.getLoc(), return rewriter.notifyMatchFailure(loc, "only default scale supported");
"only default scale supported");
} }
bool isGQAEnabled; bool isGQAEnabled;
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) || if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
isGQAEnabled) isGQAEnabled)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op.getLoc(), "grouped query attention not supported"); loc, "grouped query attention not supported");
if (queryTy.getRank() != valueTy.getRank() || if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank()) queryTy.getRank() != keyTy.getRank())
@ -1753,7 +1811,6 @@ public:
reassociation[1].push_back(valueTy.getRank() - 2); reassociation[1].push_back(valueTy.getRank() - 2);
reassociation[2].push_back(valueTy.getRank() - 1); reassociation[2].push_back(valueTy.getRank() - 1);
auto loc = op.getLoc();
auto collapseBatch = [&rewriter, &reassociation, auto collapseBatch = [&rewriter, &reassociation,
loc](Value value) -> Value { loc](Value value) -> Value {
auto valueTy = cast<ShapedType>(value.getType()); auto valueTy = cast<ShapedType>(value.getType());
@ -1788,13 +1845,12 @@ public:
SmallVector<int64_t> valueSizes( SmallVector<int64_t> valueSizes(
cast<ShapedType>(value.getType()).getShape()); cast<ShapedType>(value.getType()).getShape());
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1]; outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
SmallVector<Value> outSizesDynamic( SmallVector<Value> outSizesDynamic(getTensorSizes(rewriter, loc, query));
getTensorSizes(rewriter, op.getLoc(), query));
outSizesDynamic[outSizesDynamic.size() - 1] = outSizesDynamic[outSizesDynamic.size() - 1] =
getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1]; getTensorSizes(rewriter, loc, value)[valueSizes.size() - 1];
Type outType = RankedTensorType::get(outSizes, elementType); Type outType = RankedTensorType::get(outSizes, elementType);
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic, Value output =
elementType); createZeroInitTensor(rewriter, loc, outSizesDynamic, elementType);
SmallVector<Value> inputs = SmallVector<Value>{query, key, value}; SmallVector<Value> inputs = SmallVector<Value>{query, key, value};

View File

@ -5501,7 +5501,7 @@ class ScaledDotProductAttentionMaskModule(torch.nn.Module):
([2, 3, 8, 16], torch.float32, True), ([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True), ([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True), ([2, 3, 12, 20], torch.float32, True),
([2, 3, 8, 12], torch.float32, True), ([2, 1, 8, 12], torch.float32, True),
] ]
) )
def forward(self, query, key, value, mask): def forward(self, query, key, value, mask):
@ -5513,7 +5513,7 @@ def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32) query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32) key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32) value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) mask = torch.randn(2, 1, 8, 12, dtype=torch.float32)
module.forward(query, key, value, mask) module.forward(query, key, value, mask)