[FxImporter] Add aten._scaled_dot_product_flash_attention_for_cpu to default decomposition table (#3456)

pull/3461/head
Wu Yuan 2024-06-14 10:52:09 +08:00 committed by GitHub
parent 919b599ebe
commit a02e14e971
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3 additions and 4 deletions

View File

@ -248,8 +248,6 @@ TORCHDYNAMO_XFAIL_SET = {
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
"ScaledDotProductAttentionDifferentModule_basic",
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
"AtenEmbeddingBagStaticModule_basic",
# Lowering not present for this case
@ -731,7 +729,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"RsubInt0d_NumToTensor_Module_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",
@ -1978,6 +1975,7 @@ MAKE_FX_TOSA_PASS_SET = (
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
"ScaledDotProductAttentionDifferentModule_basic",
}
) - {
### Test failing in make_fx_tosa but not in tosa
@ -3349,7 +3347,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",

View File

@ -65,6 +65,7 @@ def _get_decomposition_table():
aten.sigmoid_backward,
aten._native_batch_norm_legit,
aten.squeeze,
aten._scaled_dot_product_flash_attention_for_cpu,
]
# TODO: enable test once 2.1.0 is stable
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):

View File

@ -48,6 +48,7 @@ DEFAULT_DECOMPOSITIONS = [
torch.ops.aten.triu.default,
torch.ops.aten.nan_to_num.default,
torch.ops.aten.unbind,
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
]