mirror of https://github.com/llvm/torch-mlir
[Stablehlo] fix aten.arange's lowering to stablehlo (#3138)
* promote to f64 to do division, avoid division on i64 (floor div) * refactor torch-to-stablehlo-pipelinepull/3141/head
parent
aa5e150313
commit
88533b1968
|
@ -1492,15 +1492,17 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
|
||||
// Get length of the 1-d output tensor
|
||||
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
|
||||
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);
|
||||
|
||||
Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get({1}, dtype), divOut);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
|
||||
resultLength = rewriter.create<stablehlo::ConvertOp>(
|
||||
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
|
||||
}
|
||||
// promote div to f64
|
||||
Type divType = RankedTensorType::get({}, rewriter.getF64Type());
|
||||
Value divOut = rewriter.create<stablehlo::DivOp>(
|
||||
loc, rewriter.create<stablehlo::ConvertOp>(loc, divType, subOut),
|
||||
rewriter.create<stablehlo::ConvertOp>(loc, divType, step));
|
||||
// ceil to i64
|
||||
Value resultLength = rewriter.create<stablehlo::ConvertOp>(
|
||||
loc, RankedTensorType::get({}, rewriter.getI64Type()),
|
||||
rewriter.create<stablehlo::CeilOp>(loc, divOut));
|
||||
resultLength = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
|
||||
|
||||
Value window =
|
||||
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
||||
|
|
|
@ -142,11 +142,6 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
|||
// Lowering Chlo ops to Stablehlo
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
stablehlo::createChloLegalizeToStablehloPass());
|
||||
// Canonicalize Stablehlo dynamic ops to static ops
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
stablehlo::createStablehloCanonicalizeDynamismPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
|
||||
// Lowering remained ops to Arith
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||
|
||||
|
@ -162,7 +157,17 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
|||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
|
||||
// Verify that we have lowered to Stablehlo and Chlo ops.
|
||||
// Verify that we have lowered to Stablehlo ops.
|
||||
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
|
||||
|
||||
// Canonicalize Stablehlo dynamic ops to static ops
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
stablehlo::createStablehloCanonicalizeDynamismPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addPass(stablehlo::createStablehloRefineShapesPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
stablehlo::createStablehloCanonicalizeDynamismPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -18,7 +18,6 @@ __all__ = [
|
|||
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
|
||||
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
||||
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([
|
||||
"canonicalize",
|
||||
"func.func(stablehlo-aggressive-simplification)",
|
||||
"stablehlo-legalize-to-linalg",
|
||||
"canonicalize"
|
||||
|
|
Loading…
Reference in New Issue