diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.h b/include/npcomp/Dialect/Torch/Transforms/Passes.h index f2d6ba6c6..4b6db2b29 100644 --- a/include/npcomp/Dialect/Torch/Transforms/Passes.h +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.h @@ -22,10 +22,6 @@ std::unique_ptr> createGlobalizeObjectGraphPass(); std::unique_ptr> createPrepareForGlobalizeObjectGraphPass(); -/// Creates a pipeline that "globalizes" the given program. -/// See the documentation on torch-globalize-object-graph for more details. -void createGlobalizePipeline(OpPassManager &pm); - /// Creates a pipeline that lowers the object graph IR that is produced by /// TorchScript import into the form expected by npcomp-verify-backend-contract. void createLowerObjectGraphPipeline(OpPassManager &pm); diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 1d8822acc..ce824efc4 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -28,9 +28,6 @@ namespace { void mlir::NPCOMP::registerTorchPasses() { ::registerPasses(); - mlir::PassPipelineRegistration<>( - "torch-globalize-pipeline", "Globalization pipeline.", - mlir::NPCOMP::Torch::createGlobalizePipeline); mlir::PassPipelineRegistration<>( "torchscript-to-npcomp-backend-pipeline", "Pipeline lowering torch object graph to npcomp backend format.", @@ -41,82 +38,78 @@ void mlir::NPCOMP::registerTorchPasses() { mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline); } -void mlir::NPCOMP::Torch::createGlobalizePipeline(OpPassManager &pm) { +void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(OpPassManager &pm) { + // When we import TorchScript IR, we import their entire "compilation unit", + // which can contain numerous functions unrelated to the current program, + // which breaks torch-globalization-pipeline; for example, there can be + // random functions referencing types that haven't been imported + // as part of the root `torch.nn.Module` we imported. Those will + // be unreferenced private functions which symbol-dce will clean up nicely. + pm.addPass(createSymbolDCEPass()); + // Globalize the program. The rest of the compiler assumes a globalized + // program, which makes all analyses and transforms significantly easier + // to write. pm.addPass(createPrepareForGlobalizeObjectGraphPass()); pm.addPass(createGlobalizeObjectGraphPass()); + // "lower" `torch.global_slot` ops by deleting them if unused, which we + // currently require because we don't have a lowering path for backends to + // handle them. + // Torch usually inserts a few unused global slots so this ends up hitting + // every single module even if it doesn't have any explicit slots. + // TODO: Support global slots in backends. + pm.addPass(createSymbolDCEPass()); + // Currently, our shape inference is not powerful enough to deal with + // calls, so inline everything. + // TODO: Improve shape inference. + pm.addPass(createInlinerPass()); + // Incorporate user annotations and remove signature Python-isms. + pm.addPass(createAdjustCallingConventionsPass()); + + createLowerToNpcompBackendPipeline(pm); } -void mlir::NPCOMP::Torch::createLowerObjectGraphPipeline(OpPassManager &pm) { - // When we import TorchScript IR, we import their entire "compilation unit", - // which can contain numerous functions unrelated to the current program, - // which breaks torch-globalization-pipeline; for example, there can be - // random functions referencing types that haven't been imported - // as part of the root `torch.nn.Module` we imported. Those will - // be unreferenced private functions which symbol-dce will clean up nicely. - pm.addPass(createSymbolDCEPass()); - // Globalize the program. The rest of the compiler assumes a globalized - // program, which makes all analyses and transforms significantly easier - // to write. - pm.addPass(createPrepareForGlobalizeObjectGraphPass()); - pm.addPass(createGlobalizeObjectGraphPass()); - // "lower" `torch.global_slot` ops by deleting them if unused, which we - // currently require because we don't have a lowering path for backends to - // handle them. - // Torch usually inserts a few unused global slots so this ends up hitting - // every single module even if it doesn't have any explicit slots. - // TODO: Support global slots in backends. - pm.addPass(createSymbolDCEPass()); - // Currently, our shape inference is not powerful enough to deal with - // calls, so inline everything. - // TODO: Improve shape inference. - pm.addPass(createInlinerPass()); - // Incorporate user annotations and remove signature Python-isms. - pm.addPass(createAdjustCallingConventionsPass()); +void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline( + OpPassManager &pm) { + // Recognize ATen kernels. + pm.addNestedPass(aten::createRecognizeKernelsPass()); - createLowerToNpcompBackendPipeline(pm); -} - -void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(OpPassManager &pm) { - // Recognize ATen kernels. - pm.addNestedPass(aten::createRecognizeKernelsPass()); - - // Convert the bulk of the program to ranked tensors with known dtype. - // This is the input to the backend layer that we are aiming for. - - // First, unilaterally convert public functions to tensor. - // The way this pass is currently written, this implies that - // as pipeline authors, we are restricting our users to not be able to see - // updates to "out params" on their public functions. - // This is deemed ok for now. - pm.addPass(Numpy::createPublicFunctionsToTensorPass()); - // Inline global slots, which for most inference scenarios deletes them. - // This also exposes more information to intraprocedural transformations - // below like ArrayToTensor and RefineTypes. - // TODO: Don't rely on this pass to "lower" global slots by deleting. - // This pass should eventually be "just an optimization". - pm.addPass(createInlineGlobalSlotsPass()); - // Convert the bulk of non-ABI-visible arrays to tensors. - pm.addNestedPass(Numpy::createArrayToTensorPass()); - // Do shape and dtype refinement. - // We could do it sooner, but the pass currently doesn't have transfer - // functions for array ops. - pm.addNestedPass(Torch::createRefineTypesPass()); - // Propagate to ABI return types the shape/dtype information discovered by - // the previous pass. Doing this is ABI-compatible for our backends. - pm.addPass(Numpy::createRefinePublicReturnPass()); - // Clean up a few stray array/tensor conversion remnants. - pm.addNestedPass(Numpy::createArrayToTensorPass()); - - // Lower to TCP (+ guards) which is the input to codegen backends. - // Most of this should be subsumed by aten->linalg+guards conversions. - // (the guard generation will be automated from the linalg Op DSL). - pm.addNestedPass(createConvertATenToLinalgPass()); - pm.addNestedPass(createConvertATenToTCFPass()); - pm.addNestedPass(createConvertTCFToStdPass()); - pm.addNestedPass(createConvertElementwiseToLinalgPass()); - - // Verify that we have lowered to the form that backends expect. - // This fails compilation (signalPassFailure) if the IR is not in the - // correct form. - pm.addPass(CommonBackend::createVerifyBackendContractPass()); + // Convert the bulk of the program to ranked tensors with known dtype. + // This is the input to the backend layer that we are aiming for. + + // First, unilaterally convert public functions to tensor. + // The way this pass is currently written, this implies that + // as pipeline authors, we are restricting our users to not be able to see + // updates to "out params" on their public functions. + // This is deemed ok for now. + pm.addPass(Numpy::createPublicFunctionsToTensorPass()); + // Inline global slots, which for most inference scenarios deletes them. + // This also exposes more information to intraprocedural transformations + // below like ArrayToTensor and RefineTypes. + // TODO: Don't rely on this pass to "lower" global slots by deleting. + // This pass should eventually be "just an optimization". + pm.addPass(createInlineGlobalSlotsPass()); + // Convert the bulk of non-ABI-visible arrays to tensors. + pm.addNestedPass(Numpy::createArrayToTensorPass()); + // Do shape and dtype refinement. + // We could do it sooner, but the pass currently doesn't have transfer + // functions for array ops. + pm.addNestedPass(Torch::createRefineTypesPass()); + // Propagate to ABI return types the shape/dtype information discovered by + // the previous pass. Doing this is ABI-compatible for our backends. + pm.addPass(Numpy::createRefinePublicReturnPass()); + // Clean up a few stray array/tensor conversion remnants. + pm.addNestedPass(Numpy::createArrayToTensorPass()); + + // Lower to TCP (+ guards) which is the input to codegen backends. + // Most of this should be subsumed by aten->linalg+guards conversions. + // (the guard generation will be automated from the linalg Op DSL). + pm.addNestedPass(createConvertATenToLinalgPass()); + pm.addNestedPass(createConvertATenToTCFPass()); + pm.addNestedPass(createConvertTCFToStdPass()); + pm.addNestedPass(createConvertElementwiseToLinalgPass()); + + // Verify that we have lowered to the form that backends expect. + // This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(CommonBackend::createVerifyBackendContractPass()); }