diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index 90c800ba3..c47eaabf7 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -252,13 +252,14 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", ["generateScalarImplementation"]>]> { let summary = "Attention operator"; let description = [{ - This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes - the attention. Each of the inputs has shape BxNxd where B is the - of the batch dimension, N is the sequence length and d is head dimension. - Typically N >>> d. Mathematically, the attention is defined as - matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually, - this operator also performs scaling, masking and dropout, but we leave - that out of the current implementation. + This operator takes in 3 to 4 tensors: query(Q), key(K), value(V), and an + optional mask(M) to compute the attention. These tensors must take on shapes + BxMxK1 for Q, BxK2xK1 for K, BxK2xN for V, and BxMxK2 for M. For all these + shapes, B represents the batch dimension, M represents sequence length, N + represents head dimension, and K1 and K2 are hidden dimensions. + Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and + has shape BxMxN. Usually, this operator also performs scaling, masking and + dropout, but we leave that out of the current implementation. }]; let arguments = (ins Variadic:$inputs, @@ -287,6 +288,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", Value getValue() { return getInputOperand(2)->get(); } + std::optional getAttnMask() { + if (getNumInputs() < 4) { + return std::nullopt; + } + return getInputOperand(3)->get(); + } Value getOutput() { return getOutputOperand(0)->get(); } @@ -299,6 +306,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", ShapedType getValueType() { return cast(getValue().getType()); } + std::optional getAttnMaskType() { + if (getAttnMask()){ + return cast((*getAttnMask()).getType()); + } + return std::nullopt; + } ShapedType getOutputType() { return cast(getOutput().getType()); } @@ -311,6 +324,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", int64_t getValueRank() { return getValueType().getRank(); } + std::optional getAttnMaskRank() { + if (getAttnMask()){ + return (*getAttnMaskType()).getRank(); + } + return std::nullopt; + } int64_t getOutputRank() { return getOutputType().getRank(); } diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index e52a373bd..4a87d6888 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1578,7 +1578,16 @@ public: LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value mask = op.getAttnMask(); + + auto opTy = cast(op.getType()).toBuiltinTensor(); + auto query = adaptor.getQuery(); + auto value = adaptor.getValue(); + auto key = adaptor.getKey(); + auto mask = adaptor.getAttnMask(); + auto queryTy = cast(query.getType()); + auto valueTy = cast(value.getType()); + auto keyTy = cast(key.getType()); + Value dropoutP = op.getDropoutP(); Value isCausal = op.getIsCausal(); Value scale = op.getScale(); @@ -1586,18 +1595,77 @@ public: Type elementType = cast(adaptor.getQuery().getType()).getElementType(); - // Verify inputs (only support defaults) - if (!isa(mask.getType())) - return rewriter.notifyMatchFailure(op.getLoc(), - "attention masking not supported"); double dropout; if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) || dropout > 0.0) return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported"); + bool causal; - if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) - return rewriter.notifyMatchFailure( - op.getLoc(), "causal attention masking not supported"); + if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) { + if (!isa(mask.getType())) { + return rewriter.notifyMatchFailure( + op.getLoc(), "expected no attention mask when isCausal is true"); + } + + SmallVector maskSizes; + + if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) { + auto seqLenQ = + rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2)); + auto seqLenK = + rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2)); + maskSizes = {seqLenQ, seqLenK}; + for (int i = queryTy.getRank() - 3; i >= 0; --i) { + auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i)); + maskSizes.insert(maskSizes.begin(), batchSize); + } + } else { // Dynamic shape case: for example + for (int i = 0; i < queryTy.getRank() - 2; ++i) { + Value batchSize = + rewriter.create(op.getLoc(), query, i); + maskSizes.push_back(batchSize); + } + Value seqLenQ = rewriter.create(op.getLoc(), query, + queryTy.getRank() - 2); + Value seqLenK = rewriter.create(op.getLoc(), key, + keyTy.getRank() - 2); + maskSizes.push_back(seqLenQ); + maskSizes.push_back(seqLenK); + } + + Type maskType = getElementTypeOrSelf(queryTy); + Value emptyMask = + rewriter.create(op.getLoc(), maskSizes, maskType); + + Value zero = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0)); + Value negInf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY)); + + mask = rewriter.create(op.getLoc(), zero, emptyMask) + .getResult(0); + + int64_t rank = cast(queryTy).getRank(); + AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + auto genericOp = rewriter.create( + op.getLoc(), mask.getType(), ValueRange{}, mask, + SmallVector{maskMap}, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value i = b.create(loc, queryTy.getRank() - 2); + Value j = b.create(loc, queryTy.getRank() - 1); + + Value cond = + b.create(loc, arith::CmpIPredicate::sge, i, j); + Value select = b.create(loc, cond, zero, negInf); + b.create(loc, select); + }); + mask = genericOp.getResult(0); + } + if (!isa(scale.getType())) { double scaleFloat; if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) || @@ -1611,14 +1679,6 @@ public: return rewriter.notifyMatchFailure( op.getLoc(), "grouped query attention not supported"); - auto opTy = cast(op.getType()).toBuiltinTensor(); - auto query = adaptor.getQuery(); - auto value = adaptor.getValue(); - auto key = adaptor.getKey(); - auto queryTy = cast(query.getType()); - auto valueTy = cast(value.getType()); - auto keyTy = cast(key.getType()); - if (queryTy.getRank() != valueTy.getRank() || queryTy.getRank() != keyTy.getRank()) return rewriter.notifyMatchFailure(op, "operand ranks do not match"); @@ -1659,6 +1719,9 @@ public: query = collapseBatch(query); key = collapseBatch(key); value = collapseBatch(value); + if (!isa(mask.getType())) { + mask = collapseBatch(mask); + } SmallVector outSizes(cast(query.getType()).getShape()); SmallVector valueSizes( @@ -1672,13 +1735,17 @@ public: Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic, elementType); + SmallVector inputs = SmallVector{query, key, value}; + + if (!isa(mask.getType())) { + inputs.push_back(mask); + } + // Overwrite with tm_tensor::attention - Value attention = - rewriter - .create(loc, outType, - SmallVector{query, key, value}, - SmallVector{output}) - .getResult()[0]; + Value attention = rewriter + .create(loc, outType, inputs, + SmallVector{output}) + .getResult()[0]; if (opTy != outType) { attention = rewriter.create(loc, opTy, attention, diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 943eda423..9a90b4cac 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -93,14 +93,49 @@ LogicalResult AttentionOp::verify() { Operation *op = getOperation(); ShapedType queryType = getQueryType(); ShapedType keyType = getKeyType(); + ShapedType valueType = getValueType(); + + auto optionalMaskType = getAttnMaskType(); + ShapedType maskType = optionalMaskType ? *optionalMaskType : ShapedType(); + ArrayRef queryShape = queryType.getShape(); ArrayRef keyShape = keyType.getShape(); + ArrayRef valueShape = valueType.getShape(); + ArrayRef maskShape = + optionalMaskType ? maskType.getShape() : ArrayRef(); + for (int i = 0, s = queryShape.size() - 2; i < s; ++i) { - if (keyShape[i] != queryShape[i]) + if (keyShape[i] != queryShape[i]) { return op->emitOpError("query and key batch mismatch"); + } } - if (keyShape.back() != queryShape.back()) + if (keyShape.back() != queryShape.back()) { return op->emitOpError("query and key head dimension mismatch"); + } + + for (int i = 0, s = queryShape.size() - 2; i < s; ++i) { + if (valueShape[i] != queryShape[i]) { + return op->emitOpError("query and value batch dimension mismatch"); + } + } + if (keyShape[keyShape.size() - 2] != valueShape[valueShape.size() - 2]) { + return op->emitOpError("key and value sequence length dimension mismatch"); + } + if (optionalMaskType) { + for (int i = 0, s = maskShape.size() - 2; i < s; ++i) { + if (maskShape[i] != queryShape[i]) { + return op->emitOpError("query and mask batch dimension mismatch"); + } + } + if (maskShape[maskShape.size() - 2] != queryShape[queryShape.size() - 2]) { + return op->emitOpError( + "mask sequence length and query sequence length mismatch"); + } + if (maskShape[maskShape.size() - 1] != keyShape[keyShape.size() - 2]) { + return op->emitOpError( + "mask sequence lengt and key sequence length mismatch"); + } + } return success(); } @@ -168,10 +203,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value query = getQuery(); Value key = getKey(); Value value = getValue(); + + auto optionalMask = getAttnMask(); + Value mask = optionalMask ? *optionalMask : Value(); + Value output = getOutput(); auto queryType = cast(query.getType()); auto keyType = cast(key.getType()); auto valueType = cast(value.getType()); + auto maskType = mask ? cast(mask.getType()) : MemRefType(); auto queryRank = queryType.getRank(); auto keyRank = keyType.getRank(); auto valueRank = valueType.getRank(); @@ -180,6 +220,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value zeroF = b.create(loc, elementType, b.getFloatAttr(elementType, 0.0)); + Value negInfF = b.create( + loc, elementType, + b.getFloatAttr(elementType, -std::numeric_limits::infinity())); // TODO: This needs to be fixed, it assumes everything is dynamic however if // any shapes are static the `memref.alloc` generated is illegal. @@ -214,14 +257,43 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, /*transposed=*/true); // weight = softmax(weight) - Value one = b.create(loc, 1); - Value zero = b.create(loc, 0); Value dim = weightDynSizes[weightRank - 1]; Value scaleFactor = b.create( loc, b.create( loc, elementType, b.create(loc, b.getI32Type(), queryDynSizes[queryRank - 1]))); + + // weight = (weight - max(weight)) / math.sqrt(querySizes[-1]) + Value one = b.create(loc, 1); + Value zero = b.create(loc, 0); + b.create( + loc, SmallVector(weightRank, zero), weightDynSizes, + SmallVector(weightRank, one), + [&](OpBuilder &b, Location loc, ValueRange localIVs) { + Value x = b.create(loc, weight, localIVs); + x = b.create(loc, x, scaleFactor); + b.create(loc, x, weight, localIVs); + }); + + // Apply mask to weights if mask is given + if (mask) { + b.create( + loc, SmallVector(weightRank, zero), weightDynSizes, + SmallVector(weightRank, one), + [&](OpBuilder &b, Location loc, ValueRange localIVs) { + Value weightValue = b.create(loc, weight, localIVs); + Value maskValue = b.create(loc, mask, localIVs); + if (maskType.getElementType().isInteger(1)) { + maskValue = + b.create(loc, maskValue, zeroF, negInfF); + } + Value maskedWeight = + b.create(loc, weightValue, maskValue); + b.create(loc, maskedWeight, weight, localIVs); + }); + } + // calculate max(weight) Value init = b.create(loc, weight, SmallVector(weightRank, zero)); @@ -249,7 +321,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, [&](OpBuilder &b, Location loc, ValueRange localIVs) { Value x = b.create(loc, weight, localIVs); x = b.create(loc, x, globalMax); - x = b.create(loc, x, scaleFactor); b.create(loc, x, weight, localIVs); }); // calculate exp(weight) @@ -307,10 +378,19 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, [&](OpBuilder &b, Location loc, ValueRange localIVs) { SmallVector sumIVs(localIVs); sumIVs.pop_back(); + Value x = b.create(loc, weight, localIVs); Value sum = b.create(loc, expWeightSum, sumIVs); - x = b.create(loc, x, sum); - b.create(loc, x, weight, localIVs); + Value divResult = b.create(loc, x, sum); + + // Set to 0 if sum is 0 (can occur during boolean mask / large negative + // QK) + Value isSumZero = + b.create(loc, arith::CmpFPredicate::OEQ, sum, zeroF); + Value result = + b.create(loc, isSumZero, zeroF, divResult); + + b.create(loc, result, weight, localIVs); }); // output = weight @ value diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 80831d8ea..cb981b327 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -34,7 +34,13 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { if torch_version_for_comparison() < version.parse("2.5.0.dev"): LINALG_XFAIL_SET = LINALG_XFAIL_SET | { # Error: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7 + # WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", "ScaledDotProductAttentionSameModule_basic", } @@ -498,7 +504,13 @@ FX_IMPORTER_XFAIL_SET = { "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "WeightNormInterfaceModule_basic", + # REMOVE WHEN ENABLE_GQA IS ADDED + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", "ScaledDotProductAttentionSameModule_basic", } @@ -780,6 +792,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "RsubInt0d_NumToTensor_Module_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", + # REMOVE WHEN ENABLE_GQA IS ADDED + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", + "ScaledDotProductAttentionSameModule_basic", "ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMeanModule", @@ -2179,6 +2199,8 @@ MAKE_FX_TOSA_PASS_SET = ( if torch_version_for_comparison() < version.parse("2.5.0.dev"): MAKE_FX_TOSA_PASS_SET = MAKE_FX_TOSA_PASS_SET | { "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameModule_basic", } LTC_CRASHING_SET = { @@ -2932,6 +2954,12 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"): "ElementwiseBitwiseAndStaticShapeModule_basic", } +if torch_version_for_comparison() >= version.parse("2.5.0.dev"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # ERROR: value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=+nan, max=+nan, mean=+nan) is not close to golden value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=-2.394, max=+2.454, mean=-0.02828) + "ScaledDotProductAttentionBoolMaskModule_basic", + } + if torch_version_for_comparison() < version.parse("2.4.0.dev"): STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { "AtenIntMM_basic", @@ -3009,8 +3037,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ReduceAminmaxSingleDim_basic", "ReduceAnyDimFloatModule_basic", "RenormModuleFloat16_basic", - "ScaledDotProductAttentionDifferentModule_basic", - "ScaledDotProductAttentionSameModule_basic", + # REMOVE WHEN ENABLE_GQA IS ADDED + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", "ScatterAddStaticModule_basic", "TensorsConcatComplex128FloatModule_basic", "TensorsConcatComplex128IntModule_basic", @@ -4548,7 +4579,11 @@ ONNX_TOSA_XFAIL_SET = { "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", - "ScaledDotProductAttentionSameModule_basic", + # REMOVE WHEN ENABLE_GQA IS ADDED + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", "ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMeanModule", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 481a89b18..b33f8e3ee 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5107,9 +5107,9 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module): @annotate_args( [ None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), + ([1, 5, 5], torch.float32, True), + ([1, 5, 5], torch.float32, True), + ([1, 5, 5], torch.float32, True), ] ) def forward(self, query, key, value): @@ -5124,6 +5124,58 @@ def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils): module.forward(query, key, value) +class ScaledDotProductAttentionSameDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention(query, key, value) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameDynamicModule()) +def ScaledDotProductAttentionSameDynamicModule_basic(module, tu: TestUtils): + query = torch.randn(1, 5, 5, dtype=torch.float32) + key = torch.randn(1, 5, 5, dtype=torch.float32) + value = torch.randn(1, 5, 5, dtype=torch.float32) + module.forward(query, key, value) + + +class ScaledDotProductAttentionSameCausalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention( + query, key, value, is_causal=True + ) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameCausalModule()) +def ScaledDotProductAttentionSameCausalModule_basic(module, tu: TestUtils): + query = torch.randn(1, 5, 5, dtype=torch.float32) + key = torch.randn(1, 5, 5, dtype=torch.float32) + value = torch.randn(1, 5, 5, dtype=torch.float32) + module.forward(query, key, value) + + class ScaledDotProductAttentionDifferentModule(torch.nn.Module): def __init__(self): super().__init__() @@ -5132,9 +5184,9 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module): @annotate_args( [ None, - ([2, 3, 8, 4], torch.float32, True), - ([2, 3, 16, 4], torch.float32, True), - ([2, 3, 16, 4], torch.float32, True), + ([2, 3, 8, 16], torch.float32, True), + ([2, 3, 12, 16], torch.float32, True), + ([2, 3, 12, 20], torch.float32, True), ] ) def forward(self, query, key, value): @@ -5143,12 +5195,95 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module): @register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule()) def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils): - query = torch.randn(2, 3, 8, 4, dtype=torch.float32) - key = torch.randn(2, 3, 16, 4, dtype=torch.float32) - value = torch.randn(2, 3, 16, 4, dtype=torch.float32) + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) module.forward(query, key, value) +class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 8, 16], torch.float32, True), + ([2, 3, 12, 16], torch.float32, True), + ([2, 3, 12, 20], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention( + query, key, value, is_causal=True + ) + + +@register_test_case( + module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule() +) +def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils): + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + module.forward(query, key, value) + + +class ScaledDotProductAttentionMaskModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 8, 16], torch.float32, True), + ([2, 3, 12, 16], torch.float32, True), + ([2, 3, 12, 20], torch.float32, True), + ([2, 3, 8, 12], torch.float32, True), + ] + ) + def forward(self, query, key, value, mask): + return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionMaskModule()) +def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils): + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) + module.forward(query, key, value, mask) + + +class ScaledDotProductAttentionBoolMaskModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 8, 16], torch.float32, True), + ([2, 3, 12, 16], torch.float32, True), + ([2, 3, 12, 20], torch.float32, True), + ([2, 3, 8, 12], torch.bool, True), + ] + ) + def forward(self, query, key, value, mask): + return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionBoolMaskModule()) +def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils): + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) > 0.5 + module.forward(query, key, value, mask) + + # ==============================================================================