mirror of https://github.com/llvm/torch-mlir
[FxImporter] Add aten._scaled_dot_product_flash_attention_for_cpu to default decomposition table (#3456)
parent
919b599ebe
commit
a02e14e971
|
@ -248,8 +248,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
# Lowering not present for this case
|
||||
|
@ -731,7 +729,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScatterReduceFloatMaxModule",
|
||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||
"ScatterReduceFloatMeanModule",
|
||||
|
@ -1978,6 +1975,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
|
||||
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
}
|
||||
) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
|
@ -3349,7 +3347,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScatterReduceFloatMaxModule",
|
||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||
"ScatterReduceFloatMeanModule",
|
||||
|
|
|
@ -65,6 +65,7 @@ def _get_decomposition_table():
|
|||
aten.sigmoid_backward,
|
||||
aten._native_batch_norm_legit,
|
||||
aten.squeeze,
|
||||
aten._scaled_dot_product_flash_attention_for_cpu,
|
||||
]
|
||||
# TODO: enable test once 2.1.0 is stable
|
||||
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
|
||||
|
|
|
@ -48,6 +48,7 @@ DEFAULT_DECOMPOSITIONS = [
|
|||
torch.ops.aten.triu.default,
|
||||
torch.ops.aten.nan_to_num.default,
|
||||
torch.ops.aten.unbind,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue