mirror of https://github.com/llvm/torch-mlir
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product attention + relevant comments + tests. The tests added highlight the new capabilities introduced in this PR, including: Attention with F16 mask Attention with Boolean mask Causal attention with same Q K V shapes Causal attention without Q K V shapes Made sure that one cannot input both mask and is_causal.pull/3654/merge
parent
0a788e0467
commit
e86f56bc76
|
@ -252,13 +252,14 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
|||
["generateScalarImplementation"]>]> {
|
||||
let summary = "Attention operator";
|
||||
let description = [{
|
||||
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
|
||||
the attention. Each of the inputs has shape BxNxd where B is the
|
||||
of the batch dimension, N is the sequence length and d is head dimension.
|
||||
Typically N >>> d. Mathematically, the attention is defined as
|
||||
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
|
||||
this operator also performs scaling, masking and dropout, but we leave
|
||||
that out of the current implementation.
|
||||
This operator takes in 3 to 4 tensors: query(Q), key(K), value(V), and an
|
||||
optional mask(M) to compute the attention. These tensors must take on shapes
|
||||
BxMxK1 for Q, BxK2xK1 for K, BxK2xN for V, and BxMxK2 for M. For all these
|
||||
shapes, B represents the batch dimension, M represents sequence length, N
|
||||
represents head dimension, and K1 and K2 are hidden dimensions.
|
||||
Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and
|
||||
has shape BxMxN. Usually, this operator also performs scaling, masking and
|
||||
dropout, but we leave that out of the current implementation.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyShaped>:$inputs,
|
||||
|
@ -287,6 +288,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
|||
Value getValue() {
|
||||
return getInputOperand(2)->get();
|
||||
}
|
||||
std::optional<Value> getAttnMask() {
|
||||
if (getNumInputs() < 4) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return getInputOperand(3)->get();
|
||||
}
|
||||
Value getOutput() {
|
||||
return getOutputOperand(0)->get();
|
||||
}
|
||||
|
@ -299,6 +306,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
|||
ShapedType getValueType() {
|
||||
return cast<ShapedType>(getValue().getType());
|
||||
}
|
||||
std::optional<ShapedType> getAttnMaskType() {
|
||||
if (getAttnMask()){
|
||||
return cast<ShapedType>((*getAttnMask()).getType());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
ShapedType getOutputType() {
|
||||
return cast<ShapedType>(getOutput().getType());
|
||||
}
|
||||
|
@ -311,6 +324,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
|
|||
int64_t getValueRank() {
|
||||
return getValueType().getRank();
|
||||
}
|
||||
std::optional<int64_t> getAttnMaskRank() {
|
||||
if (getAttnMask()){
|
||||
return (*getAttnMaskType()).getRank();
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
int64_t getOutputRank() {
|
||||
return getOutputType().getRank();
|
||||
}
|
||||
|
|
|
@ -1578,7 +1578,16 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value mask = op.getAttnMask();
|
||||
|
||||
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
|
||||
auto query = adaptor.getQuery();
|
||||
auto value = adaptor.getValue();
|
||||
auto key = adaptor.getKey();
|
||||
auto mask = adaptor.getAttnMask();
|
||||
auto queryTy = cast<ShapedType>(query.getType());
|
||||
auto valueTy = cast<ShapedType>(value.getType());
|
||||
auto keyTy = cast<ShapedType>(key.getType());
|
||||
|
||||
Value dropoutP = op.getDropoutP();
|
||||
Value isCausal = op.getIsCausal();
|
||||
Value scale = op.getScale();
|
||||
|
@ -1586,18 +1595,77 @@ public:
|
|||
Type elementType =
|
||||
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
|
||||
|
||||
// Verify inputs (only support defaults)
|
||||
if (!isa<Torch::NoneType>(mask.getType()))
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"attention masking not supported");
|
||||
double dropout;
|
||||
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
|
||||
dropout > 0.0)
|
||||
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
|
||||
|
||||
bool causal;
|
||||
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "causal attention masking not supported");
|
||||
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
|
||||
if (!isa<Torch::NoneType>(mask.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "expected no attention mask when isCausal is true");
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> maskSizes;
|
||||
|
||||
if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
|
||||
auto seqLenQ =
|
||||
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
|
||||
auto seqLenK =
|
||||
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
|
||||
maskSizes = {seqLenQ, seqLenK};
|
||||
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
|
||||
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
|
||||
maskSizes.insert(maskSizes.begin(), batchSize);
|
||||
}
|
||||
} else { // Dynamic shape case: <?x?x...x?xf32> for example
|
||||
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
|
||||
Value batchSize =
|
||||
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
|
||||
maskSizes.push_back(batchSize);
|
||||
}
|
||||
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
|
||||
queryTy.getRank() - 2);
|
||||
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
|
||||
keyTy.getRank() - 2);
|
||||
maskSizes.push_back(seqLenQ);
|
||||
maskSizes.push_back(seqLenK);
|
||||
}
|
||||
|
||||
Type maskType = getElementTypeOrSelf(queryTy);
|
||||
Value emptyMask =
|
||||
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
|
||||
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
|
||||
Value negInf = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
|
||||
|
||||
mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
|
||||
.getResult(0);
|
||||
|
||||
int64_t rank = cast<ShapedType>(queryTy).getRank();
|
||||
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
rank, utils::IteratorType::parallel);
|
||||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||
op.getLoc(), mask.getType(), ValueRange{}, mask,
|
||||
SmallVector<AffineMap>{maskMap}, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
|
||||
Value j = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 1);
|
||||
|
||||
Value cond =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, i, j);
|
||||
Value select = b.create<arith::SelectOp>(loc, cond, zero, negInf);
|
||||
b.create<linalg::YieldOp>(loc, select);
|
||||
});
|
||||
mask = genericOp.getResult(0);
|
||||
}
|
||||
|
||||
if (!isa<Torch::NoneType>(scale.getType())) {
|
||||
double scaleFloat;
|
||||
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
|
||||
|
@ -1611,14 +1679,6 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "grouped query attention not supported");
|
||||
|
||||
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
|
||||
auto query = adaptor.getQuery();
|
||||
auto value = adaptor.getValue();
|
||||
auto key = adaptor.getKey();
|
||||
auto queryTy = cast<ShapedType>(query.getType());
|
||||
auto valueTy = cast<ShapedType>(value.getType());
|
||||
auto keyTy = cast<ShapedType>(key.getType());
|
||||
|
||||
if (queryTy.getRank() != valueTy.getRank() ||
|
||||
queryTy.getRank() != keyTy.getRank())
|
||||
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
|
||||
|
@ -1659,6 +1719,9 @@ public:
|
|||
query = collapseBatch(query);
|
||||
key = collapseBatch(key);
|
||||
value = collapseBatch(value);
|
||||
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
|
||||
mask = collapseBatch(mask);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
|
||||
SmallVector<int64_t> valueSizes(
|
||||
|
@ -1672,13 +1735,17 @@ public:
|
|||
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
|
||||
elementType);
|
||||
|
||||
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
|
||||
|
||||
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
|
||||
inputs.push_back(mask);
|
||||
}
|
||||
|
||||
// Overwrite with tm_tensor::attention
|
||||
Value attention =
|
||||
rewriter
|
||||
.create<AttentionOp>(loc, outType,
|
||||
SmallVector<Value>{query, key, value},
|
||||
SmallVector<Value>{output})
|
||||
.getResult()[0];
|
||||
Value attention = rewriter
|
||||
.create<AttentionOp>(loc, outType, inputs,
|
||||
SmallVector<Value>{output})
|
||||
.getResult()[0];
|
||||
|
||||
if (opTy != outType) {
|
||||
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,
|
||||
|
|
|
@ -93,14 +93,49 @@ LogicalResult AttentionOp::verify() {
|
|||
Operation *op = getOperation();
|
||||
ShapedType queryType = getQueryType();
|
||||
ShapedType keyType = getKeyType();
|
||||
ShapedType valueType = getValueType();
|
||||
|
||||
auto optionalMaskType = getAttnMaskType();
|
||||
ShapedType maskType = optionalMaskType ? *optionalMaskType : ShapedType();
|
||||
|
||||
ArrayRef<int64_t> queryShape = queryType.getShape();
|
||||
ArrayRef<int64_t> keyShape = keyType.getShape();
|
||||
ArrayRef<int64_t> valueShape = valueType.getShape();
|
||||
ArrayRef<int64_t> maskShape =
|
||||
optionalMaskType ? maskType.getShape() : ArrayRef<int64_t>();
|
||||
|
||||
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
|
||||
if (keyShape[i] != queryShape[i])
|
||||
if (keyShape[i] != queryShape[i]) {
|
||||
return op->emitOpError("query and key batch mismatch");
|
||||
}
|
||||
}
|
||||
if (keyShape.back() != queryShape.back())
|
||||
if (keyShape.back() != queryShape.back()) {
|
||||
return op->emitOpError("query and key head dimension mismatch");
|
||||
}
|
||||
|
||||
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
|
||||
if (valueShape[i] != queryShape[i]) {
|
||||
return op->emitOpError("query and value batch dimension mismatch");
|
||||
}
|
||||
}
|
||||
if (keyShape[keyShape.size() - 2] != valueShape[valueShape.size() - 2]) {
|
||||
return op->emitOpError("key and value sequence length dimension mismatch");
|
||||
}
|
||||
if (optionalMaskType) {
|
||||
for (int i = 0, s = maskShape.size() - 2; i < s; ++i) {
|
||||
if (maskShape[i] != queryShape[i]) {
|
||||
return op->emitOpError("query and mask batch dimension mismatch");
|
||||
}
|
||||
}
|
||||
if (maskShape[maskShape.size() - 2] != queryShape[queryShape.size() - 2]) {
|
||||
return op->emitOpError(
|
||||
"mask sequence length and query sequence length mismatch");
|
||||
}
|
||||
if (maskShape[maskShape.size() - 1] != keyShape[keyShape.size() - 2]) {
|
||||
return op->emitOpError(
|
||||
"mask sequence lengt and key sequence length mismatch");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -168,10 +203,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
Value query = getQuery();
|
||||
Value key = getKey();
|
||||
Value value = getValue();
|
||||
|
||||
auto optionalMask = getAttnMask();
|
||||
Value mask = optionalMask ? *optionalMask : Value();
|
||||
|
||||
Value output = getOutput();
|
||||
auto queryType = cast<MemRefType>(query.getType());
|
||||
auto keyType = cast<MemRefType>(key.getType());
|
||||
auto valueType = cast<MemRefType>(value.getType());
|
||||
auto maskType = mask ? cast<MemRefType>(mask.getType()) : MemRefType();
|
||||
auto queryRank = queryType.getRank();
|
||||
auto keyRank = keyType.getRank();
|
||||
auto valueRank = valueType.getRank();
|
||||
|
@ -180,6 +220,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
|
||||
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
|
||||
b.getFloatAttr(elementType, 0.0));
|
||||
Value negInfF = b.create<arith::ConstantOp>(
|
||||
loc, elementType,
|
||||
b.getFloatAttr(elementType, -std::numeric_limits<double>::infinity()));
|
||||
|
||||
// TODO: This needs to be fixed, it assumes everything is dynamic however if
|
||||
// any shapes are static the `memref.alloc` generated is illegal.
|
||||
|
@ -214,14 +257,43 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
/*transposed=*/true);
|
||||
|
||||
// weight = softmax(weight)
|
||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value dim = weightDynSizes[weightRank - 1];
|
||||
Value scaleFactor = b.create<math::SqrtOp>(
|
||||
loc, b.create<arith::UIToFPOp>(
|
||||
loc, elementType,
|
||||
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
|
||||
queryDynSizes[queryRank - 1])));
|
||||
|
||||
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
|
||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
||||
SmallVector<Value>(weightRank, one),
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||
});
|
||||
|
||||
// Apply mask to weights if mask is given
|
||||
if (mask) {
|
||||
b.create<scf::ParallelOp>(
|
||||
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
|
||||
SmallVector<Value>(weightRank, one),
|
||||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
Value weightValue = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
Value maskValue = b.create<memref::LoadOp>(loc, mask, localIVs);
|
||||
if (maskType.getElementType().isInteger(1)) {
|
||||
maskValue =
|
||||
b.create<arith::SelectOp>(loc, maskValue, zeroF, negInfF);
|
||||
}
|
||||
Value maskedWeight =
|
||||
b.create<arith::AddFOp>(loc, weightValue, maskValue);
|
||||
b.create<memref::StoreOp>(loc, maskedWeight, weight, localIVs);
|
||||
});
|
||||
}
|
||||
|
||||
// calculate max(weight)
|
||||
Value init = b.create<memref::LoadOp>(loc, weight,
|
||||
SmallVector<Value>(weightRank, zero));
|
||||
|
@ -249,7 +321,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
x = b.create<arith::SubFOp>(loc, x, globalMax);
|
||||
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
|
||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||
});
|
||||
// calculate exp(weight)
|
||||
|
@ -307,10 +378,19 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
|||
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
|
||||
SmallVector<Value> sumIVs(localIVs);
|
||||
sumIVs.pop_back();
|
||||
|
||||
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
|
||||
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
|
||||
x = b.create<arith::DivFOp>(loc, x, sum);
|
||||
b.create<memref::StoreOp>(loc, x, weight, localIVs);
|
||||
Value divResult = b.create<arith::DivFOp>(loc, x, sum);
|
||||
|
||||
// Set to 0 if sum is 0 (can occur during boolean mask / large negative
|
||||
// QK)
|
||||
Value isSumZero =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, sum, zeroF);
|
||||
Value result =
|
||||
b.create<arith::SelectOp>(loc, isSumZero, zeroF, divResult);
|
||||
|
||||
b.create<memref::StoreOp>(loc, result, weight, localIVs);
|
||||
});
|
||||
|
||||
// output = weight @ value
|
||||
|
|
|
@ -34,7 +34,13 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
||||
LINALG_XFAIL_SET = LINALG_XFAIL_SET | {
|
||||
# Error: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7
|
||||
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
}
|
||||
|
||||
|
@ -498,7 +504,13 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"WeightNormInterfaceModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
}
|
||||
|
||||
|
@ -780,6 +792,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
"ScatterReduceFloatMaxModule",
|
||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||
"ScatterReduceFloatMeanModule",
|
||||
|
@ -2179,6 +2199,8 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
||||
MAKE_FX_TOSA_PASS_SET = MAKE_FX_TOSA_PASS_SET | {
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionMaskModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
}
|
||||
|
||||
LTC_CRASHING_SET = {
|
||||
|
@ -2932,6 +2954,12 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
|||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() >= version.parse("2.5.0.dev"):
|
||||
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
|
||||
# ERROR: value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=+nan, max=+nan, mean=+nan) is not close to golden value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=-2.394, max=+2.454, mean=-0.02828)
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
|
||||
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
|
||||
"AtenIntMM_basic",
|
||||
|
@ -3009,8 +3037,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ReduceAminmaxSingleDim_basic",
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"RenormModuleFloat16_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScatterAddStaticModule_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
"TensorsConcatComplex128IntModule_basic",
|
||||
|
@ -4548,7 +4579,11 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameCausalModule_basic",
|
||||
"ScaledDotProductAttentionSameDynamicModule_basic",
|
||||
"ScatterReduceFloatMaxModule",
|
||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||
"ScatterReduceFloatMeanModule",
|
||||
|
|
|
@ -5107,9 +5107,9 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module):
|
|||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([1, 5, 5], torch.float32, True),
|
||||
([1, 5, 5], torch.float32, True),
|
||||
([1, 5, 5], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value):
|
||||
|
@ -5124,6 +5124,58 @@ def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
|
|||
module.forward(query, key, value)
|
||||
|
||||
|
||||
class ScaledDotProductAttentionSameDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameDynamicModule())
|
||||
def ScaledDotProductAttentionSameDynamicModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
key = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
value = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
|
||||
class ScaledDotProductAttentionSameCausalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten.scaled_dot_product_attention(
|
||||
query, key, value, is_causal=True
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameCausalModule())
|
||||
def ScaledDotProductAttentionSameCausalModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
key = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
value = torch.randn(1, 5, 5, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
|
||||
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -5132,9 +5184,9 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
|||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 8, 4], torch.float32, True),
|
||||
([2, 3, 16, 4], torch.float32, True),
|
||||
([2, 3, 16, 4], torch.float32, True),
|
||||
([2, 3, 8, 16], torch.float32, True),
|
||||
([2, 3, 12, 16], torch.float32, True),
|
||||
([2, 3, 12, 20], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value):
|
||||
|
@ -5143,12 +5195,95 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
|
||||
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(2, 3, 8, 4, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 16, 4, dtype=torch.float32)
|
||||
value = torch.randn(2, 3, 16, 4, dtype=torch.float32)
|
||||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
|
||||
class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 8, 16], torch.float32, True),
|
||||
([2, 3, 12, 16], torch.float32, True),
|
||||
([2, 3, 12, 20], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value):
|
||||
return torch.ops.aten.scaled_dot_product_attention(
|
||||
query, key, value, is_causal=True
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
|
||||
)
|
||||
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||
module.forward(query, key, value)
|
||||
|
||||
|
||||
class ScaledDotProductAttentionMaskModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 8, 16], torch.float32, True),
|
||||
([2, 3, 12, 16], torch.float32, True),
|
||||
([2, 3, 12, 20], torch.float32, True),
|
||||
([2, 3, 8, 12], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value, mask):
|
||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionMaskModule())
|
||||
def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32)
|
||||
module.forward(query, key, value, mask)
|
||||
|
||||
|
||||
class ScaledDotProductAttentionBoolMaskModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 8, 16], torch.float32, True),
|
||||
([2, 3, 12, 16], torch.float32, True),
|
||||
([2, 3, 12, 20], torch.float32, True),
|
||||
([2, 3, 8, 12], torch.bool, True),
|
||||
]
|
||||
)
|
||||
def forward(self, query, key, value, mask):
|
||||
return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ScaledDotProductAttentionBoolMaskModule())
|
||||
def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils):
|
||||
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
|
||||
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
|
||||
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
|
||||
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) > 0.5
|
||||
module.forward(query, key, value, mask)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue