mirror of https://github.com/llvm/torch-mlir
[Stablehlo] add stablehlo-canonicalize-dynamism when lowering (#3097)
so that many stablehlo e2e testcases could passpull/3100/head
parent
d1f770c620
commit
6cbb2f7ae0
|
@ -142,6 +142,9 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
|||
// Lowering Chlo ops to Stablehlo
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
stablehlo::createChloLegalizeToStablehloPass());
|
||||
// Canonicalize Stablehlo dynamic ops to static ops
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
stablehlo::createStablehloCanonicalizeDynamismPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
|
||||
// Lowering remained ops to Arith
|
||||
|
|
|
@ -782,6 +782,63 @@ STABLEHLO_PASS_SET = {
|
|||
"ZerosModuleFloat3D_basic",
|
||||
"ZerosModuleInt2D_basic",
|
||||
"ZerosModuleInt3D_basic",
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
"AtenEyeMModuleCPUDevice_basic",
|
||||
"AtenEyeMModuleDefaultDtype_basic",
|
||||
"AtenEyeMModuleFalsePinMemory_basic",
|
||||
"AtenEyeMModuleFloat2D_basic",
|
||||
"AtenEyeMModuleInt2D_basic",
|
||||
"AtenEyeModuleCPUDevice_basic",
|
||||
"AtenEyeModuleDefaultDtype_basic",
|
||||
"AtenEyeModuleFalsePinMemory_basic",
|
||||
"AtenEyeModuleFloat2D_basic",
|
||||
"AtenEyeModuleInt2D_basic",
|
||||
"AtenInstanceNormModule_basic",
|
||||
"AtenLinalgCrossBroadcast_basic",
|
||||
"AtenLinalgCrossCustomDim_basic",
|
||||
"AtenLinalgCrossFloat_basic",
|
||||
"AtenLinalgCrossInt_basic",
|
||||
"AtenLinalgCrossNegativeDim_basic",
|
||||
"BucketizeTensorStaticFloatModule_basic",
|
||||
"BucketizeTensorStaticModule_basic",
|
||||
"DropoutTrainStaticShapeModule_basic",
|
||||
"ElementwiseWhereScalarOtherStaticModule_basic",
|
||||
"ElementwiseWhereScalarSelfStaticModule_basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"EmbeddingModuleF16_basic",
|
||||
"EmbeddingModuleI32Static_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"GatherStaticModule_basic",
|
||||
"IndexSelectDynamicIndexSizeModule_basic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"IndexSelectSingleIdxModule_basic",
|
||||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"LayerNormLastDimModule_basic",
|
||||
"LayerNormModule_basic",
|
||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||
"MaxPool2dWithIndicesStaticModule_basic",
|
||||
"MeanDimAllReduceKeepdimModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
"NativeLayerNormModule4D_basic",
|
||||
"NativeLayerNormModule_basic",
|
||||
"NormalizeModule_basic",
|
||||
"PrimsSqueezeModule_basic",
|
||||
"RandModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SqueezeDimModule_static",
|
||||
"SqueezeModule_static",
|
||||
"TriuBroadcastModule_basic",
|
||||
"TriuModule_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UniformStaticShapeModule_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = {
|
||||
|
|
Loading…
Reference in New Issue