2020-09-29 09:36:00 +08:00
|
|
|
//===- module_builder.cpp -------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file is licensed under a pytorch-style license
|
|
|
|
// See frontends/pytorch/LICENSE for license information.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "module_builder.h"
|
|
|
|
|
2020-11-21 09:03:23 +08:00
|
|
|
#include "graph_importer.h"
|
2021-01-28 08:35:44 +08:00
|
|
|
#include "ivalue_importer.h"
|
2020-11-21 09:03:23 +08:00
|
|
|
|
2020-10-13 12:39:48 +08:00
|
|
|
#include "mlir-c/Bindings/Python/Interop.h"
|
2020-12-12 06:43:38 +08:00
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
|
|
#include "mlir-c/BuiltinTypes.h"
|
2021-02-18 07:50:13 +08:00
|
|
|
#include "mlir-c/Diagnostics.h"
|
2020-09-29 09:36:00 +08:00
|
|
|
#include "mlir-c/Registration.h"
|
|
|
|
#include "npcomp-c/Registration.h"
|
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
using namespace torch_mlir;
|
|
|
|
|
2020-10-13 12:39:48 +08:00
|
|
|
static py::object getMlirIrClass(const char *className) {
|
|
|
|
// Note that the "mlir" module may be a loader which internally sets up
|
|
|
|
// the child modules, so it must be resolved incrementally (vs "mlir.ir").
|
|
|
|
return py::module::import("mlir").attr("ir").attr(className);
|
|
|
|
}
|
2020-09-29 09:36:00 +08:00
|
|
|
|
2020-10-13 12:39:48 +08:00
|
|
|
static py::object createPythonContextIfNone(py::object contextObj) {
|
|
|
|
if (contextObj.is_none()) {
|
|
|
|
contextObj = getMlirIrClass("Context")();
|
2020-09-29 09:36:00 +08:00
|
|
|
}
|
2020-10-13 12:39:48 +08:00
|
|
|
return contextObj;
|
|
|
|
}
|
2020-09-29 09:36:00 +08:00
|
|
|
|
2020-10-13 12:39:48 +08:00
|
|
|
static MlirContext castPythonObjectToMlirContext(py::object &contextObj) {
|
|
|
|
assert(!contextObj.is_none() && "context cannot be None");
|
|
|
|
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
|
|
|
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
|
|
|
|
if (mlirContextIsNull(context)) {
|
|
|
|
// An error will have already been set by the above.
|
|
|
|
throw py::error_already_set();
|
2020-09-29 09:36:00 +08:00
|
|
|
}
|
2020-10-13 12:39:48 +08:00
|
|
|
return context;
|
|
|
|
}
|
|
|
|
|
|
|
|
static py::object castMlirModuleToPythonObject(MlirModule module) {
|
|
|
|
auto moduleClass = getMlirIrClass("Module");
|
|
|
|
auto moduleCapsule =
|
|
|
|
py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(module));
|
|
|
|
return moduleClass.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(moduleCapsule);
|
|
|
|
}
|
2020-09-29 09:36:00 +08:00
|
|
|
|
2020-10-13 12:39:48 +08:00
|
|
|
static MlirModule createEmptyModule(MlirContext context) {
|
|
|
|
// TODO: Extract location from backtrace.
|
|
|
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
|
|
return mlirModuleCreateEmpty(loc);
|
|
|
|
}
|
|
|
|
|
2021-02-18 07:50:13 +08:00
|
|
|
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
|
|
|
// instead of a C/C++-level file abstraction. This ensures, for example,
|
|
|
|
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
|
|
|
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
|
|
|
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
|
|
|
void *) -> MlirLogicalResult {
|
|
|
|
std::stringstream ss;
|
|
|
|
auto stringCallback =
|
|
|
|
[](MlirStringRef s, void *stringCallbackUserData) {
|
|
|
|
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
|
|
|
|
ssp->write(s.data, s.length);
|
|
|
|
};
|
|
|
|
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
|
|
|
|
// Use pybind11's print:
|
|
|
|
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
|
|
|
using namespace pybind11::literals;
|
|
|
|
py::print(ss.str(), "file"_a = py::module_::import("sys").attr("stderr"));
|
|
|
|
return mlirLogicalResultSuccess();
|
|
|
|
};
|
|
|
|
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
|
|
|
context, diagnosticHandler, nullptr, [](void *) { return; });
|
|
|
|
// Ignore the ID. We intend to keep this handler for the entire lifetime
|
|
|
|
// of this context.
|
|
|
|
(void)id;
|
|
|
|
}
|
|
|
|
|
2020-10-13 12:39:48 +08:00
|
|
|
ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
|
|
|
: contextObj(createPythonContextIfNone(std::move(contextObj))),
|
|
|
|
context(castPythonObjectToMlirContext(this->contextObj)),
|
|
|
|
module(createEmptyModule(this->context)),
|
|
|
|
moduleObj(castMlirModuleToPythonObject(module)),
|
|
|
|
unknownLoc(mlirLocationUnknownGet(context)), typeMapper(this->context) {
|
2020-09-29 09:36:00 +08:00
|
|
|
// TODO: Rework this once dialect registration C-APIs are in place.
|
|
|
|
// https://reviews.llvm.org/D88162
|
|
|
|
mlirRegisterAllDialects(context);
|
|
|
|
npcompRegisterAllDialects(context);
|
2020-10-02 09:59:58 +08:00
|
|
|
|
2021-02-18 07:50:13 +08:00
|
|
|
registerPythonSysStderrDiagnosticHandler(context);
|
|
|
|
|
2020-10-02 09:59:58 +08:00
|
|
|
// Terminator will always be the first op of an empty module.
|
|
|
|
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
2020-09-29 09:36:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
std::shared_ptr<AcapController>
|
2020-10-02 09:59:58 +08:00
|
|
|
ModuleBuilder::startCaptureFunction(std::string &name,
|
|
|
|
std::vector<at::Tensor> args) {
|
|
|
|
// TODO: Verify that arguments do not alias each other.
|
2020-12-15 00:42:42 +08:00
|
|
|
std::vector<MlirType> inputTypes;
|
2020-10-02 09:59:58 +08:00
|
|
|
for (auto &arg : args) {
|
|
|
|
inputTypes.push_back(typeMapper.forwardTensorToType(arg));
|
|
|
|
}
|
2020-09-29 09:36:00 +08:00
|
|
|
|
2020-10-02 09:59:58 +08:00
|
|
|
// TODO: Extract a traceback and use in place of unknownLoc.
|
2020-11-21 09:03:23 +08:00
|
|
|
auto inserter = createInserter();
|
2020-10-02 09:59:58 +08:00
|
|
|
auto funcBuilder =
|
2020-11-21 09:03:23 +08:00
|
|
|
FuncBuilder::createFunction(inserter, unknownLoc, name, inputTypes);
|
2020-10-02 09:59:58 +08:00
|
|
|
// Map block arguments.
|
|
|
|
MlirBlock entryBlock = funcBuilder->getEntryBlock();
|
|
|
|
assert(mlirBlockGetNumArguments(entryBlock) ==
|
|
|
|
static_cast<intptr_t>(args.size()) &&
|
|
|
|
"entry block incorrect arg arity");
|
2020-12-15 00:42:42 +08:00
|
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
|
|
funcBuilder->mapTensor(args[i], mlirBlockGetArgument(entryBlock, i));
|
2020-09-29 09:36:00 +08:00
|
|
|
}
|
2020-10-06 14:21:21 +08:00
|
|
|
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
|
2020-10-02 09:59:58 +08:00
|
|
|
}
|
2020-09-29 09:36:00 +08:00
|
|
|
|
2020-11-24 06:41:30 +08:00
|
|
|
torch::jit::StrongFunctionPtr
|
|
|
|
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
2021-02-02 09:59:42 +08:00
|
|
|
MlirBlock block = getBodyBlock();
|
|
|
|
MlirOperation terminator = this->terminator;
|
|
|
|
MlirOperation func = importGraphAsFuncOp(
|
|
|
|
context, function.function_->graph().get(), function.function_->name());
|
|
|
|
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
2020-11-24 06:41:30 +08:00
|
|
|
return function;
|
2020-11-21 09:03:23 +08:00
|
|
|
}
|
|
|
|
|
2021-02-20 08:21:21 +08:00
|
|
|
void ModuleBuilder::importModule(torch::jit::Module jitModule,
|
|
|
|
py::object maybeClassAnnotator) {
|
|
|
|
ClassAnnotator dummyAnnotator;
|
|
|
|
ClassAnnotator *classAnnotator = &dummyAnnotator;
|
|
|
|
if (!maybeClassAnnotator.is_none()) {
|
|
|
|
classAnnotator = py::cast<ClassAnnotator *>(maybeClassAnnotator);
|
|
|
|
}
|
2021-01-28 08:35:44 +08:00
|
|
|
importIValue(jitModule._ivalue(), mlirModuleGetBody(module),
|
2021-02-20 08:21:21 +08:00
|
|
|
mlirModuleGetContext(module), *classAnnotator);
|
2021-01-28 08:35:44 +08:00
|
|
|
}
|
|
|
|
|
2020-11-21 09:03:23 +08:00
|
|
|
FuncBuilder::Inserter ModuleBuilder::createInserter() {
|
|
|
|
MlirBlock block = getBodyBlock();
|
|
|
|
MlirOperation terminator = this->terminator;
|
|
|
|
return [=](MlirOperation op) {
|
|
|
|
mlirBlockInsertOwnedOperationBefore(block, terminator, op);
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
2020-10-02 09:59:58 +08:00
|
|
|
MlirBlock ModuleBuilder::getBodyBlock() {
|
|
|
|
MlirOperation moduleOp = mlirModuleGetOperation(module);
|
|
|
|
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
2020-09-29 09:36:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void ModuleBuilder::bind(py::module &m) {
|
|
|
|
py::class_<ModuleBuilder>(m, "ModuleBuilder")
|
2020-10-13 12:39:48 +08:00
|
|
|
.def(py::init<py::object>(), py::arg("context") = py::none())
|
|
|
|
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
|
|
|
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
2020-09-29 09:36:00 +08:00
|
|
|
.def("capture_function", &ModuleBuilder::startCaptureFunction,
|
2020-11-21 09:03:23 +08:00
|
|
|
py::keep_alive<0, 1>())
|
2021-01-28 08:35:44 +08:00
|
|
|
.def("import_function", &ModuleBuilder::importFunction)
|
2021-02-20 08:21:21 +08:00
|
|
|
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
|
|
|
py::arg("classAnnotator") = py::none());
|
2020-09-29 09:36:00 +08:00
|
|
|
}
|