[Stablehlo] add stablehlo-canonicalize-dynamism when lowering (#3097)

so that many stablehlo e2e testcases could pass
pull/3100/head
Yuanqiang Liu 2024-04-02 22:47:24 +08:00 committed by GitHub
parent d1f770c620
commit 6cbb2f7ae0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 0 deletions

View File

@ -142,6 +142,9 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
// Lowering Chlo ops to Stablehlo
pm.addNestedPass<func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
// Canonicalize Stablehlo dynamic ops to static ops
pm.addNestedPass<func::FuncOp>(
stablehlo::createStablehloCanonicalizeDynamismPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Lowering remained ops to Arith

View File

@ -782,6 +782,63 @@ STABLEHLO_PASS_SET = {
"ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"AtenEmbeddingBagStaticModule_basic",
"AtenEyeMModuleCPUDevice_basic",
"AtenEyeMModuleDefaultDtype_basic",
"AtenEyeMModuleFalsePinMemory_basic",
"AtenEyeMModuleFloat2D_basic",
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleCPUDevice_basic",
"AtenEyeModuleDefaultDtype_basic",
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenInstanceNormModule_basic",
"AtenLinalgCrossBroadcast_basic",
"AtenLinalgCrossCustomDim_basic",
"AtenLinalgCrossFloat_basic",
"AtenLinalgCrossInt_basic",
"AtenLinalgCrossNegativeDim_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"DropoutTrainStaticShapeModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleF16_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"GatherStaticModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"NativeLayerNormModule4D_basic",
"NativeLayerNormModule_basic",
"NormalizeModule_basic",
"PrimsSqueezeModule_basic",
"RandModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SelectScattertStaticModule_basic",
"SqueezeDimModule_static",
"SqueezeModule_static",
"TriuBroadcastModule_basic",
"TriuModule_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UniformStaticShapeModule_basic",
}
STABLEHLO_CRASHING_SET = {