[Stablehlo] add stablehlo-aggressive-simplification in e2e test (#3109)

* so that more stablehlo e2e testcases would pass.
pull/3110/head
Yuanqiang Liu 2024-04-07 10:48:11 +08:00 committed by GitHub
parent 9d9a05366e
commit 0a00f38a7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 1 deletions

View File

@ -32,7 +32,7 @@ set(LinkedLibs
)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs StablehloLinalgTransforms)
list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses)
endif()
if(TORCH_MLIR_ENABLE_REFBACKEND)

View File

@ -31,6 +31,7 @@
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "stablehlo/conversions/linalg/transforms/Passes.h"
#include "stablehlo/transforms/Passes.h"
#endif
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
@ -58,6 +59,8 @@ void mlir::torch::registerAllPasses() {
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerStablehloLegalizeToLinalgPass();
mlir::stablehlo::registerStablehloAggressiveSimplificationPass();
mlir::stablehlo::registerStablehloRefineShapesPass();
#endif
#ifdef TORCH_MLIR_ENABLE_REFBACKEND

View File

@ -839,6 +839,54 @@ STABLEHLO_PASS_SET = {
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UniformStaticShapeModule_basic",
"ArangeStartOutViewModule_basic",
"ConvolutionBackwardModule2DStrided_basic",
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"FlattenStaticModule_basic",
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
"NativeGroupNormModule_basic",
"RepeatModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeAsModule_basic",
"ReshapeExpandModule_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",
"UniformNoCorrelationModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
"UnsafeViewExpandModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandOnesModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
}
STABLEHLO_CRASHING_SET = {

View File

@ -18,6 +18,8 @@ __all__ = [
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
# Linalg-on-Tensors backend contract accepted by RefBackend.
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([
"canonicalize",
"func.func(stablehlo-aggressive-simplification)",
"stablehlo-legalize-to-linalg",
"canonicalize"
])