mirror of https://github.com/llvm/torch-mlir
[linalg] Broadcast batch for mask on sdpa lowering (#3824)
Attention often broadcasts a mask across the batch dimension as masking is usually performed the same across attention heads. Added this materialization to the mask dimensions optionally.pull/3828/head
parent
5aa323dd29
commit
25738b8c19
|
@ -1661,6 +1661,7 @@ public:
|
||||||
auto valueTy = cast<ShapedType>(value.getType());
|
auto valueTy = cast<ShapedType>(value.getType());
|
||||||
auto keyTy = cast<ShapedType>(key.getType());
|
auto keyTy = cast<ShapedType>(key.getType());
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
Value dropoutP = op.getDropoutP();
|
Value dropoutP = op.getDropoutP();
|
||||||
Value isCausal = op.getIsCausal();
|
Value isCausal = op.getIsCausal();
|
||||||
Value scale = op.getScale();
|
Value scale = op.getScale();
|
||||||
|
@ -1671,13 +1672,13 @@ public:
|
||||||
double dropout;
|
double dropout;
|
||||||
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
|
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
|
||||||
dropout > 0.0)
|
dropout > 0.0)
|
||||||
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
|
return rewriter.notifyMatchFailure(loc, "dropout not supported");
|
||||||
|
|
||||||
bool causal;
|
bool causal;
|
||||||
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
|
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
|
||||||
if (!isa<Torch::NoneType>(mask.getType())) {
|
if (!isa<Torch::NoneType>(mask.getType())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op.getLoc(), "expected no attention mask when isCausal is true");
|
loc, "expected no attention mask when isCausal is true");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> maskStatic;
|
SmallVector<int64_t> maskStatic;
|
||||||
|
@ -1685,35 +1686,32 @@ public:
|
||||||
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
|
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
|
||||||
maskStatic.push_back(queryTy.getDimSize(i));
|
maskStatic.push_back(queryTy.getDimSize(i));
|
||||||
if (maskStatic.back() == ShapedType::kDynamic)
|
if (maskStatic.back() == ShapedType::kDynamic)
|
||||||
maskDyn.push_back(
|
maskDyn.push_back(rewriter.create<tensor::DimOp>(loc, query, i));
|
||||||
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
|
maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
|
||||||
if (maskStatic.back() == ShapedType::kDynamic)
|
if (maskStatic.back() == ShapedType::kDynamic)
|
||||||
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
|
maskDyn.push_back(
|
||||||
keyTy.getRank() - 2));
|
rewriter.create<tensor::DimOp>(loc, key, keyTy.getRank() - 2));
|
||||||
|
|
||||||
Type maskType = getElementTypeOrSelf(queryTy);
|
Type maskType = getElementTypeOrSelf(queryTy);
|
||||||
Value emptyMask = rewriter.create<tensor::EmptyOp>(
|
Value emptyMask =
|
||||||
op.getLoc(), maskStatic, maskType, maskDyn);
|
rewriter.create<tensor::EmptyOp>(loc, maskStatic, maskType, maskDyn);
|
||||||
|
|
||||||
Value zero = rewriter.create<arith::ConstantOp>(
|
Value zero = rewriter.create<arith::ConstantOp>(
|
||||||
op.getLoc(),
|
loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
|
||||||
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
|
|
||||||
Value negInf = rewriter.create<arith::ConstantOp>(
|
Value negInf = rewriter.create<arith::ConstantOp>(
|
||||||
op.getLoc(),
|
loc,
|
||||||
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
|
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
|
||||||
|
|
||||||
mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
|
mask = rewriter.create<linalg::FillOp>(loc, zero, emptyMask).getResult(0);
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
int64_t rank = cast<ShapedType>(queryTy).getRank();
|
int64_t rank = cast<ShapedType>(queryTy).getRank();
|
||||||
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
|
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
|
||||||
SmallVector<utils::IteratorType> iteratorTypes(
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
rank, utils::IteratorType::parallel);
|
rank, utils::IteratorType::parallel);
|
||||||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||||
op.getLoc(), mask.getType(), ValueRange{}, mask,
|
loc, mask.getType(), ValueRange{}, mask,
|
||||||
SmallVector<AffineMap>{maskMap}, iteratorTypes,
|
SmallVector<AffineMap>{maskMap}, iteratorTypes,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
|
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
|
||||||
|
@ -1727,18 +1725,78 @@ public:
|
||||||
mask = genericOp.getResult(0);
|
mask = genericOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Broadcast the batch dimensions of the mask:
|
||||||
|
if (!isa<Torch::NoneType>(mask.getType())) {
|
||||||
|
auto maskTy = cast<RankedTensorType>(mask.getType());
|
||||||
|
int64_t rank = maskTy.getRank();
|
||||||
|
bool needsBroadcast = false;
|
||||||
|
for (int i = 0, s = rank - 2; i < s; ++i) {
|
||||||
|
needsBroadcast |= maskTy.getDimSize(i) != keyTy.getDimSize(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (needsBroadcast) {
|
||||||
|
SmallVector<int64_t> maskShape;
|
||||||
|
SmallVector<Value> maskDynDims;
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> maskExprs;
|
||||||
|
for (int i = 0, s = rank - 2; i < s; ++i) {
|
||||||
|
maskShape.push_back(keyTy.getDimSize(i));
|
||||||
|
|
||||||
|
if (maskTy.getDimSize(i) != keyTy.getDimSize(i)) {
|
||||||
|
maskExprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||||
|
} else {
|
||||||
|
maskExprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (keyTy.isDynamicDim(i)) {
|
||||||
|
maskDynDims.push_back(rewriter.create<tensor::DimOp>(loc, key, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 2));
|
||||||
|
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 1));
|
||||||
|
maskShape.push_back(maskTy.getDimSize(rank - 2));
|
||||||
|
maskShape.push_back(maskTy.getDimSize(rank - 1));
|
||||||
|
if (maskTy.isDynamicDim(rank - 2))
|
||||||
|
maskDynDims.push_back(
|
||||||
|
rewriter.create<tensor::DimOp>(loc, mask, rank - 2));
|
||||||
|
if (maskTy.isDynamicDim(rank - 1))
|
||||||
|
maskDynDims.push_back(
|
||||||
|
rewriter.create<tensor::DimOp>(loc, mask, rank - 1));
|
||||||
|
|
||||||
|
SmallVector<AffineMap> affineMaps = {
|
||||||
|
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, maskExprs,
|
||||||
|
op.getContext()),
|
||||||
|
rewriter.getMultiDimIdentityMap(rank)};
|
||||||
|
SmallVector<utils::IteratorType> findMaxIteratorTypes(
|
||||||
|
rank, utils::IteratorType::parallel);
|
||||||
|
|
||||||
|
Value emptyMask = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, maskShape, maskTy.getElementType(), maskDynDims);
|
||||||
|
Value newMask =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, emptyMask.getType(), mask, ValueRange({emptyMask}),
|
||||||
|
affineMaps, findMaxIteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
b.create<linalg::YieldOp>(loc, args[0]);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
mask = newMask;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!isa<Torch::NoneType>(scale.getType())) {
|
if (!isa<Torch::NoneType>(scale.getType())) {
|
||||||
double scaleFloat;
|
double scaleFloat;
|
||||||
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
||||||
scaleFloat != 1.0)
|
scaleFloat != 1.0)
|
||||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
return rewriter.notifyMatchFailure(loc, "only default scale supported");
|
||||||
"only default scale supported");
|
|
||||||
}
|
}
|
||||||
bool isGQAEnabled;
|
bool isGQAEnabled;
|
||||||
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
|
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
|
||||||
isGQAEnabled)
|
isGQAEnabled)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op.getLoc(), "grouped query attention not supported");
|
loc, "grouped query attention not supported");
|
||||||
|
|
||||||
if (queryTy.getRank() != valueTy.getRank() ||
|
if (queryTy.getRank() != valueTy.getRank() ||
|
||||||
queryTy.getRank() != keyTy.getRank())
|
queryTy.getRank() != keyTy.getRank())
|
||||||
|
@ -1753,7 +1811,6 @@ public:
|
||||||
reassociation[1].push_back(valueTy.getRank() - 2);
|
reassociation[1].push_back(valueTy.getRank() - 2);
|
||||||
reassociation[2].push_back(valueTy.getRank() - 1);
|
reassociation[2].push_back(valueTy.getRank() - 1);
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
|
||||||
auto collapseBatch = [&rewriter, &reassociation,
|
auto collapseBatch = [&rewriter, &reassociation,
|
||||||
loc](Value value) -> Value {
|
loc](Value value) -> Value {
|
||||||
auto valueTy = cast<ShapedType>(value.getType());
|
auto valueTy = cast<ShapedType>(value.getType());
|
||||||
|
@ -1788,13 +1845,12 @@ public:
|
||||||
SmallVector<int64_t> valueSizes(
|
SmallVector<int64_t> valueSizes(
|
||||||
cast<ShapedType>(value.getType()).getShape());
|
cast<ShapedType>(value.getType()).getShape());
|
||||||
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
||||||
SmallVector<Value> outSizesDynamic(
|
SmallVector<Value> outSizesDynamic(getTensorSizes(rewriter, loc, query));
|
||||||
getTensorSizes(rewriter, op.getLoc(), query));
|
|
||||||
outSizesDynamic[outSizesDynamic.size() - 1] =
|
outSizesDynamic[outSizesDynamic.size() - 1] =
|
||||||
getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1];
|
getTensorSizes(rewriter, loc, value)[valueSizes.size() - 1];
|
||||||
Type outType = RankedTensorType::get(outSizes, elementType);
|
Type outType = RankedTensorType::get(outSizes, elementType);
|
||||||
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
Value output =
|
||||||
elementType);
|
createZeroInitTensor(rewriter, loc, outSizesDynamic, elementType);
|
||||||
|
|
||||||
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
|
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
|
||||||
|
|
||||||
|
|
|
@ -5501,7 +5501,7 @@ class ScaledDotProductAttentionMaskModule(torch.nn.Module):
|
||||||
([2, 3, 8, 16], torch.float32, True),
|
([2, 3, 8, 16], torch.float32, True),
|
||||||
([2, 3, 12, 16], torch.float32, True),
|
([2, 3, 12, 16], torch.float32, True),
|
||||||
([2, 3, 12, 20], torch.float32, True),
|
([2, 3, 12, 20], torch.float32, True),
|
||||||
([2, 3, 8, 12], torch.float32, True),
|
([2, 1, 8, 12], torch.float32, True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, query, key, value, mask):
|
def forward(self, query, key, value, mask):
|
||||||
|
@ -5513,7 +5513,7 @@ def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
|
||||||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||||
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||||
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32)
|
mask = torch.randn(2, 1, 8, 12, dtype=torch.float32)
|
||||||
module.forward(query, key, value, mask)
|
module.forward(query, key, value, mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue