mirror of https://github.com/llvm/torch-mlir
[FxImporter] Add aten._scaled_dot_product_flash_attention_for_cpu to default decomposition table (#3456)
parent
919b599ebe
commit
a02e14e971
|
@ -248,8 +248,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
|
||||||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
|
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
|
||||||
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
|
||||||
"AtenEmbeddingBagStaticModule_basic",
|
"AtenEmbeddingBagStaticModule_basic",
|
||||||
# Lowering not present for this case
|
# Lowering not present for this case
|
||||||
|
@ -731,7 +729,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMeanModule",
|
"ScatterReduceFloatMeanModule",
|
||||||
|
@ -1978,6 +1975,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
|
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
|
||||||
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
}
|
}
|
||||||
) - {
|
) - {
|
||||||
### Test failing in make_fx_tosa but not in tosa
|
### Test failing in make_fx_tosa but not in tosa
|
||||||
|
@ -3349,7 +3347,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ScalarConstantTupleModule_basic",
|
"ScalarConstantTupleModule_basic",
|
||||||
"ScalarImplicitFloatModule_basic",
|
"ScalarImplicitFloatModule_basic",
|
||||||
"ScalarImplicitIntModule_basic",
|
"ScalarImplicitIntModule_basic",
|
||||||
"ScaledDotProductAttentionDifferentModule_basic",
|
|
||||||
"ScatterReduceFloatMaxModule",
|
"ScatterReduceFloatMaxModule",
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMeanModule",
|
"ScatterReduceFloatMeanModule",
|
||||||
|
|
|
@ -65,6 +65,7 @@ def _get_decomposition_table():
|
||||||
aten.sigmoid_backward,
|
aten.sigmoid_backward,
|
||||||
aten._native_batch_norm_legit,
|
aten._native_batch_norm_legit,
|
||||||
aten.squeeze,
|
aten.squeeze,
|
||||||
|
aten._scaled_dot_product_flash_attention_for_cpu,
|
||||||
]
|
]
|
||||||
# TODO: enable test once 2.1.0 is stable
|
# TODO: enable test once 2.1.0 is stable
|
||||||
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
|
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
|
||||||
|
|
|
@ -48,6 +48,7 @@ DEFAULT_DECOMPOSITIONS = [
|
||||||
torch.ops.aten.triu.default,
|
torch.ops.aten.triu.default,
|
||||||
torch.ops.aten.nan_to_num.default,
|
torch.ops.aten.nan_to_num.default,
|
||||||
torch.ops.aten.unbind,
|
torch.ops.aten.unbind,
|
||||||
|
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue