[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)

Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
pull/3654/merge
rohan-tan-bhowmik 2024-09-09 15:51:41 -07:00 committed by GitHub
parent 0a788e0467
commit e86f56bc76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 384 additions and 48 deletions

View File

@ -252,13 +252,14 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
["generateScalarImplementation"]>]> {
let summary = "Attention operator";
let description = [{
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
the attention. Each of the inputs has shape BxNxd where B is the
of the batch dimension, N is the sequence length and d is head dimension.
Typically N >>> d. Mathematically, the attention is defined as
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
this operator also performs scaling, masking and dropout, but we leave
that out of the current implementation.
This operator takes in 3 to 4 tensors: query(Q), key(K), value(V), and an
optional mask(M) to compute the attention. These tensors must take on shapes
BxMxK1 for Q, BxK2xK1 for K, BxK2xN for V, and BxMxK2 for M. For all these
shapes, B represents the batch dimension, M represents sequence length, N
represents head dimension, and K1 and K2 are hidden dimensions.
Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and
has shape BxMxN. Usually, this operator also performs scaling, masking and
dropout, but we leave that out of the current implementation.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
@ -287,6 +288,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
Value getValue() {
return getInputOperand(2)->get();
}
std::optional<Value> getAttnMask() {
if (getNumInputs() < 4) {
return std::nullopt;
}
return getInputOperand(3)->get();
}
Value getOutput() {
return getOutputOperand(0)->get();
}
@ -299,6 +306,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
ShapedType getValueType() {
return cast<ShapedType>(getValue().getType());
}
std::optional<ShapedType> getAttnMaskType() {
if (getAttnMask()){
return cast<ShapedType>((*getAttnMask()).getType());
}
return std::nullopt;
}
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
@ -311,6 +324,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
int64_t getValueRank() {
return getValueType().getRank();
}
std::optional<int64_t> getAttnMaskRank() {
if (getAttnMask()){
return (*getAttnMaskType()).getRank();
}
return std::nullopt;
}
int64_t getOutputRank() {
return getOutputType().getRank();
}

View File

@ -1578,7 +1578,16 @@ public:
LogicalResult
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value mask = op.getAttnMask();
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto mask = adaptor.getAttnMask();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());
Value dropoutP = op.getDropoutP();
Value isCausal = op.getIsCausal();
Value scale = op.getScale();
@ -1586,18 +1595,77 @@ public:
Type elementType =
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
// Verify inputs (only support defaults)
if (!isa<Torch::NoneType>(mask.getType()))
return rewriter.notifyMatchFailure(op.getLoc(),
"attention masking not supported");
double dropout;
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
dropout > 0.0)
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
bool causal;
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
if (!isa<Torch::NoneType>(mask.getType())) {
return rewriter.notifyMatchFailure(
op.getLoc(), "causal attention masking not supported");
op.getLoc(), "expected no attention mask when isCausal is true");
}
SmallVector<OpFoldResult> maskSizes;
if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
auto seqLenQ =
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
auto seqLenK =
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
maskSizes = {seqLenQ, seqLenK};
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
maskSizes.insert(maskSizes.begin(), batchSize);
}
} else { // Dynamic shape case: <?x?x...x?xf32> for example
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
Value batchSize =
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
maskSizes.push_back(batchSize);
}
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
queryTy.getRank() - 2);
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2);
maskSizes.push_back(seqLenQ);
maskSizes.push_back(seqLenK);
}
Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask =
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
Value negInf = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
.getResult(0);
int64_t rank = cast<ShapedType>(queryTy).getRank();
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
auto genericOp = rewriter.create<linalg::GenericOp>(
op.getLoc(), mask.getType(), ValueRange{}, mask,
SmallVector<AffineMap>{maskMap}, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
Value j = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 1);
Value cond =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, i, j);
Value select = b.create<arith::SelectOp>(loc, cond, zero, negInf);
b.create<linalg::YieldOp>(loc, select);
});
mask = genericOp.getResult(0);
}
if (!isa<Torch::NoneType>(scale.getType())) {
double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
@ -1611,14 +1679,6 @@ public:
return rewriter.notifyMatchFailure(
op.getLoc(), "grouped query attention not supported");
auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());
if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank())
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
@ -1659,6 +1719,9 @@ public:
query = collapseBatch(query);
key = collapseBatch(key);
value = collapseBatch(value);
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
mask = collapseBatch(mask);
}
SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
SmallVector<int64_t> valueSizes(
@ -1672,11 +1735,15 @@ public:
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
elementType);
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
inputs.push_back(mask);
}
// Overwrite with tm_tensor::attention
Value attention =
rewriter
.create<AttentionOp>(loc, outType,
SmallVector<Value>{query, key, value},
Value attention = rewriter
.create<AttentionOp>(loc, outType, inputs,
SmallVector<Value>{output})
.getResult()[0];

View File

@ -93,14 +93,49 @@ LogicalResult AttentionOp::verify() {
Operation *op = getOperation();
ShapedType queryType = getQueryType();
ShapedType keyType = getKeyType();
ShapedType valueType = getValueType();
auto optionalMaskType = getAttnMaskType();
ShapedType maskType = optionalMaskType ? *optionalMaskType : ShapedType();
ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape();
ArrayRef<int64_t> valueShape = valueType.getShape();
ArrayRef<int64_t> maskShape =
optionalMaskType ? maskType.getShape() : ArrayRef<int64_t>();
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (keyShape[i] != queryShape[i])
if (keyShape[i] != queryShape[i]) {
return op->emitOpError("query and key batch mismatch");
}
if (keyShape.back() != queryShape.back())
}
if (keyShape.back() != queryShape.back()) {
return op->emitOpError("query and key head dimension mismatch");
}
for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (valueShape[i] != queryShape[i]) {
return op->emitOpError("query and value batch dimension mismatch");
}
}
if (keyShape[keyShape.size() - 2] != valueShape[valueShape.size() - 2]) {
return op->emitOpError("key and value sequence length dimension mismatch");
}
if (optionalMaskType) {
for (int i = 0, s = maskShape.size() - 2; i < s; ++i) {
if (maskShape[i] != queryShape[i]) {
return op->emitOpError("query and mask batch dimension mismatch");
}
}
if (maskShape[maskShape.size() - 2] != queryShape[queryShape.size() - 2]) {
return op->emitOpError(
"mask sequence length and query sequence length mismatch");
}
if (maskShape[maskShape.size() - 1] != keyShape[keyShape.size() - 2]) {
return op->emitOpError(
"mask sequence lengt and key sequence length mismatch");
}
}
return success();
}
@ -168,10 +203,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value query = getQuery();
Value key = getKey();
Value value = getValue();
auto optionalMask = getAttnMask();
Value mask = optionalMask ? *optionalMask : Value();
Value output = getOutput();
auto queryType = cast<MemRefType>(query.getType());
auto keyType = cast<MemRefType>(key.getType());
auto valueType = cast<MemRefType>(value.getType());
auto maskType = mask ? cast<MemRefType>(mask.getType()) : MemRefType();
auto queryRank = queryType.getRank();
auto keyRank = keyType.getRank();
auto valueRank = valueType.getRank();
@ -180,6 +220,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
b.getFloatAttr(elementType, 0.0));
Value negInfF = b.create<arith::ConstantOp>(
loc, elementType,
b.getFloatAttr(elementType, -std::numeric_limits<double>::infinity()));
// TODO: This needs to be fixed, it assumes everything is dynamic however if
// any shapes are static the `memref.alloc` generated is illegal.
@ -214,14 +257,43 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
/*transposed=*/true);
// weight = softmax(weight)
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value dim = weightDynSizes[weightRank - 1];
Value scaleFactor = b.create<math::SqrtOp>(
loc, b.create<arith::UIToFPOp>(
loc, elementType,
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
queryDynSizes[queryRank - 1])));
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
});
// Apply mask to weights if mask is given
if (mask) {
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value weightValue = b.create<memref::LoadOp>(loc, weight, localIVs);
Value maskValue = b.create<memref::LoadOp>(loc, mask, localIVs);
if (maskType.getElementType().isInteger(1)) {
maskValue =
b.create<arith::SelectOp>(loc, maskValue, zeroF, negInfF);
}
Value maskedWeight =
b.create<arith::AddFOp>(loc, weightValue, maskValue);
b.create<memref::StoreOp>(loc, maskedWeight, weight, localIVs);
});
}
// calculate max(weight)
Value init = b.create<memref::LoadOp>(loc, weight,
SmallVector<Value>(weightRank, zero));
@ -249,7 +321,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::SubFOp>(loc, x, globalMax);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
});
// calculate exp(weight)
@ -307,10 +378,19 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
SmallVector<Value> sumIVs(localIVs);
sumIVs.pop_back();
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
x = b.create<arith::DivFOp>(loc, x, sum);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
Value divResult = b.create<arith::DivFOp>(loc, x, sum);
// Set to 0 if sum is 0 (can occur during boolean mask / large negative
// QK)
Value isSumZero =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, sum, zeroF);
Value result =
b.create<arith::SelectOp>(loc, isSumZero, zeroF, divResult);
b.create<memref::StoreOp>(loc, result, weight, localIVs);
});
// output = weight @ value

View File

@ -34,7 +34,13 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
LINALG_XFAIL_SET = LINALG_XFAIL_SET | {
# Error: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
}
@ -498,7 +504,13 @@ FX_IMPORTER_XFAIL_SET = {
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"WeightNormInterfaceModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
}
@ -780,6 +792,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"RsubInt0d_NumToTensor_Module_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",
@ -2179,6 +2199,8 @@ MAKE_FX_TOSA_PASS_SET = (
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
MAKE_FX_TOSA_PASS_SET = MAKE_FX_TOSA_PASS_SET | {
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameModule_basic",
}
LTC_CRASHING_SET = {
@ -2932,6 +2954,12 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"):
"ElementwiseBitwiseAndStaticShapeModule_basic",
}
if torch_version_for_comparison() >= version.parse("2.5.0.dev"):
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
# ERROR: value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=+nan, max=+nan, mean=+nan) is not close to golden value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=-2.394, max=+2.454, mean=-0.02828)
"ScaledDotProductAttentionBoolMaskModule_basic",
}
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
"AtenIntMM_basic",
@ -3009,8 +3037,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReduceAminmaxSingleDim_basic",
"ReduceAnyDimFloatModule_basic",
"RenormModuleFloat16_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionSameModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterAddStaticModule_basic",
"TensorsConcatComplex128FloatModule_basic",
"TensorsConcatComplex128IntModule_basic",
@ -4548,7 +4579,11 @@ ONNX_TOSA_XFAIL_SET = {
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"ScaledDotProductAttentionSameModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",

View File

@ -5107,9 +5107,9 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module):
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([1, 5, 5], torch.float32, True),
([1, 5, 5], torch.float32, True),
([1, 5, 5], torch.float32, True),
]
)
def forward(self, query, key, value):
@ -5124,6 +5124,58 @@ def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils):
module.forward(query, key, value)
class ScaledDotProductAttentionSameDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameDynamicModule())
def ScaledDotProductAttentionSameDynamicModule_basic(module, tu: TestUtils):
query = torch.randn(1, 5, 5, dtype=torch.float32)
key = torch.randn(1, 5, 5, dtype=torch.float32)
value = torch.randn(1, 5, 5, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionSameCausalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, is_causal=True
)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameCausalModule())
def ScaledDotProductAttentionSameCausalModule_basic(module, tu: TestUtils):
query = torch.randn(1, 5, 5, dtype=torch.float32)
key = torch.randn(1, 5, 5, dtype=torch.float32)
value = torch.randn(1, 5, 5, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -5132,9 +5184,9 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
@annotate_args(
[
None,
([2, 3, 8, 4], torch.float32, True),
([2, 3, 16, 4], torch.float32, True),
([2, 3, 16, 4], torch.float32, True),
([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True),
]
)
def forward(self, query, key, value):
@ -5143,12 +5195,95 @@ class ScaledDotProductAttentionDifferentModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule())
def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 4, dtype=torch.float32)
key = torch.randn(2, 3, 16, 4, dtype=torch.float32)
value = torch.randn(2, 3, 16, 4, dtype=torch.float32)
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True),
]
)
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(
query, key, value, is_causal=True
)
@register_test_case(
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
)
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
module.forward(query, key, value)
class ScaledDotProductAttentionMaskModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True),
([2, 3, 8, 12], torch.float32, True),
]
)
def forward(self, query, key, value, mask):
return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionMaskModule())
def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32)
module.forward(query, key, value, mask)
class ScaledDotProductAttentionBoolMaskModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, 8, 16], torch.float32, True),
([2, 3, 12, 16], torch.float32, True),
([2, 3, 12, 20], torch.float32, True),
([2, 3, 8, 12], torch.bool, True),
]
)
def forward(self, query, key, value, mask):
return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask)
@register_test_case(module_factory=lambda: ScaledDotProductAttentionBoolMaskModule())
def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils):
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) > 0.5
module.forward(query, key, value, mask)
# ==============================================================================