[Stablehlo] Enable Stablehlo backend with arith dialect (#2139)

pull/1499/head
Yuanqiang Liu 2023-05-26 22:57:57 +08:00 committed by GitHub
parent 4216c7d622
commit 5223f990df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 0 deletions

View File

@ -301,6 +301,41 @@ TORCHDYNAMO_CRASHING_SET = {
}
STABLEHLO_PASS_SET = {
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
"AtenSubFloatModule_basic",
"BoolFloatConstantModule_basic",
"BoolIntConstantModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
"IntFloatModule_basic",
"IsFloatingPointFloat_True",
"IsFloatingPointInt_False",
"LenStrModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimDtypeModule_basic",
"MeanDimKeepdimModule_basic",
"MeanDimModule_basic",
"MeanDimNegativeModule_basic",
"NumelZeroRankModule_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SqrtIntConstantModule_basic",
"StdBiasedModule_basic",
"StdDimBiasedModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"VarBiasedModule_basic",
"VarDimBiasedModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanDimBiasedModule_basic",
"ConstantBoolParameterModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",

View File

@ -128,6 +128,8 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
// Generate Stablehlo ops.
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
options.enableStaticShape, options.enableI32Index));
// Lowering remained ops to Arith
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
@ -137,6 +139,7 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
// Finish the type conversion from `torch` types to the types of the
// StableHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());