[Stablehlo] fix aten.arange's lowering to stablehlo (#3138)

* promote to f64 to do division, avoid division on i64 (floor div)
* refactor torch-to-stablehlo-pipeline
pull/3141/head
Yuanqiang Liu 2024-04-11 15:55:56 +08:00 committed by GitHub
parent aa5e150313
commit 88533b1968
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 16 deletions

View File

@ -1492,15 +1492,17 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
// Get length of the 1-d output tensor
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);
Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({1}, dtype), divOut);
if (dtype.isa<mlir::FloatType>()) {
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
resultLength = rewriter.create<stablehlo::ConvertOp>(
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
}
// promote div to f64
Type divType = RankedTensorType::get({}, rewriter.getF64Type());
Value divOut = rewriter.create<stablehlo::DivOp>(
loc, rewriter.create<stablehlo::ConvertOp>(loc, divType, subOut),
rewriter.create<stablehlo::ConvertOp>(loc, divType, step));
// ceil to i64
Value resultLength = rewriter.create<stablehlo::ConvertOp>(
loc, RankedTensorType::get({}, rewriter.getI64Type()),
rewriter.create<stablehlo::CeilOp>(loc, divOut));
resultLength = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
Value window =
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);

View File

@ -142,11 +142,6 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
// Lowering Chlo ops to Stablehlo
pm.addNestedPass<func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
// Canonicalize Stablehlo dynamic ops to static ops
pm.addNestedPass<func::FuncOp>(
stablehlo::createStablehloCanonicalizeDynamismPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Lowering remained ops to Arith
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
@ -162,7 +157,17 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to Stablehlo and Chlo ops.
// Verify that we have lowered to Stablehlo ops.
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
// Canonicalize Stablehlo dynamic ops to static ops
pm.addNestedPass<func::FuncOp>(
stablehlo::createStablehloCanonicalizeDynamismPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addPass(stablehlo::createStablehloRefineShapesPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
stablehlo::createStablehloCanonicalizeDynamismPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
#endif

View File

@ -18,7 +18,6 @@ __all__ = [
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
# Linalg-on-Tensors backend contract accepted by RefBackend.
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([
"canonicalize",
"func.func(stablehlo-aggressive-simplification)",
"stablehlo-legalize-to-linalg",
"canonicalize"