diff --git a/include/torch-mlir-c/Registration.h b/include/torch-mlir-c/Registration.h index e83823b7e..4d582e61f 100644 --- a/include/torch-mlir-c/Registration.h +++ b/include/torch-mlir-c/Registration.h @@ -22,9 +22,6 @@ extern "C" { */ MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context); -/** Registers upstream (MLIR) dialects used in Torch-MLIR IRs. */ -MLIR_CAPI_EXPORTED void torchMlirRegisterRequiredDialects(MlirContext context); - /** Registers all passes for symbolic access with the global registry. */ MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses(); diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index 4e4a058dd..87977a86f 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -13,23 +13,6 @@ add_mlir_public_c_api_library(TorchMLIRCAPI LINK_LIBS PUBLIC MLIRIR - MLIRAffineToStandard - MLIRArithmeticToLLVM - MLIRArithmeticTransforms - MLIRBufferizationTransforms - MLIRControlFlowToLLVM - MLIRFuncToLLVM - MLIRFuncTransforms - MLIRLinalgToLLVM - MLIRLinalgTransforms - MLIRMathToLLVM - MLIRMemRefToLLVM - MLIRReconcileUnrealizedCasts - MLIRSCFToControlFlow - MLIRSCFTransforms - MLIRTensorTransforms - MLIRTosaToArith - MLIRTosaToLinalg MLIRSupport TorchMLIRTorchDialect TorchMLIRInitAll diff --git a/lib/CAPI/Registration.cpp b/lib/CAPI/Registration.cpp index c83f3f43e..52cd10b38 100644 --- a/lib/CAPI/Registration.cpp +++ b/lib/CAPI/Registration.cpp @@ -10,27 +10,11 @@ #include "torch-mlir-c/Registration.h" #include "mlir/CAPI/IR.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/InitAllPasses.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Transforms/Passes.h" #include "torch-mlir/InitAll.h" -void torchMlirRegisterRequiredDialects(MlirContext context) { - mlir::DialectRegistry registry; - registry.insert(); - unwrap(context)->appendDialectRegistry(registry); -} - void torchMlirRegisterAllDialects(MlirContext context) { mlir::DialectRegistry registry; mlir::torch::registerAllDialects(registry); @@ -39,24 +23,4 @@ void torchMlirRegisterAllDialects(MlirContext context) { unwrap(context)->loadAllAvailableDialects(); } -void torchMlirRegisterAllPasses() { - mlir::arith::registerArithmeticPasses(); - mlir::bufferization::registerBufferizationPasses(); - mlir::func::registerFuncPasses(); - mlir::registerConvertAffineToStandardPass(); - mlir::registerConvertArithmeticToLLVMPass(); - mlir::registerConvertControlFlowToLLVMPass(); - mlir::registerConvertFuncToLLVMPass(); - mlir::registerConvertLinalgToLLVMPass(); - mlir::registerConvertMathToLLVMPass(); - mlir::registerConvertMemRefToLLVMPass(); - mlir::registerLinalgPasses(); - mlir::registerReconcileUnrealizedCastsPass(); - mlir::registerSCFPasses(); - mlir::registerSCFToControlFlowPass(); - mlir::registerTosaToArithPass(); - mlir::registerTosaToLinalgNamedPass(); - mlir::registerTosaToLinalgPass(); - mlir::tensor::registerTensorPasses(); - mlir::torch::registerAllPasses(); -} +void torchMlirRegisterAllPasses() { mlir::torch::registerAllPasses(); } diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 8bd781d55..77a31184c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -26,7 +26,6 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel SOURCES __init__.py compiler_utils.py - _mlir_libs/_site_initialize_0.py ) declare_mlir_python_sources(TorchMLIRPythonSources.Dialects @@ -103,10 +102,16 @@ add_subdirectory(torch_mlir/eager_mode) ################################################################################ set(_source_components - MLIRPythonSources.Core - MLIRPythonSources.Dialects.func - MLIRPythonSources.ExecutionEngine + # TODO: Core is now implicitly building/registering all dialects, increasing + # build burden by ~5x. Make it stop. + # TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes + # for the reference backend, but logically they can be separate. But seemingly + # the only way to handle that is to create a separate mlir python package + # tree, which seems excessive. + MLIRPythonSources MLIRPythonExtension.Core + MLIRPythonExtension.RegisterEverything + MLIRPythonExtension.ExecutionEngine TorchMLIRPythonSources TorchMLIRPythonExtensions ) diff --git a/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp index 509fb3749..e0b045143 100644 --- a/python/TorchMLIRModule.cpp +++ b/python/TorchMLIRModule.cpp @@ -9,7 +9,6 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir/Bindings/Python/PybindAdaptors.h" -#include "mlir/CAPI/IR.h" #include "torch-mlir-c/Dialects.h" #include "torch-mlir-c/Registration.h" @@ -20,6 +19,14 @@ PYBIND11_MODULE(_torchMlir, m) { m.doc() = "torch-mlir main python extension"; - m.def("register_required_dialects", torchMlirRegisterRequiredDialects, - py::arg("context")); + m.def( + "register_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle handle = mlirGetDialectHandle__torch__(); + mlirDialectHandleRegisterDialect(handle, context); + if (load) { + mlirDialectHandleLoadDialect(handle, context); + } + }, + py::arg("context"), py::arg("load") = true); } diff --git a/python/torch_mlir/_mlir_libs/_site_initialize_0.py b/python/torch_mlir/_mlir_libs/_site_initialize_0.py deleted file mode 100644 index 9978b72c7..000000000 --- a/python/torch_mlir/_mlir_libs/_site_initialize_0.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import _torchMlir - -def context_init_hook(context): - _torchMlir.register_required_dialects(context) diff --git a/python/torch_mlir/dialects/torch/__init__.py b/python/torch_mlir/dialects/torch/__init__.py index b94a7bdab..bd362849a 100644 --- a/python/torch_mlir/dialects/torch/__init__.py +++ b/python/torch_mlir/dialects/torch/__init__.py @@ -4,4 +4,4 @@ # Also available under a BSD-style license. See LICENSE. from .._torch_ops_gen import * -from ..._mlir_libs._torchMlir import register_required_dialects +from ..._mlir_libs._torchMlir import register_dialect diff --git a/test/python/smoketest.py b/test/python/smoketest.py index c42335123..88e0a10f7 100644 --- a/test/python/smoketest.py +++ b/test/python/smoketest.py @@ -4,4 +4,4 @@ import torch_mlir.ir from torch_mlir.dialects import torch with torch_mlir.ir.Context() as ctx: - torch.register_required_dialects(ctx) + torch.register_dialect(ctx)