mirror of https://github.com/llvm/torch-mlir
[linalg] Broadcast batch for mask on sdpa lowering (#3824)
Attention often broadcasts a mask across the batch dimension as masking is usually performed the same across attention heads. Added this materialization to the mask dimensions optionally.pull/3828/head
parent
5aa323dd29
commit
25738b8c19
|
@ -1661,6 +1661,7 @@ public:
|
|||
auto valueTy = cast<ShapedType>(value.getType());
|
||||
auto keyTy = cast<ShapedType>(key.getType());
|
||||
|
||||
auto loc = op.getLoc();
|
||||
Value dropoutP = op.getDropoutP();
|
||||
Value isCausal = op.getIsCausal();
|
||||
Value scale = op.getScale();
|
||||
|
@ -1671,13 +1672,13 @@ public:
|
|||
double dropout;
|
||||
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
|
||||
dropout > 0.0)
|
||||
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
|
||||
return rewriter.notifyMatchFailure(loc, "dropout not supported");
|
||||
|
||||
bool causal;
|
||||
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
|
||||
if (!isa<Torch::NoneType>(mask.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "expected no attention mask when isCausal is true");
|
||||
loc, "expected no attention mask when isCausal is true");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> maskStatic;
|
||||
|
@ -1685,35 +1686,32 @@ public:
|
|||
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
|
||||
maskStatic.push_back(queryTy.getDimSize(i));
|
||||
if (maskStatic.back() == ShapedType::kDynamic)
|
||||
maskDyn.push_back(
|
||||
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
|
||||
maskDyn.push_back(rewriter.create<tensor::DimOp>(loc, query, i));
|
||||
}
|
||||
|
||||
maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
|
||||
if (maskStatic.back() == ShapedType::kDynamic)
|
||||
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
|
||||
keyTy.getRank() - 2));
|
||||
maskDyn.push_back(
|
||||
rewriter.create<tensor::DimOp>(loc, key, keyTy.getRank() - 2));
|
||||
|
||||
Type maskType = getElementTypeOrSelf(queryTy);
|
||||
Value emptyMask = rewriter.create<tensor::EmptyOp>(
|
||||
op.getLoc(), maskStatic, maskType, maskDyn);
|
||||
Value emptyMask =
|
||||
rewriter.create<tensor::EmptyOp>(loc, maskStatic, maskType, maskDyn);
|
||||
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
|
||||
loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
|
||||
Value negInf = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
loc,
|
||||
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
|
||||
|
||||
mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
|
||||
.getResult(0);
|
||||
mask = rewriter.create<linalg::FillOp>(loc, zero, emptyMask).getResult(0);
|
||||
|
||||
int64_t rank = cast<ShapedType>(queryTy).getRank();
|
||||
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
rank, utils::IteratorType::parallel);
|
||||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||
op.getLoc(), mask.getType(), ValueRange{}, mask,
|
||||
loc, mask.getType(), ValueRange{}, mask,
|
||||
SmallVector<AffineMap>{maskMap}, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
|
||||
|
@ -1727,18 +1725,78 @@ public:
|
|||
mask = genericOp.getResult(0);
|
||||
}
|
||||
|
||||
// Broadcast the batch dimensions of the mask:
|
||||
if (!isa<Torch::NoneType>(mask.getType())) {
|
||||
auto maskTy = cast<RankedTensorType>(mask.getType());
|
||||
int64_t rank = maskTy.getRank();
|
||||
bool needsBroadcast = false;
|
||||
for (int i = 0, s = rank - 2; i < s; ++i) {
|
||||
needsBroadcast |= maskTy.getDimSize(i) != keyTy.getDimSize(i);
|
||||
}
|
||||
|
||||
if (needsBroadcast) {
|
||||
SmallVector<int64_t> maskShape;
|
||||
SmallVector<Value> maskDynDims;
|
||||
|
||||
SmallVector<AffineExpr> maskExprs;
|
||||
for (int i = 0, s = rank - 2; i < s; ++i) {
|
||||
maskShape.push_back(keyTy.getDimSize(i));
|
||||
|
||||
if (maskTy.getDimSize(i) != keyTy.getDimSize(i)) {
|
||||
maskExprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||
} else {
|
||||
maskExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
}
|
||||
|
||||
if (keyTy.isDynamicDim(i)) {
|
||||
maskDynDims.push_back(rewriter.create<tensor::DimOp>(loc, key, i));
|
||||
}
|
||||
}
|
||||
|
||||
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 2));
|
||||
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 1));
|
||||
maskShape.push_back(maskTy.getDimSize(rank - 2));
|
||||
maskShape.push_back(maskTy.getDimSize(rank - 1));
|
||||
if (maskTy.isDynamicDim(rank - 2))
|
||||
maskDynDims.push_back(
|
||||
rewriter.create<tensor::DimOp>(loc, mask, rank - 2));
|
||||
if (maskTy.isDynamicDim(rank - 1))
|
||||
maskDynDims.push_back(
|
||||
rewriter.create<tensor::DimOp>(loc, mask, rank - 1));
|
||||
|
||||
SmallVector<AffineMap> affineMaps = {
|
||||
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, maskExprs,
|
||||
op.getContext()),
|
||||
rewriter.getMultiDimIdentityMap(rank)};
|
||||
SmallVector<utils::IteratorType> findMaxIteratorTypes(
|
||||
rank, utils::IteratorType::parallel);
|
||||
|
||||
Value emptyMask = rewriter.create<tensor::EmptyOp>(
|
||||
loc, maskShape, maskTy.getElementType(), maskDynDims);
|
||||
Value newMask =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, emptyMask.getType(), mask, ValueRange({emptyMask}),
|
||||
affineMaps, findMaxIteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
mask = newMask;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isa<Torch::NoneType>(scale.getType())) {
|
||||
double scaleFloat;
|
||||
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
||||
scaleFloat != 1.0)
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"only default scale supported");
|
||||
return rewriter.notifyMatchFailure(loc, "only default scale supported");
|
||||
}
|
||||
bool isGQAEnabled;
|
||||
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
|
||||
isGQAEnabled)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "grouped query attention not supported");
|
||||
loc, "grouped query attention not supported");
|
||||
|
||||
if (queryTy.getRank() != valueTy.getRank() ||
|
||||
queryTy.getRank() != keyTy.getRank())
|
||||
|
@ -1753,7 +1811,6 @@ public:
|
|||
reassociation[1].push_back(valueTy.getRank() - 2);
|
||||
reassociation[2].push_back(valueTy.getRank() - 1);
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto collapseBatch = [&rewriter, &reassociation,
|
||||
loc](Value value) -> Value {
|
||||
auto valueTy = cast<ShapedType>(value.getType());
|
||||
|
@ -1788,13 +1845,12 @@ public:
|
|||
SmallVector<int64_t> valueSizes(
|
||||
cast<ShapedType>(value.getType()).getShape());
|
||||
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
||||
SmallVector<Value> outSizesDynamic(
|
||||
getTensorSizes(rewriter, op.getLoc(), query));
|
||||
SmallVector<Value> outSizesDynamic(getTensorSizes(rewriter, loc, query));
|
||||
outSizesDynamic[outSizesDynamic.size() - 1] =
|
||||
getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1];
|
||||
getTensorSizes(rewriter, loc, value)[valueSizes.size() - 1];
|
||||
Type outType = RankedTensorType::get(outSizes, elementType);
|
||||
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
||||
elementType);
|
||||
Value output =
|
||||
createZeroInitTensor(rewriter, loc, outSizesDynamic, elementType);
|
||||
|
||||
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
|
||||
|
||||
|
|
|
@ -5501,7 +5501,7 @@ class ScaledDotProductAttentionMaskModule(torch.nn.Module):
|
|||
([2, 3, 8, 16], torch.float32, True),
|
||||
([2, 3, 12, 16], torch.float32, True),
|
||||
([2, 3, 12, 20], torch.float32, True),
|
||||
([2, 3, 8, 12], torch.float32, True),
|
||||
([2, 1, 8, 12], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value, mask):
|
||||
|
@ -5513,7 +5513,7 @@ def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
|
|||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32)
|
||||
mask = torch.randn(2, 1, 8, 12, dtype=torch.float32)
|
||||
module.forward(query, key, value, mask)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue