Create public API for torch_mlir python code.

* Adds a trampoline/loader 'torch_mlir' module.
* Plumbs through the MLIR python Context and Module creation, interoping with the MLIR Python API (resolves TODO on creating with own context and accessing the module being built).
* Inter-module Python API interop is still a bit rough but workable via the capsule mechanism. Can be evolved later.
* Exports the frontends/pytorch python sources to the project python/ build directory.
* Requires D89294 to land.
pull/81/head
Stella Laurenzo 2020-10-12 21:39:48 -07:00
parent 86df4cabeb
commit 30cfc6499f
13 changed files with 107 additions and 92 deletions

View File

@ -12,3 +12,33 @@ function(npcomp_python_target_compile_options target)
/EHsc /GR>
)
endfunction()
function(npcomp_python_create_symlinks binary_dir source_dir)
# Do nothing if building in-source
if (${binary_dir} STREQUAL ${source_dir})
return()
endif()
file(GLOB_RECURSE python_files RELATIVE ${source_dir} *.py)
foreach (path_file ${python_files})
get_filename_component(folder ${path_file} PATH)
# Create REAL folder
file(MAKE_DIRECTORY "${binary_dir}/${folder}")
# Get OS dependent path to use in `execute_process`
file(TO_NATIVE_PATH "${binary_dir}/${path_file}" link)
file(TO_NATIVE_PATH "${source_dir}/${path_file}" target)
# TODO: Switch to copy on windows if symlink still not supported by
# then.
set(cmake_verb create_symlink)
execute_process(COMMAND ${CMAKE_COMMAND} -E ${cmake_verb} ${target} ${link}
RESULT_VARIABLE result
ERROR_VARIABLE output)
if (NOT ${result} EQUAL 0)
message(FATAL_ERROR "Could not create symbolic link for: ${target} --> ${output}")
endif()
endforeach(path_file)
endfunction(npcomp_python_create_symlinks)

@ -1 +1 @@
Subproject commit 75ae846de69cccd6ed66357f3ee3ad3301849d95
Subproject commit ad958f648e46680966375a93a3f2f1f5ee870671

View File

@ -20,4 +20,5 @@ if(NPCOMP_ENABLE_TORCH_TYPE_DISPATCH)
endif()
add_subdirectory(csrc)
add_subdirectory(python)
add_subdirectory(test)

View File

