Add python binding for running passes.

pull/1/head
Stella Laurenzo 2020-06-03 01:29:59 -07:00
parent bb871e7601
commit fddf41ca92
8 changed files with 163 additions and 0 deletions

View File

@ -0,0 +1,43 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Test for the MLIR Pass Python bindings"""
from _npcomp.mlir import ir
from _npcomp.mlir import passes
from npcomp.utils import test_utils
test_utils.start_filecheck_test()
c = ir.MLIRContext()
pm = passes.PassManager(c)
# CHECK-LABEL: module @parseSuccess
m = c.parse_asm(r"""
module @parseSuccess {
func @notUsed() attributes { sym_visibility = "private" }
func @f() {
return
}
}
""")
# CHECK: func @notUsed
# CHECK: func @f
print(m.to_asm())
# CHECK: PASS COUNT: 0
print("PASS COUNT:", len(pm))
pm.addPassPipelines("canonicalize", "symbol-dce")
# Note: not checking the actual count since these may expand to more than
# two passes.
# CHECK: PASS COUNT:
print("PASS COUNT:", len(pm))
# CHECK: PASSES: canonicalize, symbol-dce
print("PASSES:", str(pm))
pm.run(m)
print(m.to_asm())
# CHECK-NOT: func @notUsed
test_utils.end_filecheck_test(__file__)

View File

@ -7,6 +7,7 @@ import sys
TEST_MODULES = ( TEST_MODULES = (
"npcomp.mlir_ir_test", "npcomp.mlir_ir_test",
"npcomp.mlir_pass_test",
"npcomp.dialect.Basicpy", "npcomp.dialect.Basicpy",
"npcomp.dialect.Numpy", "npcomp.dialect.Numpy",
"npcomp.tracing.context", "npcomp.tracing.context",

View File

@ -27,6 +27,7 @@ set(extension_target NPCOMPNativePyExt)
set(extension_pybind_sources set(extension_pybind_sources
MlirInit.cpp MlirInit.cpp
MlirIr.cpp MlirIr.cpp
MlirPass.cpp
NpcompDialect.cpp NpcompDialect.cpp
NpcompModule.cpp NpcompModule.cpp
PybindUtils.cpp PybindUtils.cpp

View File

@ -8,6 +8,7 @@
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/InitAllDialects.h" #include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/NumpyDialect.h" #include "npcomp/Dialect/Numpy/NumpyDialect.h"
@ -35,6 +36,7 @@ bool npcompMlirInitialize() {
// Global registration. // Global registration.
::mlir::registerAllDialects(); ::mlir::registerAllDialects();
::mlir::registerAllPasses();
// Local registration. // Local registration.
registerDialect<NPCOMP::Basicpy::BasicpyDialect>(); registerDialect<NPCOMP::Basicpy::BasicpyDialect>();

View File

@ -0,0 +1,80 @@
//===- MlirIr.cpp - MLIR IR Bindings --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "MlirPass.h"
#include "NpcompModule.h"
#include "mlir/Pass/PassRegistry.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// Module initialization
//===----------------------------------------------------------------------===//
void defineMlirPassModule(py::module m) {
m.doc() = "Python bindings for mlir pass infra";
PyPassManager::bind(m);
}
//===----------------------------------------------------------------------===//
// PassManager
//===----------------------------------------------------------------------===//
void PyPassManager::bind(py::module m) {
py::class_<PyPassManager>(m, "PassManager")
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
py::arg("verifyModules") = true)
.def("enableCrashReproducerGeneration",
[](PyPassManager &self, std::string outputFile,
bool genLocalReproducer) {
self.passManager.enableCrashReproducerGeneration(
outputFile, genLocalReproducer);
},
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
.def("__len__",
[](PyPassManager &self) { return self.passManager.size(); })
.def("__str__",
[](PyPassManager &self) {
std::string spec;
llvm::raw_string_ostream stream(spec);
self.passManager.printAsTextualPipeline(stream);
return spec;
})
.def("run",
[](PyPassManager &self, PyModuleOp &module) {
if (module.context.get() != self.context.get()) {
throw py::raiseValueError(
"Expected a module with the same context "
"as the PassManager");
}
if (failed(self.passManager.run(module.moduleOp))) {
// TODO: Wrap propagate context diagnostics
throw py::raisePyError(PyExc_RuntimeError,
"Could not run passes");
}
})
.def("addPassPipelines", [](PyPassManager &self, py::args passPipelines) {
std::string error;
llvm::raw_string_ostream error_stream(error);
for (auto pyPassPipeline : passPipelines) {
auto passPipeline = pyPassPipeline.cast<std::string>();
if (failed(mlir::parsePassPipeline(passPipeline, self.passManager,
error_stream))) {
std::string message = "failed to parse pass pipeline '";
message.append(passPipeline);
message.append("': ");
message.append(error);
throw py::raiseValueError(message);
}
}
});
}
} // namespace mlir

View File

@ -0,0 +1,32 @@
//===- MlirPass.h - MLIR Pass Bindings ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_NATIVE_MLIR_PASS_H
#define NPCOMP_PYTHON_NATIVE_MLIR_PASS_H
#include "MlirIr.h"
#include "PybindUtils.h"
#include "mlir/Pass/PassManager.h"
namespace mlir {
struct PyPassManager {
PyPassManager(std::shared_ptr<PyContext> context, bool verifyModules)
: context(std::move(context)),
passManager(&context->context, verifyModules) {}
static void bind(py::module m);
PassManager passManager;
private:
std::shared_ptr<PyContext> context;
};
} // namespace mlir
#endif // NPCOMP_PYTHON_NATIVE_MLIR_PASS_H

View File

@ -62,6 +62,9 @@ PYBIND11_MODULE(_npcomp, m) {
auto mlir_m = m.def_submodule("mlir", "MLIR interop"); auto mlir_m = m.def_submodule("mlir", "MLIR interop");
auto mlir_ir_m = mlir_m.def_submodule("ir"); auto mlir_ir_m = mlir_m.def_submodule("ir");
defineMlirIrModule(mlir_ir_m); defineMlirIrModule(mlir_ir_m);
// Note: not "pass" because it is a reserved word
auto mlir_pass_m = mlir_m.def_submodule("passes");
defineMlirPassModule(mlir_pass_m);
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects"); auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
defineNpcompDialect(npcomp_dialect); defineNpcompDialect(npcomp_dialect);

View File

@ -13,6 +13,7 @@
namespace mlir { namespace mlir {
void defineMlirIrModule(py::module m); void defineMlirIrModule(py::module m);
void defineMlirPassModule(py::module m);
namespace npcomp { namespace npcomp {
namespace python { namespace python {