mirror of https://github.com/llvm/torch-mlir
Add python binding for running passes.
parent
bb871e7601
commit
fddf41ca92
|
@ -0,0 +1,43 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
"""Test for the MLIR Pass Python bindings"""
|
||||||
|
|
||||||
|
from _npcomp.mlir import ir
|
||||||
|
from _npcomp.mlir import passes
|
||||||
|
from npcomp.utils import test_utils
|
||||||
|
|
||||||
|
test_utils.start_filecheck_test()
|
||||||
|
c = ir.MLIRContext()
|
||||||
|
|
||||||
|
pm = passes.PassManager(c)
|
||||||
|
|
||||||
|
# CHECK-LABEL: module @parseSuccess
|
||||||
|
m = c.parse_asm(r"""
|
||||||
|
module @parseSuccess {
|
||||||
|
func @notUsed() attributes { sym_visibility = "private" }
|
||||||
|
func @f() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
# CHECK: func @notUsed
|
||||||
|
# CHECK: func @f
|
||||||
|
print(m.to_asm())
|
||||||
|
|
||||||
|
# CHECK: PASS COUNT: 0
|
||||||
|
print("PASS COUNT:", len(pm))
|
||||||
|
|
||||||
|
pm.addPassPipelines("canonicalize", "symbol-dce")
|
||||||
|
# Note: not checking the actual count since these may expand to more than
|
||||||
|
# two passes.
|
||||||
|
# CHECK: PASS COUNT:
|
||||||
|
print("PASS COUNT:", len(pm))
|
||||||
|
# CHECK: PASSES: canonicalize, symbol-dce
|
||||||
|
print("PASSES:", str(pm))
|
||||||
|
pm.run(m)
|
||||||
|
print(m.to_asm())
|
||||||
|
# CHECK-NOT: func @notUsed
|
||||||
|
|
||||||
|
test_utils.end_filecheck_test(__file__)
|
|
@ -7,6 +7,7 @@ import sys
|
||||||
|
|
||||||
TEST_MODULES = (
|
TEST_MODULES = (
|
||||||
"npcomp.mlir_ir_test",
|
"npcomp.mlir_ir_test",
|
||||||
|
"npcomp.mlir_pass_test",
|
||||||
"npcomp.dialect.Basicpy",
|
"npcomp.dialect.Basicpy",
|
||||||
"npcomp.dialect.Numpy",
|
"npcomp.dialect.Numpy",
|
||||||
"npcomp.tracing.context",
|
"npcomp.tracing.context",
|
||||||
|
|
|
@ -27,6 +27,7 @@ set(extension_target NPCOMPNativePyExt)
|
||||||
set(extension_pybind_sources
|
set(extension_pybind_sources
|
||||||
MlirInit.cpp
|
MlirInit.cpp
|
||||||
MlirIr.cpp
|
MlirIr.cpp
|
||||||
|
MlirPass.cpp
|
||||||
NpcompDialect.cpp
|
NpcompDialect.cpp
|
||||||
NpcompModule.cpp
|
NpcompModule.cpp
|
||||||
PybindUtils.cpp
|
PybindUtils.cpp
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/InitAllDialects.h"
|
#include "mlir/InitAllDialects.h"
|
||||||
|
#include "mlir/InitAllPasses.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||||
|
@ -35,6 +36,7 @@ bool npcompMlirInitialize() {
|
||||||
|
|
||||||
// Global registration.
|
// Global registration.
|
||||||
::mlir::registerAllDialects();
|
::mlir::registerAllDialects();
|
||||||
|
::mlir::registerAllPasses();
|
||||||
|
|
||||||
// Local registration.
|
// Local registration.
|
||||||
registerDialect<NPCOMP::Basicpy::BasicpyDialect>();
|
registerDialect<NPCOMP::Basicpy::BasicpyDialect>();
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
//===- MlirIr.cpp - MLIR IR Bindings --------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "MlirPass.h"
|
||||||
|
#include "NpcompModule.h"
|
||||||
|
|
||||||
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Module initialization
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void defineMlirPassModule(py::module m) {
|
||||||
|
m.doc() = "Python bindings for mlir pass infra";
|
||||||
|
|
||||||
|
PyPassManager::bind(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PassManager
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void PyPassManager::bind(py::module m) {
|
||||||
|
py::class_<PyPassManager>(m, "PassManager")
|
||||||
|
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
|
||||||
|
py::arg("verifyModules") = true)
|
||||||
|
.def("enableCrashReproducerGeneration",
|
||||||
|
[](PyPassManager &self, std::string outputFile,
|
||||||
|
bool genLocalReproducer) {
|
||||||
|
self.passManager.enableCrashReproducerGeneration(
|
||||||
|
outputFile, genLocalReproducer);
|
||||||
|
},
|
||||||
|
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
|
||||||
|
.def("__len__",
|
||||||
|
[](PyPassManager &self) { return self.passManager.size(); })
|
||||||
|
.def("__str__",
|
||||||
|
[](PyPassManager &self) {
|
||||||
|
std::string spec;
|
||||||
|
llvm::raw_string_ostream stream(spec);
|
||||||
|
self.passManager.printAsTextualPipeline(stream);
|
||||||
|
return spec;
|
||||||
|
})
|
||||||
|
.def("run",
|
||||||
|
[](PyPassManager &self, PyModuleOp &module) {
|
||||||
|
if (module.context.get() != self.context.get()) {
|
||||||
|
throw py::raiseValueError(
|
||||||
|
"Expected a module with the same context "
|
||||||
|
"as the PassManager");
|
||||||
|
}
|
||||||
|
if (failed(self.passManager.run(module.moduleOp))) {
|
||||||
|
// TODO: Wrap propagate context diagnostics
|
||||||
|
throw py::raisePyError(PyExc_RuntimeError,
|
||||||
|
"Could not run passes");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.def("addPassPipelines", [](PyPassManager &self, py::args passPipelines) {
|
||||||
|
std::string error;
|
||||||
|
llvm::raw_string_ostream error_stream(error);
|
||||||
|
for (auto pyPassPipeline : passPipelines) {
|
||||||
|
auto passPipeline = pyPassPipeline.cast<std::string>();
|
||||||
|
if (failed(mlir::parsePassPipeline(passPipeline, self.passManager,
|
||||||
|
error_stream))) {
|
||||||
|
std::string message = "failed to parse pass pipeline '";
|
||||||
|
message.append(passPipeline);
|
||||||
|
message.append("': ");
|
||||||
|
message.append(error);
|
||||||
|
throw py::raiseValueError(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,32 @@
|
||||||
|
//===- MlirPass.h - MLIR Pass Bindings ------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_PYTHON_NATIVE_MLIR_PASS_H
|
||||||
|
#define NPCOMP_PYTHON_NATIVE_MLIR_PASS_H
|
||||||
|
|
||||||
|
#include "MlirIr.h"
|
||||||
|
#include "PybindUtils.h"
|
||||||
|
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
struct PyPassManager {
|
||||||
|
PyPassManager(std::shared_ptr<PyContext> context, bool verifyModules)
|
||||||
|
: context(std::move(context)),
|
||||||
|
passManager(&context->context, verifyModules) {}
|
||||||
|
static void bind(py::module m);
|
||||||
|
PassManager passManager;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<PyContext> context;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_PYTHON_NATIVE_MLIR_PASS_H
|
|
@ -62,6 +62,9 @@ PYBIND11_MODULE(_npcomp, m) {
|
||||||
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
|
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
|
||||||
auto mlir_ir_m = mlir_m.def_submodule("ir");
|
auto mlir_ir_m = mlir_m.def_submodule("ir");
|
||||||
defineMlirIrModule(mlir_ir_m);
|
defineMlirIrModule(mlir_ir_m);
|
||||||
|
// Note: not "pass" because it is a reserved word
|
||||||
|
auto mlir_pass_m = mlir_m.def_submodule("passes");
|
||||||
|
defineMlirPassModule(mlir_pass_m);
|
||||||
|
|
||||||
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
|
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
|
||||||
defineNpcompDialect(npcomp_dialect);
|
defineNpcompDialect(npcomp_dialect);
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
void defineMlirIrModule(py::module m);
|
void defineMlirIrModule(py::module m);
|
||||||
|
void defineMlirPassModule(py::module m);
|
||||||
|
|
||||||
namespace npcomp {
|
namespace npcomp {
|
||||||
namespace python {
|
namespace python {
|
||||||
|
|
Loading…
Reference in New Issue