@ -7,6 +7,7 @@
#include "module_builder.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Registration.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
@ -15,38 +16,49 @@
namespace py = pybind11;
using namespace torch_mlir;
namespace {
/// Accumulates into a python string from a method that accepts an
/// MlirStringCallback.
/// TODO: Remove this once the MLIR Python objects are exposed directly.
struct PyPrintAccumulator {
py::list parts;
static py::object getMlirIrClass(const char *className) {
// Note that the "mlir" module may be a loader which internally sets up
// the child modules, so it must be resolved incrementally (vs "mlir.ir").
return py::module::import("mlir").attr("ir").attr(className);
}
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData);
py::str pyPart(part, size); // Decodes as UTF-8 by default.
printAccum->parts.append(std::move(pyPart));
};
static py::object createPythonContextIfNone(py::object contextObj) {
if (contextObj.is_none()) {
contextObj = getMlirIrClass("Context")();
}
return contextObj;
}
py::str join() {
py::str delim("", 0);
return delim.attr("join")(parts);
static MlirContext castPythonObjectToMlirContext(py::object &contextObj) {
assert(!contextObj.is_none() && "context cannot be None");
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
if (mlirContextIsNull(context)) {
// An error will have already been set by the above.
throw py::error_already_set();
}
};
} // namespace
return context;
}
ModuleBuilder::ModuleBuilder()
// TODO: Once the MLIR C/Python capsule API is in place, these should be
// derived from Python level objects (which will provide better lifetime
// semantics and interop). Until then, they are just scoped to this instance
// and must not escape.
: context(mlirContextCreate()), unknownLoc(mlirLocationUnknownGet(context)),
module(mlirModuleCreateEmpty(unknownLoc)), typeMapper(context) {
static py::object castMlirModuleToPythonObject(MlirModule module) {
auto moduleClass = getMlirIrClass("Module");
auto moduleCapsule =
py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(module));
return moduleClass.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(moduleCapsule);
}
static MlirModule createEmptyModule(MlirContext context) {
// TODO: Extract location from backtrace.
MlirLocation loc = mlirLocationUnknownGet(context);
return mlirModuleCreateEmpty(loc);
}
ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
: contextObj(createPythonContextIfNone(std::move(contextObj))),
context(castPythonObjectToMlirContext(this->contextObj)),
module(createEmptyModule(this->context)),
moduleObj(castMlirModuleToPythonObject(module)),
unknownLoc(mlirLocationUnknownGet(context)), typeMapper(this->context) {
// TODO: Rework this once dialect registration C-APIs are in place.
// https://reviews.llvm.org/D88162
mlirRegisterAllDialects(context);
@ -56,19 +68,6 @@ ModuleBuilder::ModuleBuilder()
terminator = mlirBlockGetFirstOperation(getBodyBlock());
}
ModuleBuilder::~ModuleBuilder() {
mlirModuleDestroy(module);
mlirContextDestroy(context);
}
py::str ModuleBuilder::getAsm() {
MlirOperation operation = mlirModuleGetOperation(module);
PyPrintAccumulator printAccum;
mlirOperationPrint(operation, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
}
std::shared_ptr<AcapController>
ModuleBuilder::startCaptureFunction(std::string &name,
std::vector<at::Tensor> args) {
@ -103,8 +102,9 @@ MlirBlock ModuleBuilder::getBodyBlock() {
void ModuleBuilder::bind(py::module &m) {
py::class_<ModuleBuilder>(m, "ModuleBuilder")
.def(py::init<>())
.def("__str__", &ModuleBuilder::getAsm)
.def(py::init<py::object>(), py::arg("context") = py::none())
.def_property_readonly("context", &ModuleBuilder::getContextObj)
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
.def("capture_function", &ModuleBuilder::startCaptureFunction,
py::keep_alive<0, 1>());
}

View File

@ -8,7 +8,6 @@
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
// TODO: Remove this dep once the getAsm() method is removed.
#include "../pybind.h"
#include "acap_dispatch.h"
@ -24,14 +23,13 @@ namespace torch_mlir {
/// of PyTorch programs/execution.
class ModuleBuilder {
public:
ModuleBuilder();
~ModuleBuilder();
ModuleBuilder(pybind11::object contextObj);
/// Creates Python bindings for the class.
static void bind(pybind11::module &m);
// TODO: Remove this once the MLIR Python objects are exposed directly.
pybind11::str getAsm();
pybind11::object getContextObj() { return contextObj; }
pybind11::object getModuleObj() { return moduleObj; }
// Starts a device-capture based function.
// TODO: Add inputs.
@ -41,10 +39,15 @@ public:
private:
MlirBlock getBodyBlock();
// Capture references to the python-owned context and module. Ownership
// is delegated to python for these, and the C-API types are extracted via
// the capsule API.
pybind11::object contextObj;
MlirContext context;
MlirLocation unknownLoc;
MlirModule module;
pybind11::object moduleObj;
MlirOperation terminator;
MlirLocation unknownLoc;
TypeMapper typeMapper;
};

View File

@ -0,0 +1,6 @@
################################################################################
# Manage python source files
# Collapse all local python sources to the project level python/ directory.
################################################################################
npcomp_python_create_symlinks(${CMAKE_BINARY_DIR}/python ${CMAKE_CURRENT_SOURCE_DIR})

View File

@ -0,0 +1,14 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This is a trampoline module which loads the _torch_mlir native module
# and binds names locally. It exists to allow for customization of behavior
# prior to loading shared objects.
from _torch_mlir import ModuleBuilder
__all__ = [
"ModuleBuilder",
]

View File

@ -6,12 +6,12 @@
# TODO: Once stabilized, expand tests to include all argument dtypes.
import torch
import _torch_mlir as m
import torch_mlir
t0 = torch.randn((1,4))
t1 = torch.randn((4,1))
mb = m.ModuleBuilder()
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("foobar", [t0, t1]) as f:
result = t0 + t1
f.returns([result])
@ -23,7 +23,7 @@ with mb.capture_function("foobar", [t0, t1]) as f:
# CHECK: return %0 : !numpy.ndarray<[4,4]:f32>
# CHECK: }
# CHECK: }
print(mb)
print(mb.module)
# CHECK: CAPTURE: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
for line in f.get_debug_log(): print(line)

View File

@ -19,51 +19,12 @@ add_dependencies(NPCOMPPythonResources
################################################################################
# Manage python source files
################################################################################
function (create_symlinks)
# Do nothing if building in-source
if (${CMAKE_CURRENT_BINARY_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR})
return()
endif()
foreach (path_file ${ARGN})
get_filename_component(folder ${path_file} PATH)
# Create REAL folder
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${folder}")
# Delete symlink if it exists
file(REMOVE "${CMAKE_CURRENT_BINARY_DIR}/${path_file}")
# Get OS dependent path to use in `execute_process`
file(TO_NATIVE_PATH "${CMAKE_CURRENT_BINARY_DIR}/${path_file}" link)
file(TO_NATIVE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${path_file}" target)
if (UNIX)
set(command ln -s ${target} ${link})
else()
set(command cmd.exe /c mklink ${link} ${target})
endif()
execute_process(COMMAND ${command}
RESULT_VARIABLE result
ERROR_VARIABLE output)
if (NOT ${result} EQUAL 0)
message(FATAL_ERROR "Could not create symbolic link for: ${target} --> ${output}")
endif()
endforeach(path_file)
endfunction(create_symlinks)
file(GLOB_RECURSE python_files RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.py)
create_symlinks(${python_files})
npcomp_python_create_symlinks(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
################################################################################
# Native extensions
################################################################################
include(NpcompPython)
# Normally on unix-like platforms, extensions are built as "MODULE" libraries
# and do not explicitly link to the python shared object. This allows for
# come greater deployment flexibility since the extension will bind to