mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Enable Stablehlo backend with arith dialect (#2139)
parent
4216c7d622
commit
5223f990df
|
@ -301,6 +301,41 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"AllBoolFalseModule_basic",
|
||||
"AllBoolTrueModule_basic",
|
||||
"AnyBoolFalseModule_basic",
|
||||
"AnyBoolTrueModule_basic",
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
"AtenSubFloatModule_basic",
|
||||
"BoolFloatConstantModule_basic",
|
||||
"BoolIntConstantModule_basic",
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"IntFloatModule_basic",
|
||||
"IsFloatingPointFloat_True",
|
||||
"IsFloatingPointInt_False",
|
||||
"LenStrModule_basic",
|
||||
"MeanDimAllReduceKeepdimModule_basic",
|
||||
"MeanDimAllReduceModule_basic",
|
||||
"MeanDimDtypeModule_basic",
|
||||
"MeanDimKeepdimModule_basic",
|
||||
"MeanDimModule_basic",
|
||||
"MeanDimNegativeModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
"PrimMinIntModule_basic",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
"SqrtIntConstantModule_basic",
|
||||
"StdBiasedModule_basic",
|
||||
"StdDimBiasedModule_basic",
|
||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||
"VarBiasedModule_basic",
|
||||
"VarDimBiasedModule_basic",
|
||||
"VarMeanBiasedModule_basic",
|
||||
"VarMeanDimBiasedModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
"MaskedFillScalarIntValueStaticModule_basic",
|
||||
"MaskedFillScalarFloatValueStaticModule_basic",
|
||||
|
|
|
@ -128,6 +128,8 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
|||
// Generate Stablehlo ops.
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
|
||||
options.enableStaticShape, options.enableI32Index));
|
||||
// Lowering remained ops to Arith
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
|
@ -137,6 +139,7 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
|||
// Finish the type conversion from `torch` types to the types of the
|
||||
// StableHLO backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
|
||||
|
|
Loading…
Reference in New Issue