diff --git a/include/torch-mlir/InitAll.h b/include/torch-mlir/InitAll.h index 42eb3c6a1..19b2c474d 100644 --- a/include/torch-mlir/InitAll.h +++ b/include/torch-mlir/InitAll.h @@ -18,6 +18,9 @@ namespace torch { // Registers all dialects that this project produces and any dependencies. void registerAllDialects(mlir::DialectRegistry ®istry); +// Registers all necessary dialect extensions for this project +void registerAllExtensions(mlir::DialectRegistry ®istry); + // Registers dialects that may be needed to parse torch-mlir inputs and // test cases. void registerOptionalInputDialects(mlir::DialectRegistry ®istry); diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c0b622005..249a8ad4f 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -13,6 +13,7 @@ set(LinkedLibs MLIRMemRefDialect MLIRSCFDialect MLIRTensorDialect + MLIRTensorInferTypeOpInterfaceImpl MLIRTosaDialect MLIRSupport diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index e8f9622c3..3b8b4ba04 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Dialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" @@ -39,7 +40,11 @@ void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); +} + +void mlir::torch::registerAllExtensions(mlir::DialectRegistry ®istry) { mlir::func::registerInlinerExtension(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); } // TODO: Break this up when backends are separated. diff --git a/tools/torch-mlir-opt/torch-mlir-opt.cpp b/tools/torch-mlir-opt/torch-mlir-opt.cpp index 2750ee2b7..0fa392de4 100644 --- a/tools/torch-mlir-opt/torch-mlir-opt.cpp +++ b/tools/torch-mlir-opt/torch-mlir-opt.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" @@ -33,8 +34,13 @@ int main(int argc, char **argv) { registerStripDebugInfoPass(); registerSymbolDCEPass(); + // memref passes used in torch-backend-to-linalg-on-tensors-backend-pipeline + memref::registerExpandOpsPass(); + memref::registerResolveShapedTypeResultDimsPass(); + DialectRegistry registry; mlir::torch::registerAllDialects(registry); + mlir::torch::registerAllExtensions(registry); mlir::torch::registerOptionalInputDialects(registry); #ifdef TORCH_MLIR_ENABLE_STABLEHLO