diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 2e2c108c9..c0b622005 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -32,7 +32,7 @@ set(LinkedLibs ) if(TORCH_MLIR_ENABLE_STABLEHLO) -list(APPEND LinkedLibs StablehloLinalgTransforms) +list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses) endif() if(TORCH_MLIR_ENABLE_REFBACKEND) diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index ce29176c9..e8f9622c3 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -31,6 +31,7 @@ #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "stablehlo/conversions/linalg/transforms/Passes.h" +#include "stablehlo/transforms/Passes.h" #endif void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { @@ -58,6 +59,8 @@ void mlir::torch::registerAllPasses() { #ifdef TORCH_MLIR_ENABLE_STABLEHLO mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); + mlir::stablehlo::registerStablehloAggressiveSimplificationPass(); + mlir::stablehlo::registerStablehloRefineShapesPass(); #endif #ifdef TORCH_MLIR_ENABLE_REFBACKEND diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 44594e9b6..6efd56ce2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -839,6 +839,54 @@ STABLEHLO_PASS_SET = { "UnbindIntGetItem_Module_basic", "UnbindIntListUnpack_Module_basic", "UniformStaticShapeModule_basic", + "ArangeStartOutViewModule_basic", + "ConvolutionBackwardModule2DStrided_basic", + "EinsumStaticContractRhsModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", + "EinsumStaticWithEllipsisSlicingModule_basic", + "FlattenStaticModule_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "NativeGroupNormModule_basic", + "RepeatModule_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReshapeAsModule_basic", + "ReshapeExpandModule_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", + "UniformNoCorrelationModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", + "UnsafeViewExpandModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewCollapseModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandOnesModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "ViewNoChangeStaticModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", } STABLEHLO_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 4c511b9c5..4899549a8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -18,6 +18,8 @@ __all__ = [ # The pipeline of func.func passes that lower the STABLEHLO backend contract to the # Linalg-on-Tensors backend contract accepted by RefBackend. STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ + "canonicalize", + "func.func(stablehlo-aggressive-simplification)", "stablehlo-legalize-to-linalg", "canonicalize" ])