torch-mlir/frontends/pytorch/csrc/builder/module_builder.cpp

193 lines
7.1 KiB
C++
Raw Normal View History

//===- module_builder.cpp -------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "module_builder.h"
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
#include "function_importer.h"
#include "ivalue_importer.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/Registration.h"
#include "npcomp-c/Registration.h"
namespace py = pybind11;
using namespace torch_mlir;
static py::object getMlirIrClass(const char *className) {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className);
}
static py::object createPythonContextIfNone(py::object contextObj) {
if (contextObj.is_none()) {
contextObj = getMlirIrClass("Context")();
}
return contextObj;
}
static MlirContext castPythonObjectToMlirContext(py::object &contextObj) {
assert(!contextObj.is_none() && "context cannot be None");
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
if (mlirContextIsNull(context)) {
// An error will have already been set by the above.
throw py::error_already_set();
}
return context;
}
static py::object castMlirModuleToPythonObject(MlirModule module) {
auto moduleClass = getMlirIrClass("Module");
auto moduleCapsule =
py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(module));
return moduleClass.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(moduleCapsule);
}
static MlirModule createEmptyModule(MlirContext context) {
// TODO: Extract location from backtrace.
MlirLocation loc = mlirLocationUnknownGet(context);
return mlirModuleCreateEmpty(loc);
}
static std::string
stringifyMlirDiagnosticSeverity(MlirDiagnosticSeverity severity) {
switch (severity) {
case MlirDiagnosticError:
return "error";
case MlirDiagnosticWarning:
return "warning";
case MlirDiagnosticNote:
return "note";
case MlirDiagnosticRemark:
return "remark";
default:
return "<unknown severity>";
}
}
static void printDiagnostic(MlirDiagnostic diagnostic) {
std::stringstream ss;
ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
<< ": ";
auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) {
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
ssp->write(s.data, s.length);
};
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
// Use pybind11's print:
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
py::print(ss.str(),
py::arg("file") = py::module_::import("sys").attr("stderr"));
}
// Register a diagnostic handler that will redirect output to `sys.stderr`
// instead of a C/C++-level file abstraction. This ensures, for example,
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
void *) -> MlirLogicalResult {
printDiagnostic(diagnostic);
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
}
return mlirLogicalResultSuccess();
};
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
context, diagnosticHandler, nullptr, [](void *) { return; });
// Ignore the ID. We intend to keep this handler for the entire lifetime
// of this context.
(void)id;
}
ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
: contextObj(createPythonContextIfNone(std::move(contextObj))),
context(castPythonObjectToMlirContext(this->contextObj)),
module(createEmptyModule(this->context)),
moduleObj(castMlirModuleToPythonObject(module)),
unknownLoc(mlirLocationUnknownGet(context)), typeMapper(this->context) {
// TODO: Rework this once dialect registration C-APIs are in place.
// https://reviews.llvm.org/D88162
mlirRegisterAllDialects(context);
npcompRegisterAllDialects(context);
registerPythonSysStderrDiagnosticHandler(context);
// Terminator will always be the first op of an empty module.
terminator = mlirBlockGetFirstOperation(getBodyBlock());
}
std::shared_ptr<AcapController>
ModuleBuilder::startCaptureFunction(std::string &name,
std::vector<at::Tensor> args) {
// TODO: Verify that arguments do not alias each other.
std::vector<MlirType> inputTypes;
for (auto &arg : args) {
inputTypes.push_back(typeMapper.forwardTensorToType(arg));
}
// TODO: Extract a traceback and use in place of unknownLoc.
auto inserter = createInserter();
auto funcBuilder =
FuncBuilder::createFunction(inserter, unknownLoc, name, inputTypes);
// Map block arguments.
MlirBlock entryBlock = funcBuilder->getEntryBlock();
assert(mlirBlockGetNumArguments(entryBlock) ==
static_cast<intptr_t>(args.size()) &&
"entry block incorrect arg arity");
for (size_t i = 0; i < args.size(); ++i) {
funcBuilder->mapTensor(args[i], mlirBlockGetArgument(entryBlock, i));
}
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
}
torch::jit::StrongFunctionPtr
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator;
Properly model "derefinement". In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
2021-03-02 09:24:15 +08:00
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
return function;
}
void ModuleBuilder::importModule(torch::jit::Module jitModule,
py::object maybeClassAnnotator) {
ClassAnnotator dummyAnnotator;
ClassAnnotator *classAnnotator = &dummyAnnotator;
if (!maybeClassAnnotator.is_none()) {
classAnnotator = py::cast<ClassAnnotator *>(maybeClassAnnotator);
}
importIValue(jitModule._ivalue(), mlirModuleGetBody(module),
mlirModuleGetContext(module), *classAnnotator);
}
FuncBuilder::Inserter ModuleBuilder::createInserter() {
MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator;
return [=](MlirOperation op) {
mlirBlockInsertOwnedOperationBefore(block, terminator, op);
};
}
MlirBlock ModuleBuilder::getBodyBlock() {
MlirOperation moduleOp = mlirModuleGetOperation(module);
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
}
void ModuleBuilder::bind(py::module &m) {
py::class_<ModuleBuilder>(m, "ModuleBuilder")
.def(py::init<py::object>(), py::arg("context") = py::none())
.def_property_readonly("context", &ModuleBuilder::getContextObj)
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
.def("capture_function", &ModuleBuilder::startCaptureFunction,
py::keep_alive<0, 1>())
.def("import_function", &ModuleBuilder::importFunction)
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
py::arg("classAnnotator") = py::none());
}