From a02e14e9712b41ec629d071994bdc19990f2c91a Mon Sep 17 00:00:00 2001 From: Wu Yuan Date: Fri, 14 Jun 2024 10:52:09 +0800 Subject: [PATCH] [FxImporter] Add aten._scaled_dot_product_flash_attention_for_cpu to default decomposition table (#3456) --- projects/pt1/e2e_testing/xfail_sets.py | 5 +---- projects/pt1/python/torch_mlir/dynamo.py | 1 + python/torch_mlir/extras/fx_decomp_util.py | 1 + 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 25d7df1fd..be9498a53 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -248,8 +248,6 @@ TORCHDYNAMO_XFAIL_SET = { # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - # AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu - "ScaledDotProductAttentionDifferentModule_basic", # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only "AtenEmbeddingBagStaticModule_basic", # Lowering not present for this case @@ -731,7 +729,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "RsubInt0d_NumToTensor_Module_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", - "ScaledDotProductAttentionDifferentModule_basic", "ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMeanModule", @@ -1978,6 +1975,7 @@ MAKE_FX_TOSA_PASS_SET = ( "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", "ViewSizeDimLedByCollapsedOnesModule_basic", "ViewSizeFromOtherTensor_basic", + "ScaledDotProductAttentionDifferentModule_basic", } ) - { ### Test failing in make_fx_tosa but not in tosa @@ -3349,7 +3347,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", - "ScaledDotProductAttentionDifferentModule_basic", "ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMeanModule", diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index 2c339be98..1c202ed3a 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -65,6 +65,7 @@ def _get_decomposition_table(): aten.sigmoid_backward, aten._native_batch_norm_legit, aten.squeeze, + aten._scaled_dot_product_flash_attention_for_cpu, ] # TODO: enable test once 2.1.0 is stable if torch_version_for_comparison() >= version.parse("2.1.0.dev"): diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py index 754fb4132..868dc26c6 100644 --- a/python/torch_mlir/extras/fx_decomp_util.py +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -48,6 +48,7 @@ DEFAULT_DECOMPOSITIONS = [ torch.ops.aten.triu.default, torch.ops.aten.nan_to_num.default, torch.ops.aten.unbind, + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, ]