mirror of https://github.com/llvm/torch-mlir
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product attention + relevant comments + tests. The tests added highlight the new capabilities introduced in this PR, including: Attention with F16 mask Attention with Boolean mask Causal attention with same Q K V shapes Causal attention without Q K V shapes Made sure that one cannot input both mask and is_causal.pull/3654/merge
parent
0a788e0467
commit
e86f56bc76
|
@ -252,13 +252,14 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||||
["generateScalarImplementation"]>]> {
|
["generateScalarImplementation"]>]> {
|
||||||
let summary = "Attention operator";
|
let summary = "Attention operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
|
This operator takes in 3 to 4 tensors: query(Q), key(K), value(V), and an
|
||||||
the attention. Each of the inputs has shape BxNxd where B is the
|
optional mask(M) to compute the attention. These tensors must take on shapes
|
||||||
of the batch dimension, N is the sequence length and d is head dimension.
|
BxMxK1 for Q, BxK2xK1 for K, BxK2xN for V, and BxMxK2 for M. For all these
|
||||||
Typically N >>> d. Mathematically, the attention is defined as
|
shapes, B represents the batch dimension, M represents sequence length, N
|
||||||
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
|
represents head dimension, and K1 and K2 are hidden dimensions.
|
||||||
this operator also performs scaling, masking and dropout, but we leave
|
Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and
|
||||||
that out of the current implementation.
|
has shape BxMxN. Usually, this operator also performs scaling, masking and
|
||||||
|
dropout, but we leave that out of the current implementation.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
||||||
|
@ -287,6 +288,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||||
Value getValue() {
|
Value getValue() {
|
||||||
return getInputOperand(2)->get();
|
return getInputOperand(2)->get();
|
||||||
}
|
}
|
||||||
|
std::optional<Value> getAttnMask() {
|
||||||
|
if (getNumInputs() < 4) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
return getInputOperand(3)->get();
|
||||||
|
}
|
||||||
Value getOutput() {
|
Value getOutput() {
|
||||||
return getOutputOperand(0)->get();
|
return getOutputOperand(0)->get();
|
||||||
}
|
}
|
||||||
|
@ -299,6 +306,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||||
ShapedType getValueType() {
|
ShapedType getValueType() {
|
||||||
return cast<ShapedType>(getValue().getType());
|
return cast<ShapedType>(getValue().getType());
|
||||||
}
|
}
|
||||||
|
std::optional<ShapedType> getAttnMaskType() {
|
||||||
|
if (getAttnMask()){
|
||||||
|
return cast<ShapedType>((*getAttnMask()).getType());
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
ShapedType getOutputType() {
|
ShapedType getOutputType() {
|
||||||
return cast<ShapedType>(getOutput().getType());
|
return cast<ShapedType>(getOutput().getType());
|
||||||
}
|
}
|
||||||
|
@ -311,6 +324,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
||||||
int64_t getValueRank() {
|
int64_t getValueRank() {
|
||||||
return getValueType().getRank();
|
return getValueType().getRank();
|
||||||
}
|
}
|
||||||
|
std::optional<int64_t> getAttnMaskRank() {
|
||||||
|
if (getAttnMask()){
|
||||||
|
return (*getAttnMaskType()).getRank();
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
int64_t getOutputRank() {
|
int64_t getOutputRank() {
|
||||||
return getOutputType().getRank();
|
return getOutputType().getRank();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1578,7 +1578,16 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
|
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value mask = op.getAttnMask();
|
|
||||||
|
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
|
||||||
|
auto query = adaptor.getQuery();
|
||||||
|
auto value = adaptor.getValue();
|
||||||
|
auto key = adaptor.getKey();
|
||||||
|
auto mask = adaptor.getAttnMask();
|
||||||
|
auto queryTy = cast<ShapedType>(query.getType());
|
||||||
|
auto valueTy = cast<ShapedType>(value.getType());
|
||||||
|
auto keyTy = cast<ShapedType>(key.getType());
|
||||||
|
|
||||||
Value dropoutP = op.getDropoutP();
|
Value dropoutP = op.getDropoutP();
|
||||||
Value isCausal = op.getIsCausal();
|
Value isCausal = op.getIsCausal();
|
||||||
Value scale = op.getScale();
|
Value scale = op.getScale();
|
||||||
|
@ -1586,18 +1595,77 @@ public:
|
||||||
Type elementType =
|
Type elementType =
|
||||||
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
|
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
|
||||||
|
|
||||||
// Verify inputs (only support defaults)
|
|
||||||
if (!isa<Torch::NoneType>(mask.getType()))
|
|
||||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
|
||||||
"attention masking not supported");
|
|
||||||
double dropout;
|
double dropout;
|
||||||
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
|
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
|
||||||
dropout > 0.0)
|
dropout > 0.0)
|
||||||
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
|
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
|
||||||
|
|
||||||
bool causal;
|
bool causal;
|
||||||
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
|
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
|
||||||
return rewriter.notifyMatchFailure(
|
if (!isa<Torch::NoneType>(mask.getType())) {
|
||||||
op.getLoc(), "causal attention masking not supported");
|
return rewriter.notifyMatchFailure(
|
||||||
|
op.getLoc(), "expected no attention mask when isCausal is true");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> maskSizes;
|
||||||
|
|
||||||
|
if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
|
||||||
|
auto seqLenQ =
|
||||||
|
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
|
||||||
|
auto seqLenK =
|
||||||
|
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
|
||||||
|
maskSizes = {seqLenQ, seqLenK};
|
||||||
|
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
|
||||||
|
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
|
||||||
|
maskSizes.insert(maskSizes.begin(), batchSize);
|
||||||
|
}
|
||||||
|
} else { // Dynamic shape case: <?x?x...x?xf32> for example
|
||||||
|
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
|
||||||
|
Value batchSize =
|
||||||
|
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
|
||||||
|
maskSizes.push_back(batchSize);
|
||||||
|
}
|
||||||
|
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
|
||||||
|
queryTy.getRank() - 2);
|
||||||
|
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
|
||||||
|
keyTy.getRank() - 2);
|
||||||
|
maskSizes.push_back(seqLenQ);
|
||||||
|
maskSizes.push_back(seqLenK);
|
||||||
|
}
|
||||||
|
|
||||||
|
Type maskType = getElementTypeOrSelf(queryTy);
|
||||||
|
Value emptyMask =
|
||||||
|
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
|
||||||
|
|
||||||
|
Value zero = rewriter.create<arith::ConstantOp>(
|
||||||
|
op.getLoc(),
|
||||||
|
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
|
||||||
|
Value negInf = rewriter.create<arith::ConstantOp>(
|
||||||
|
op.getLoc(),
|
||||||
|
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
|
||||||
|
|
||||||
|
mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
int64_t rank = cast<ShapedType>(queryTy).getRank();
|
||||||
|
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
|
rank, utils::IteratorType::parallel);
|
||||||
|
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||||
|
op.getLoc(), mask.getType(), ValueRange{}, mask,
|
||||||
|
SmallVector<AffineMap>{maskMap}, iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
|
||||||
|
Value j = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 1);
|
||||||
|
|
||||||
|
Value cond =
|
||||||
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, i, j);
|
||||||
|
Value select = b.create<arith::SelectOp>(loc, cond, zero, negInf);
|
||||||
|
b.create<linalg::YieldOp>(loc, select);
|
||||||
|
});
|
||||||
|
mask = genericOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
if (!isa<Torch::NoneType>(scale.getType())) {
|
if (!isa<Torch::NoneType>(scale.getType())) {
|
||||||
double scaleFloat;
|
double scaleFloat;
|
||||||
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
||||||
|
@ -1611,14 +1679,6 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op.getLoc(), "grouped query attention not supported");
|
op.getLoc(), "grouped query attention not supported");
|
||||||
|
|
||||||
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
|
|
||||||
auto query = adaptor.getQuery();
|
|
||||||
auto value = adaptor.getValue();
|
|
||||||
auto key = adaptor.getKey();
|
|
||||||
auto queryTy = cast<ShapedType>(query.getType());
|
|
||||||
auto valueTy = cast<ShapedType>(value.getType());
|
|
||||||
auto keyTy = cast<ShapedType>(key.getType());
|
|
||||||
|
|
||||||
if (queryTy.getRank() != valueTy.getRank() ||
|
if (queryTy.getRank() != valueTy.getRank() ||
|
||||||
queryTy.getRank() != keyTy.getRank())
|
queryTy.getRank() != keyTy.getRank())
|
||||||
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
|
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
|
||||||
|
@ -1659,6 +1719,9 @@ public:
|
||||||
query = collapseBatch(query);
|
query = collapseBatch(query);
|
||||||
key = collapseBatch(key);
|
key = collapseBatch(key);
|
||||||
value = collapseBatch(value);
|
value = collapseBatch(value);
|
||||||
|
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
|
||||||
|
mask = collapseBatch(mask);
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
|
SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
|
||||||
SmallVector<int64_t> valueSizes(
|
SmallVector<int64_t> valueSizes(
|
||||||
|
@ -1672,13 +1735,17 @@ public:
|
||||||
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
||||||
elementType);
|
elementType);
|
||||||
|
|
||||||
|
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
|
||||||
|
|
||||||
|
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
|
||||||
|
inputs.push_back(mask);
|
||||||
|
}
|
||||||
|
|
||||||
// Overwrite with tm_tensor::attention
|
// Overwrite with tm_tensor::attention
|
||||||
Value attention =
|
Value attention = rewriter
|
||||||
rewriter
|
.create<AttentionOp>(loc, outType, inputs,
|
||||||
.create<AttentionOp>(loc, outType,
|
SmallVector<Value>{output})
|
||||||
SmallVector<Value>{query, key, value},
|
.getResult()[0];
|
||||||
SmallVector<Value>{output})
|
|
||||||
.getResult()[0];
|
|
||||||
|
|
||||||
if (opTy != outType) {
|
if (opTy != outType) {
|
||||||
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,
|
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,
|
||||||
|
|
|
@ -93,14 +93,49 @@ LogicalResult AttentionOp::verify() {
|
||||||
Operation *op = getOperation();
|
Operation *op = getOperation();
|
||||||
ShapedType queryType = getQueryType();
|
ShapedType queryType = getQueryType();
|
||||||
ShapedType keyType = getKeyType();
|
ShapedType keyType = getKeyType();
|
||||||
|
ShapedType valueType = getValueType();
|
||||||
|
|
||||||
|
auto optionalMaskType = getAttnMaskType();
|
||||||
|
ShapedType maskType = optionalMaskType ? *optionalMaskType : ShapedType();
|
||||||
|
|
||||||
ArrayRef<int64_t> queryShape = queryType.getShape();
|
ArrayRef<int64_t> queryShape = queryType.getShape();
|
||||||
ArrayRef<int64_t> keyShape = keyType.getShape();
|
ArrayRef<int64_t> keyShape = keyType.getShape();
|
||||||
|
ArrayRef<int64_t> valueShape = valueType.getShape();
|
||||||
|
ArrayRef<int64_t> maskShape =
|
||||||
|
optionalMaskType ? maskType.getShape() : ArrayRef<int64_t>();
|
||||||
|
|
||||||
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
|
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
|
||||||
if (keyShape[i] != queryShape[i])
|
if (keyShape[i] != queryShape[i]) {
|
||||||
return op->emitOpError("query and key batch mismatch");
|
return op->emitOpError("query and key batch mismatch");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (keyShape.back() != queryShape.back())
|
if (keyShape.back() != queryShape.back()) {
|
||||||
return op->emitOpError("query and key head dimension mismatch");
|
return op->emitOpError("query and key head dimension mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
|
||||||
|
if (valueShape[i] != queryShape[i]) {
|
||||||
|
return op->emitOpError("query and value batch dimension mismatch");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (keyShape[keyShape.size() - 2] != valueShape[valueShape.size() - 2]) {
|
||||||
|
return op->emitOpError("key and value sequence length dimension mismatch");
|
||||||
|
}
|
||||||
|
if (optionalMaskType) {
|
||||||
|
for (int i = 0, s = maskShape.size() - 2; i < s; ++i) {
|
||||||
|
if (maskShape[i] != queryShape[i]) {
|
||||||
|
return op->emitOpError("query and mask batch dimension mismatch");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (maskShape[maskShape.size() - 2] != queryShape[queryShape.size() - 2]) {
|
||||||
|
return op->emitOpError(
|
||||||
|
"mask sequence length and query sequence length mismatch");
|
||||||
|
}
|
||||||
|
if (maskShape[maskShape.size() - 1] != keyShape[keyShape.size() - 2]) {
|
||||||
|
return op->emitOpError(
|
||||||
|
"mask sequence lengt and key sequence length mismatch");
|
||||||
|
}
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,10 +203,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
Value query = getQuery();
|
Value query = getQuery();
|
||||||
Value key = getKey();
|
Value key = getKey();
|
||||||
Value value = getValue();
|
Value value = getValue();
|
||||||
|
|
||||||
|
auto optionalMask = getAttnMask();
|
||||||
|
Value mask = optionalMask ? *optionalMask : Value();
|
||||||
|
|
||||||
Value output = getOutput();
|
Value output = getOutput();
|
||||||
auto queryType = cast<MemRefType>(query.getType());
|
auto queryType = cast<MemRefType>(query.getType());
|
||||||
auto keyType = cast<MemRefType>(key.getType());
|
auto keyType = cast<MemRefType>(key.getType());
|
||||||
auto valueType = cast<MemRefType>(value.getType());
|
auto valueType = cast<MemRefType>(value.getType());
|
||||||
|
auto maskType = mask ? cast<MemRefType>(mask.getType()) : MemRefType();
|
||||||
auto queryRank = queryType.getRank();
|
auto queryRank = queryType.getRank();
|
||||||
auto keyRank = keyType.getRank();
|
auto keyRank = keyType.getRank();
|
||||||
auto valueRank = valueType.getRank();
|
auto valueRank = valueType.getRank();
|
||||||
|
@ -180,6 +220,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
|
|
||||||
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
||||||
b.getFloatAttr(elementType, 0.0));
|
b.getFloatAttr(elementType, 0.0));
|
||||||
|
Value negInfF = b.create<arith::ConstantOp>(
|
||||||
|
loc, elementType,
|
||||||
|
b.getFloatAttr(elementType, -std::numeric_limits<double>::infinity()));
|
||||||
|
|
||||||
// TODO: This needs to be fixed, it assumes everything is dynamic however if
|
// TODO: This needs to be fixed, it assumes everything is dynamic however if
|
||||||
// any shapes are static the `memref.alloc` generated is illegal.
|
// any shapes are static the `memref.alloc` generated is illegal.
|
||||||
|
@ -214,14 +257,43 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
/*transposed=*/true);
|
/*transposed=*/true);
|
||||||
|
|
||||||
// weight = softmax(weight)
|
// weight = softmax(weight)
|
||||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
|
||||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
|
||||||
Value dim = weightDynSizes[weightRank - 1];
|
Value dim = weightDynSizes[weightRank - 1];
|
||||||
Value scaleFactor = b.create<math::SqrtOp>(
|
Value scaleFactor = b.create<math::SqrtOp>(
|
||||||
loc, b.create<arith::UIToFPOp>(
|
loc, b.create<arith::UIToFPOp>(
|
||||||
loc, elementType,
|
loc, elementType,
|
||||||
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
|
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
|
||||||
queryDynSizes[queryRank - 1])));
|
queryDynSizes[queryRank - 1])));
|
||||||
|
|
||||||
|
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
|
||||||
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
||||||
|
SmallVector<Value>(weightRank, one),
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
|
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
||||||
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Apply mask to weights if mask is given
|
||||||
|
if (mask) {
|
||||||
|
b.create<scf::ParallelOp>(
|
||||||
|
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
||||||
|
SmallVector<Value>(weightRank, one),
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
|
Value weightValue = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
|
Value maskValue = b.create<memref::LoadOp>(loc, mask, localIVs);
|
||||||
|
if (maskType.getElementType().isInteger(1)) {
|
||||||
|
maskValue =
|
||||||
|
b.create<arith::SelectOp>(loc, maskValue, zeroF, negInfF);
|
||||||
|
}
|
||||||
|
Value maskedWeight =
|
||||||
|
b.create<arith::AddFOp>(loc, weightValue, maskValue);
|
||||||
|
b.create<memref::StoreOp>(loc, maskedWeight, weight, localIVs);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// calculate max(weight)
|
// calculate max(weight)
|
||||||
Value init = b.create<memref::LoadOp>(loc, weight,
|
Value init = b.create<memref::LoadOp>(loc, weight,
|
||||||
SmallVector<Value>(weightRank, zero));
|
SmallVector<Value>(weightRank, zero));
|
||||||
|
@ -249,7 +321,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
x = b.create<arith::SubFOp>(loc, x, globalMax);
|
x = b.create<arith::SubFOp>(loc, x, globalMax);
|
||||||
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
|
||||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||||
});
|
});
|
||||||
// calculate exp(weight)
|
// calculate exp(weight)
|
||||||
|
@ -307,10 +378,19 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||||
SmallVector<Value> sumIVs(localIVs);
|
SmallVector<Value> sumIVs(localIVs);
|
||||||
sumIVs.pop_back();
|
sumIVs.pop_back();
|
||||||
|
|
||||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||||
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
|
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
|
||||||
x = b.create<arith::DivFOp>(loc, x, sum);
|
Value divResult = b.create<arith::DivFOp>(loc, x, sum);
|
||||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
|
||||||
|
// Set to 0 if sum is 0 (can occur during boolean mask / large negative
|
||||||
|
// QK)
|
||||||
|
Value isSumZero =
|
||||||
|
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, sum, zeroF);
|
||||||
|
Value result =
|
||||||
|
b.create<arith::SelectOp>(loc, isSumZero, zeroF, divResult);
|
||||||
|
|
||||||
|
b.create<memref::StoreOp>(loc, result, weight, localIVs);
|
||||||
});
|
});
|
||||||
|
|
||||||
// output = weight @ value
|
// output = weight @ value
|
||||||
|
|
|
@ -34,7 +34,13 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||||
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
||||||
LINALG_XFAIL_SET = LINALG_XFAIL_SET | {
|
LINALG_XFAIL_SET = LINALG_XFAIL_SET | {
|
||||||
# Error: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7
|
# Error: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7
|
||||||
|
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
|
||||||
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
"ScaledDotProductAttentionMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
"ScaledDotProductAttentionSameModule_basic",
|
"ScaledDotProductAttentionSameModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -498,7 +504,13 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
"WeightNormInterfaceModule_basic",
|
"WeightNormInterfaceModule_basic",
|
||||||
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
"ScaledDotProductAttentionMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
"ScaledDotProductAttentionSameModule_basic",
|
"ScaledDotProductAttentionSameModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -780,6 +792,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
"ScaledDotProductAttentionMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameModule_basic",
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMeanModule",
|
"ScatterReduceFloatMeanModule",
|
||||||
|
@ -2179,6 +2199,8 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
||||||
MAKE_FX_TOSA_PASS_SET = MAKE_FX_TOSA_PASS_SET | {
|
MAKE_FX_TOSA_PASS_SET = MAKE_FX_TOSA_PASS_SET | {
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
"ScaledDotProductAttentionMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
LTC_CRASHING_SET = {
|
LTC_CRASHING_SET = {
|
||||||
|
@ -2932,6 +2954,12 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
||||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if torch_version_for_comparison() >= version.parse("2.5.0.dev"):
|
||||||
|
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
|
||||||
|
# ERROR: value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=+nan, max=+nan, mean=+nan) is not close to golden value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=-2.394, max=+2.454, mean=-0.02828)
|
||||||
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
}
|
||||||
|
|
||||||
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
||||||
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
|
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
|
@ -3009,8 +3037,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ReduceAminmaxSingleDim_basic",
|
"ReduceAminmaxSingleDim_basic",
|
||||||
"ReduceAnyDimFloatModule_basic",
|
"ReduceAnyDimFloatModule_basic",
|
||||||
"RenormModuleFloat16_basic",
|
"RenormModuleFloat16_basic",
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
"ScaledDotProductAttentionSameModule_basic",
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
"ScatterAddStaticModule_basic",
|
"ScatterAddStaticModule_basic",
|
||||||
"TensorsConcatComplex128FloatModule_basic",
|
"TensorsConcatComplex128FloatModule_basic",
|
||||||
"TensorsConcatComplex128IntModule_basic",
|
"TensorsConcatComplex128IntModule_basic",
|
||||||
|
@ -4548,7 +4579,11 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
"ScaledDotProductAttentionSameModule_basic",
|
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||||
|
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMeanModule",
|
"ScatterReduceFloatMeanModule",
|
||||||
|
|
|
@ -5107,9 +5107,9 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module):
|
||||||
@annotate_args(
|
@annotate_args(
|
||||||
[
|
[
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1], torch.float32, True),
|
([1, 5, 5], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([1, 5, 5], torch.float32, True),
|
||||||
([-1, -1, -1], torch.float32, True),
|
([1, 5, 5], torch.float32, True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, query, key, value):
|
def forward(self, query, key, value):
|
||||||
|
@ -5124,6 +5124,58 @@ def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
|
||||||
module.forward(query, key, value)
|
module.forward(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledDotProductAttentionSameDynamicModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameDynamicModule())
|
||||||
|
def ScaledDotProductAttentionSameDynamicModule_basic(module, tu: TestUtils):
|
||||||
|
query = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
|
key = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
|
value = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
|
module.forward(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledDotProductAttentionSameCausalModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
return torch.ops.aten.scaled_dot_product_attention(
|
||||||
|
query, key, value, is_causal=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameCausalModule())
|
||||||
|
def ScaledDotProductAttentionSameCausalModule_basic(module, tu: TestUtils):
|
||||||
|
query = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
|
key = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
|
value = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||||
|
module.forward(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -5132,9 +5184,9 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||||
@annotate_args(
|
@annotate_args(
|
||||||
[
|
[
|
||||||
None,
|
None,
|
||||||
([2, 3, 8, 4], torch.float32, True),
|
([2, 3, 8, 16], torch.float32, True),
|
||||||
([2, 3, 16, 4], torch.float32, True),
|
([2, 3, 12, 16], torch.float32, True),
|
||||||
([2, 3, 16, 4], torch.float32, True),
|
([2, 3, 12, 20], torch.float32, True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, query, key, value):
|
def forward(self, query, key, value):
|
||||||
|
@ -5143,12 +5195,95 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
|
||||||
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
|
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
|
||||||
query = torch.randn(2, 3, 8, 4, dtype=torch.float32)
|
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||||
key = torch.randn(2, 3, 16, 4, dtype=torch.float32)
|
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||||
value = torch.randn(2, 3, 16, 4, dtype=torch.float32)
|
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||||
module.forward(query, key, value)
|
module.forward(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3, 8, 16], torch.float32, True),
|
||||||
|
([2, 3, 12, 16], torch.float32, True),
|
||||||
|
([2, 3, 12, 20], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
return torch.ops.aten.scaled_dot_product_attention(
|
||||||
|
query, key, value, is_causal=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
|
||||||
|
)
|
||||||
|
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
|
||||||
|
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||||
|
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||||
|
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||||
|
module.forward(query, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledDotProductAttentionMaskModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3, 8, 16], torch.float32, True),
|
||||||
|
([2, 3, 12, 16], torch.float32, True),
|
||||||
|
([2, 3, 12, 20], torch.float32, True),
|
||||||
|
([2, 3, 8, 12], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, query, key, value, mask):
|
||||||
|
return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionMaskModule())
|
||||||
|
def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
|
||||||
|
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||||
|
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||||
|
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||||
|
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32)
|
||||||
|
module.forward(query, key, value, mask)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledDotProductAttentionBoolMaskModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 3, 8, 16], torch.float32, True),
|
||||||
|
([2, 3, 12, 16], torch.float32, True),
|
||||||
|
([2, 3, 12, 20], torch.float32, True),
|
||||||
|
([2, 3, 8, 12], torch.bool, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, query, key, value, mask):
|
||||||
|
return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ScaledDotProductAttentionBoolMaskModule())
|
||||||
|
def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils):
|
||||||
|
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||||
|
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||||
|
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||||
|
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) > 0.5
|
||||||
|
module.forward(query, key, value, mask)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue