torch-mlir/lib/Dialect/Torch/Transforms/Passes.cpp

160 lines
8.2 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
void mlir::torch::registerTorchPasses() {
mlir::torch::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-module-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-function-to-torch-backend-pipeline",
"Pipeline lowering a Torch function to Torch backend form.",
mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-simplification-pipeline",
"Pipeline simplifying computations in the program.",
mlir::torch::Torch::createTorchSimplificationPipeline);
mlir::PassPipelineRegistration<>(
"torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.",
mlir::torch::Torch::createTorchShapeRefinementPipeline);
}
void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// When we import TorchScript IR, we import their entire "compilation unit",
// which can contain numerous functions unrelated to the current program,
// which breaks torch-globalization-pipeline; for example, there can be
// random functions referencing types that haven't been imported
// as part of the root `torch.nn.Module` we imported. Those will
// be unreferenced private functions which symbol-dce will clean up nicely.
pm.addPass(createSymbolDCEPass());
// Globalize the program. The rest of the compiler assumes a globalized
// program, which makes all analyses and transforms significantly easier
// to write.
pm.addPass(createPrepareForGlobalizeObjectGraphPass());
pm.addPass(createGlobalizeObjectGraphPass());
// "lower" `torch.global_slot` ops by deleting them if unused, which we
// currently require because we don't have a lowering path for backends to
// handle them.
// Torch usually inserts a few unused global slots so this ends up hitting
// every single module even if it doesn't have any explicit slots.
// TODO: Support global slots in backends.
pm.addPass(createSymbolDCEPass());
// Currently, our shape inference is not powerful enough to deal with
// calls, so inline everything.
// TODO: Improve shape inference.
pm.addPass(createInlinerPass());
createTorchFunctionToTorchBackendPipeline(pm, options);
}
void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
// Perform the bulk of lowering to the backend contract.
// See the pass documentation for more information.
pm.addPass(createLowerToBackendContractPass(
options.maxIterations, options.decompose, options.backendLegalOps));
}
// A simplification pipeline to establish the invariants of the backend
// contract (see `satisfiedBackendContract` in `LowerToBackendContract`).
//
// We structure this so that a single run of this pipeline is enough for
// most models, but it is possible for it to take multiple runs to fully
// clean things up when there are cyclic dependencies between certain
// simplifications, such as a decomposition relying on shape refinement which
// depends on another decomposition.
//
// Although technically this pipeline is an implementation detail of
// LowerToBackendContract, we expose it here to help debugging.
//
// LowerToBackendContract will run this pipeline as many times as necessary, but
// in general, it is costly to re-run this pipeline, since all the passes do
// O(module size) work. We want the number of iterations of this pipeline
// to be bounded by meaningful "always in practice small" program properties,
// such as loop nesting depth, number of sequentially dependent steps of
// constant global slots proving that other global slots are dead, etc.
//
// It is generally always possible to construct a pathological input that will
// exceed the number of iterations. If we do find practical cases with
// O(module size) number of iterations of this simplification pipeline, then
// we may need to adjust the approach, such as to do some of the transformations
// together at finer granularity.
void mlir::torch::Torch::createTorchSimplificationPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// General cleanup.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Inline global slots to expose a bunch of simplification opportunities
// from constant hyperparameters, weights, etc.
pm.addPass(createInlineGlobalSlotsPass());
// Erase the module initializer if we have proven that all the global slots
// are gone.
pm.addPass(createEraseModuleInitializerPass());
// Clean up again to avoid needing to to back around the fixed-point
// iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Remove dead global slots.
pm.addPass(createSymbolDCEPass());
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
// Update the return op to return value tensors.
pm.addPass(Torch::createRefinePublicReturnPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Do shape refinement.
// This should be run before RefineTypes (which primarily does dtype
// inference), because Torch type promotion rules actually depend on the shape
// of the operand.
createTorchShapeRefinementPipeline(pm);
// Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
pm.addPass(Torch::createRefinePublicReturnPass());
// This can fold away some branches given the information got from
// RefineTypes before doing maximize value sematics which only works with
// basic blocks.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
}
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
// Reify the shape functions for each op that is present in the shape library.
pm.addPass(Torch::createReifyShapeCalculationsPass());
// Inline the shape functions to enable analysis and transformation.
// TODO: Only inline shape functions (this will currently inline everything).
pm.addPass(createInlinerPass());
// Now, try to simplify shape calculations. This is unfortunately a "optimize
// as hard as possible" kind of thing, so it's inherently somewhat brittle.
// The idea is to keep strengthening what we do here to support the shape
// library. We don't need to support arbitrary programs, thankfully.
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
// Run CSE, then see if we can simplify further.
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
// Drop shape calculations, leaving behind the shape-refined program.
pm.addNestedPass<func::FuncOp>(Torch::createDropShapeCalculationsPass());
}