mirror of https://github.com/llvm/torch-mlir
Add python binding for running passes.
parent
bb871e7601
commit
fddf41ca92
|
@ -0,0 +1,43 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
"""Test for the MLIR Pass Python bindings"""
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from _npcomp.mlir import passes
|
||||
from npcomp.utils import test_utils
|
||||
|
||||
test_utils.start_filecheck_test()
|
||||
c = ir.MLIRContext()
|
||||
|
||||
pm = passes.PassManager(c)
|
||||
|
||||
# CHECK-LABEL: module @parseSuccess
|
||||
m = c.parse_asm(r"""
|
||||
module @parseSuccess {
|
||||
func @notUsed() attributes { sym_visibility = "private" }
|
||||
func @f() {
|
||||
return
|
||||
}
|
||||
}
|
||||
""")
|
||||
# CHECK: func @notUsed
|
||||
# CHECK: func @f
|
||||
print(m.to_asm())
|
||||
|
||||
# CHECK: PASS COUNT: 0
|
||||
print("PASS COUNT:", len(pm))
|
||||
|
||||
pm.addPassPipelines("canonicalize", "symbol-dce")
|
||||
# Note: not checking the actual count since these may expand to more than
|
||||
# two passes.
|
||||
# CHECK: PASS COUNT:
|
||||
print("PASS COUNT:", len(pm))
|
||||
# CHECK: PASSES: canonicalize, symbol-dce
|
||||
print("PASSES:", str(pm))
|
||||
pm.run(m)
|
||||
print(m.to_asm())
|
||||
# CHECK-NOT: func @notUsed
|
||||
|
||||
test_utils.end_filecheck_test(__file__)
|
|
@ -7,6 +7,7 @@ import sys
|
|||
|
||||
TEST_MODULES = (
|
||||
"npcomp.mlir_ir_test",
|
||||
"npcomp.mlir_pass_test",
|
||||
"npcomp.dialect.Basicpy",
|
||||
"npcomp.dialect.Numpy",
|
||||
"npcomp.tracing.context",
|
||||
|
|
|
@ -27,6 +27,7 @@ set(extension_target NPCOMPNativePyExt)
|
|||
set(extension_pybind_sources
|
||||
MlirInit.cpp
|
||||
MlirIr.cpp
|
||||
MlirPass.cpp
|
||||
NpcompDialect.cpp
|
||||
NpcompModule.cpp
|
||||
PybindUtils.cpp
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||
|
@ -35,6 +36,7 @@ bool npcompMlirInitialize() {
|
|||
|
||||
// Global registration.
|
||||
::mlir::registerAllDialects();
|
||||
::mlir::registerAllPasses();
|
||||
|
||||
// Local registration.
|
||||
registerDialect<NPCOMP::Basicpy::BasicpyDialect>();
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
//===- MlirIr.cpp - MLIR IR Bindings --------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "MlirPass.h"
|
||||
#include "NpcompModule.h"
|
||||
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module initialization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void defineMlirPassModule(py::module m) {
|
||||
m.doc() = "Python bindings for mlir pass infra";
|
||||
|
||||
PyPassManager::bind(m);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PassManager
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyPassManager::bind(py::module m) {
|
||||
py::class_<PyPassManager>(m, "PassManager")
|
||||
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
|
||||
py::arg("verifyModules") = true)
|
||||
.def("enableCrashReproducerGeneration",
|
||||
[](PyPassManager &self, std::string outputFile,
|
||||
bool genLocalReproducer) {
|
||||
self.passManager.enableCrashReproducerGeneration(
|
||||
outputFile, genLocalReproducer);
|
||||
},
|
||||
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
|
||||
.def("__len__",
|
||||
[](PyPassManager &self) { return self.passManager.size(); })
|
||||
.def("__str__",
|
||||
[](PyPassManager &self) {
|
||||
std::string spec;
|
||||
llvm::raw_string_ostream stream(spec);
|
||||
self.passManager.printAsTextualPipeline(stream);
|
||||
return spec;
|
||||
})
|
||||
.def("run",
|
||||
[](PyPassManager &self, PyModuleOp &module) {
|
||||
if (module.context.get() != self.context.get()) {
|
||||
throw py::raiseValueError(
|
||||
"Expected a module with the same context "
|
||||
"as the PassManager");
|
||||
}
|
||||
if (failed(self.passManager.run(module.moduleOp))) {
|
||||
// TODO: Wrap propagate context diagnostics
|
||||
throw py::raisePyError(PyExc_RuntimeError,
|
||||
"Could not run passes");
|
||||
}
|
||||
})
|
||||
.def("addPassPipelines", [](PyPassManager &self, py::args passPipelines) {
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
for (auto pyPassPipeline : passPipelines) {
|
||||
auto passPipeline = pyPassPipeline.cast<std::string>();
|
||||
if (failed(mlir::parsePassPipeline(passPipeline, self.passManager,
|
||||
error_stream))) {
|
||||
std::string message = "failed to parse pass pipeline '";
|
||||
message.append(passPipeline);
|
||||
message.append("': ");
|
||||
message.append(error);
|
||||
throw py::raiseValueError(message);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlir
|
|
@ -0,0 +1,32 @@
|
|||
//===- MlirPass.h - MLIR Pass Bindings ------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_NATIVE_MLIR_PASS_H
|
||||
#define NPCOMP_PYTHON_NATIVE_MLIR_PASS_H
|
||||
|
||||
#include "MlirIr.h"
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
struct PyPassManager {
|
||||
PyPassManager(std::shared_ptr<PyContext> context, bool verifyModules)
|
||||
: context(std::move(context)),
|
||||
passManager(&context->context, verifyModules) {}
|
||||
static void bind(py::module m);
|
||||
PassManager passManager;
|
||||
|
||||
private:
|
||||
std::shared_ptr<PyContext> context;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_PYTHON_NATIVE_MLIR_PASS_H
|
|
@ -62,6 +62,9 @@ PYBIND11_MODULE(_npcomp, m) {
|
|||
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
|
||||
auto mlir_ir_m = mlir_m.def_submodule("ir");
|
||||
defineMlirIrModule(mlir_ir_m);
|
||||
// Note: not "pass" because it is a reserved word
|
||||
auto mlir_pass_m = mlir_m.def_submodule("passes");
|
||||
defineMlirPassModule(mlir_pass_m);
|
||||
|
||||
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
|
||||
defineNpcompDialect(npcomp_dialect);
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
namespace mlir {
|
||||
void defineMlirIrModule(py::module m);
|
||||
void defineMlirPassModule(py::module m);
|
||||
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
|
Loading…
Reference in New Issue