[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)

Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
pull/3654/merge
rohan-tan-bhowmik 2024-09-09 15:51:41 -07:00 committed by GitHub
parent 0a788e0467
commit e86f56bc76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 384 additions and 48 deletions

View File

@ -252,13 +252,14 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
["generateScalarImplementation"]>]> { ["generateScalarImplementation"]>]> {
let summary = "Attention operator"; let summary = "Attention operator";
let description = [{ let description = [{
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes This operator takes in 3 to 4 tensors: query(Q), key(K), value(V), and an
the attention. Each of the inputs has shape BxNxd where B is the optional mask(M) to compute the attention. These tensors must take on shapes
of the batch dimension, N is the sequence length and d is head dimension. BxMxK1 for Q, BxK2xK1 for K, BxK2xN for V, and BxMxK2 for M. For all these
Typically N >>> d. Mathematically, the attention is defined as shapes, B represents the batch dimension, M represents sequence length, N
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually, represents head dimension, and K1 and K2 are hidden dimensions.
this operator also performs scaling, masking and dropout, but we leave Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and
that out of the current implementation. has shape BxMxN. Usually, this operator also performs scaling, masking and
dropout, but we leave that out of the current implementation.
}]; }];
let arguments = (ins Variadic<AnyShaped>:$inputs, let arguments = (ins Variadic<AnyShaped>:$inputs,
@ -287,6 +288,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
Value getValue() { Value getValue() {
return getInputOperand(2)->get(); return getInputOperand(2)->get();
} }
std::optional<Value> getAttnMask() {
if (getNumInputs() < 4) {
return std::nullopt;
}
return getInputOperand(3)->get();
}
Value getOutput() { Value getOutput() {
return getOutputOperand(0)->get(); return getOutputOperand(0)->get();
} }
@ -299,6 +306,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
ShapedType getValueType() { ShapedType getValueType() {
return cast<ShapedType>(getValue().getType()); return cast<ShapedType>(getValue().getType());
} }
std::optional<ShapedType> getAttnMaskType() {
if (getAttnMask()){
return cast<ShapedType>((*getAttnMask()).getType());
}
return std::nullopt;
}
ShapedType getOutputType() { ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType()); return cast<ShapedType>(getOutput().getType());
} }
@ -311,6 +324,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
int64_t getValueRank() { int64_t getValueRank() {
return getValueType().getRank(); return getValueType().getRank();
} }
std::optional<int64_t> getAttnMaskRank() {
if (getAttnMask()){
return (*getAttnMaskType()).getRank();
}
return std::nullopt;
}
int64_t getOutputRank() { int64_t getOutputRank() {
return getOutputType().getRank(); return getOutputType().getRank();
} }

View File

@ -1578,7 +1578,16 @@ public:
LogicalResult LogicalResult
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor, matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Value mask = op.getAttnMask();
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto mask = adaptor.getAttnMask();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());
Value dropoutP = op.getDropoutP(); Value dropoutP = op.getDropoutP();
Value isCausal = op.getIsCausal(); Value isCausal = op.getIsCausal();
Value scale = op.getScale(); Value scale = op.getScale();
@ -1586,18 +1595,77 @@ public:
Type elementType = Type elementType =
cast<ShapedType>(adaptor.getQuery().getType()).getElementType(); cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
// Verify inputs (only support defaults)
if (!isa<Torch::NoneType>(mask.getType()))
return rewriter.notifyMatchFailure(op.getLoc(),
"attention masking not supported");
double dropout; double dropout;
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) || if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
dropout > 0.0) dropout > 0.0)
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported"); return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
bool causal; bool causal;
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
return rewriter.notifyMatchFailure( if (!isa<Torch::NoneType>(mask.getType())) {
op.getLoc(), "causal attention masking not supported"); return rewriter.notifyMatchFailure(
op.getLoc(), "expected no attention mask when isCausal is true");
}
SmallVector<OpFoldResult> maskSizes;
if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
auto seqLenQ =
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
auto seqLenK =
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
maskSizes = {seqLenQ, seqLenK};
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
maskSizes.insert(maskSizes.begin(), batchSize);
}
} else { // Dynamic shape case: <?x?x...x?xf32> for example
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
Value batchSize =
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
maskSizes.push_back(batchSize);
}
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
queryTy.getRank() - 2);
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2);
maskSizes.push_back(seqLenQ);
maskSizes.push_back(seqLenK);
}
Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask =
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
Value negInf = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
.getResult(0);
int64_t rank = cast<ShapedType>(queryTy).getRank();
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
auto genericOp = rewriter.create<linalg::GenericOp>(
op.getLoc(), mask.getType(), ValueRange{}, mask,
SmallVector<AffineMap>{maskMap}, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
Value j = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 1);
Value cond =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, i, j);
Value select = b.create<arith::SelectOp>(loc, cond, zero, negInf);
b.create<linalg::YieldOp>(loc, select);
});
mask = genericOp.getResult(0);
}
if (!isa<Torch::NoneType>(scale.getType())) { if (!isa<Torch::NoneType>(scale.getType())) {
double scaleFloat; double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) || if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
@ -1611,14 +1679,6 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op.getLoc(), "grouped query attention not supported"); op.getLoc(), "grouped query attention not supported");
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());
if (queryTy.getRank() != valueTy.getRank() || if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank()) queryTy.getRank() != keyTy.getRank())
return rewriter.notifyMatchFailure(op, "operand ranks do not match"); return rewriter.notifyMatchFailure(op, "operand ranks do not match");
@ -1659,6 +1719,9 @@ public:
query = collapseBatch(query); query = collapseBatch(query);
key = collapseBatch(key); key = collapseBatch(key);
value = collapseBatch(value); value = collapseBatch(value);
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
mask = collapseBatch(mask);
}
SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape()); SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
SmallVector<int64_t> valueSizes( SmallVector<int64_t> valueSizes(
@ -1672,13 +1735,17 @@ public:
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic, Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
elementType); elementType);
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
inputs.push_back(mask);
}
// Overwrite with tm_tensor::attention // Overwrite with tm_tensor::attention
Value attention = Value attention = rewriter
rewriter .create<AttentionOp>(loc, outType, inputs,
.create<AttentionOp>(loc, outType, SmallVector<Value>{output})
SmallVector<Value>{query, key, value}, .getResult()[0];
SmallVector<Value>{output})
.getResult()[0];
if (opTy != outType) { if (opTy != outType) {
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention, attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,

View File

@ -93,14 +93,49 @@ LogicalResult AttentionOp::verify() {
Operation *op = getOperation(); Operation *op = getOperation();
ShapedType queryType = getQueryType(); ShapedType queryType = getQueryType();
ShapedType keyType = getKeyType(); ShapedType keyType = getKeyType();
ShapedType valueType = getValueType();
auto optionalMaskType = getAttnMaskType();
ShapedType maskType = optionalMaskType ? *optionalMaskType : ShapedType();
ArrayRef<int64_t> queryShape = queryType.getShape(); ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape(); ArrayRef<int64_t> keyShape = keyType.getShape();
ArrayRef<int64_t> valueShape = valueType.getShape();
ArrayRef<int64_t> maskShape =
optionalMaskType ? maskType.getShape() : ArrayRef<int64_t>();
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) { for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (keyShape[i] != queryShape[i]) if (keyShape[i] != queryShape[i]) {
return op->emitOpError("query and key batch mismatch"); return op->emitOpError("query and key batch mismatch");
}
} }
if (keyShape.back() != queryShape.back()) if (keyShape.back() != queryShape.back()) {
return op->emitOpError("query and key head dimension mismatch"); return op->emitOpError("query and key head dimension mismatch");
}
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (valueShape[i] != queryShape[i]) {
return op->emitOpError("query and value batch dimension mismatch");
}
}
if (keyShape[keyShape.size() - 2] != valueShape[valueShape.size() - 2]) {
return op->emitOpError("key and value sequence length dimension mismatch");
}
if (optionalMaskType) {
for (int i = 0, s = maskShape.size() - 2; i < s; ++i) {
if (maskShape[i] != queryShape[i]) {
return op->emitOpError("query and mask batch dimension mismatch");
}
}
if (maskShape[maskShape.size() - 2] != queryShape[queryShape.size() - 2]) {
return op->emitOpError(
"mask sequence length and query sequence length mismatch");
}
if (maskShape[maskShape.size() - 1] != keyShape[keyShape.size() - 2]) {
return op->emitOpError(
"mask sequence lengt and key sequence length mismatch");
}
}
return success(); return success();
} }
@ -168,10 +203,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value query = getQuery(); Value query = getQuery();
Value key = getKey(); Value key = getKey();
Value value = getValue(); Value value = getValue();
auto optionalMask = getAttnMask();
Value mask = optionalMask ? *optionalMask : Value();
Value output = getOutput(); Value output = getOutput();
auto queryType = cast<MemRefType>(query.getType()); auto queryType = cast<MemRefType>(query.getType());
auto keyType = cast<MemRefType>(key.getType()); auto keyType = cast<MemRefType>(key.getType());
auto valueType = cast<MemRefType>(value.getType()); auto valueType = cast<MemRefType>(value.getType());
auto maskType = mask ? cast<MemRefType>(mask.getType()) : MemRefType();
auto queryRank = queryType.getRank(); auto queryRank = queryType.getRank();
auto keyRank = keyType.getRank(); auto keyRank = keyType.getRank();
auto valueRank = valueType.getRank(); auto valueRank = valueType.getRank();
@ -180,6 +220,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value zeroF = b.create<arith::ConstantOp>(loc, elementType, Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
b.getFloatAttr(elementType, 0.0)); b.getFloatAttr(elementType, 0.0));
Value negInfF = b.create<arith::ConstantOp>(
loc, elementType,
b.getFloatAttr(elementType, -std::numeric_limits<double>::infinity()));
// TODO: This needs to be fixed, it assumes everything is dynamic however if // TODO: This needs to be fixed, it assumes everything is dynamic however if
// any shapes are static the `memref.alloc` generated is illegal. // any shapes are static the `memref.alloc` generated is illegal.
@ -214,14 +257,43 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
/*transposed=*/true); /*transposed=*/true);
// weight = softmax(weight) // weight = softmax(weight)
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value dim = weightDynSizes[weightRank - 1]; Value dim = weightDynSizes[weightRank - 1];
Value scaleFactor = b.create<math::SqrtOp>( Value scaleFactor = b.create<math::SqrtOp>(
loc, b.create<arith::UIToFPOp>( loc, b.create<arith::UIToFPOp>(
loc, elementType, loc, elementType,
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(), b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
queryDynSizes[queryRank - 1]))); queryDynSizes[queryRank - 1])));
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
});
// Apply mask to weights if mask is given
if (mask) {
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value weightValue = b.create<memref::LoadOp>(loc, weight, localIVs);
Value maskValue = b.create<memref::LoadOp>(loc, mask, localIVs);
if (maskType.getElementType().isInteger(1)) {
maskValue =
b.create<arith::SelectOp>(loc, maskValue, zeroF, negInfF);
}
Value maskedWeight =
b.create<arith::AddFOp>(loc, weightValue, maskValue);
b.create<memref::StoreOp>(loc, maskedWeight, weight, localIVs);
});
}
// calculate max(weight) // calculate max(weight)
Value init = b.create<memref::LoadOp>(loc, weight, Value init = b.create<memref::LoadOp>(loc, weight,
SmallVector<Value>(weightRank, zero)); SmallVector<Value>(weightRank, zero));
@ -249,7 +321,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
[&](OpBuilder &b, Location loc, ValueRange localIVs) { [&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs); Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::SubFOp>(loc, x, globalMax); x = b.create<arith::SubFOp>(loc, x, globalMax);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs); b.create<memref::StoreOp>(loc, x, weight, localIVs);
}); });
// calculate exp(weight) // calculate exp(weight)
@ -307,10 +378,19 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
[&](OpBuilder &b, Location loc, ValueRange localIVs) { [&](OpBuilder &b, Location loc, ValueRange localIVs) {
SmallVector<Value> sumIVs(localIVs); SmallVector<Value> sumIVs(localIVs);
sumIVs.pop_back(); sumIVs.pop_back();
Value x = b.create<memref::LoadOp>(loc, weight, localIVs); Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs); Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
x = b.create<arith::DivFOp>(loc, x, sum); Value divResult = b.create<arith::DivFOp>(loc, x, sum);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
// Set to 0 if sum is 0 (can occur during boolean mask / large negative
// QK)
Value isSumZero =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, sum, zeroF);
Value result =
b.create<arith::SelectOp>(loc, isSumZero, zeroF, divResult);
b.create<memref::StoreOp>(loc, result, weight, localIVs);
}); });
// output = weight @ value // output = weight @ value

View File

@ -34,7 +34,13 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
if torch_version_for_comparison() < version.parse("2.5.0.dev"): if torch_version_for_comparison() < version.parse("2.5.0.dev"):
LINALG_XFAIL_SET = LINALG_XFAIL_SET | { LINALG_XFAIL_SET = LINALG_XFAIL_SET | {
# Error: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7 # Error: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic", "ScaledDotProductAttentionSameModule_basic",
} }
@ -498,7 +504,13 @@ FX_IMPORTER_XFAIL_SET = {
"ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
"WeightNormInterfaceModule_basic", "WeightNormInterfaceModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic", "ScaledDotProductAttentionSameModule_basic",
} }
@ -780,6 +792,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"RsubInt0d_NumToTensor_Module_basic", "RsubInt0d_NumToTensor_Module_basic",
"ScalarConstantTupleModule_basic", "ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic", "ScalarImplicitFloatModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule", "ScatterReduceFloatMeanModule",
@ -2179,6 +2199,8 @@ MAKE_FX_TOSA_PASS_SET = (
if torch_version_for_comparison() < version.parse("2.5.0.dev"): if torch_version_for_comparison() < version.parse("2.5.0.dev"):
MAKE_FX_TOSA_PASS_SET = MAKE_FX_TOSA_PASS_SET | { MAKE_FX_TOSA_PASS_SET = MAKE_FX_TOSA_PASS_SET | {
"ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameModule_basic",
} }
LTC_CRASHING_SET = { LTC_CRASHING_SET = {
@ -2932,6 +2954,12 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
"ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic",
} }
if torch_version_for_comparison() >= version.parse("2.5.0.dev"):
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
# ERROR: value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=+nan, max=+nan, mean=+nan) is not close to golden value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=-2.394, max=+2.454, mean=-0.02828)
"ScaledDotProductAttentionBoolMaskModule_basic",
}
if torch_version_for_comparison() < version.parse("2.4.0.dev"): if torch_version_for_comparison() < version.parse("2.4.0.dev"):
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
"AtenIntMM_basic", "AtenIntMM_basic",
@ -3009,8 +3037,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReduceAminmaxSingleDim_basic", "ReduceAminmaxSingleDim_basic",
"ReduceAnyDimFloatModule_basic", "ReduceAnyDimFloatModule_basic",
"RenormModuleFloat16_basic", "RenormModuleFloat16_basic",
"ScaledDotProductAttentionDifferentModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionSameModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterAddStaticModule_basic", "ScatterAddStaticModule_basic",
"TensorsConcatComplex128FloatModule_basic", "TensorsConcatComplex128FloatModule_basic",
"TensorsConcatComplex128IntModule_basic", "TensorsConcatComplex128IntModule_basic",
@ -4548,7 +4579,11 @@ ONNX_TOSA_XFAIL_SET = {
"ScalarConstantTupleModule_basic", "ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic", "ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic", "ScalarImplicitIntModule_basic",
"ScaledDotProductAttentionSameModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule", "ScatterReduceFloatMeanModule",

View File

@ -5107,9 +5107,9 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module):
@annotate_args( @annotate_args(
[ [
None, None,
([-1, -1, -1], torch.float32, True), ([1, 5, 5], torch.float32, True),
([-1, -1, -1], torch.float32, True), ([1, 5, 5], torch.float32, True),
([-1, -1, -1], torch.float32, True), ([1, 5, 5], torch.float32, True),
] ]
) )
def forward(self, query, key, value): def forward(self, query, key, value):
@ -5124,6 +5124,58 @@ def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
module.forward(query, key, value) module.forward(query, key, value)
class ScaledDotProductAttentionSameDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameDynamicModule())
def ScaledDotProductAttentionSameDynamicModule_basic(module, tu: TestUtils):
query = torch.randn(1, 5, 5, dtype=torch.float32)
key = torch.randn(1, 5, 5, dtype=torch.float32)
value = torch.randn(1, 5, 5, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionSameCausalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, is_causal=True
)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameCausalModule())
def ScaledDotProductAttentionSameCausalModule_basic(module, tu: TestUtils):
query = torch.randn(1, 5, 5, dtype=torch.float32)
key = torch.randn(1, 5, 5, dtype=torch.float32)
value = torch.randn(1, 5, 5, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionDifferentModule(torch.nn.Module): class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -5132,9 +5184,9 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
@annotate_args( @annotate_args(
[ [
None, None,
([2, 3, 8, 4], torch.float32, True), ([2, 3, 8, 16], torch.float32, True),
([2, 3, 16, 4], torch.float32, True), ([2, 3, 12, 16], torch.float32, True),
([2, 3, 16, 4], torch.float32, True), ([2, 3, 12, 20], torch.float32, True),
] ]
) )
def forward(self, query, key, value): def forward(self, query, key, value):
@ -5143,12 +5195,95 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule()) @register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils): def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 4, dtype=torch.float32) query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 16, 4, dtype=torch.float32) key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 16, 4, dtype=torch.float32) value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
module.forward(query, key, value) module.forward(query, key, value)
class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, is_causal=True
)
@register_test_case(
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
)
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionMaskModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True),
([2, 3, 8, 12], torch.float32, True),
]
)
def forward(self, query, key, value, mask):
return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionMaskModule())
def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32)
module.forward(query, key, value, mask)
class ScaledDotProductAttentionBoolMaskModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True),
([2, 3, 8, 12], torch.bool, True),
]
)
def forward(self, query, key, value, mask):
return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionBoolMaskModule())
def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) > 0.5
module.forward(query, key, value, mask)
# ============================================================================== # ==============================================================================