mirror of https://github.com/llvm/torch-mlir
Create public API for torch_mlir python code.
* Adds a trampoline/loader 'torch_mlir' module. * Plumbs through the MLIR python Context and Module creation, interoping with the MLIR Python API (resolves TODO on creating with own context and accessing the module being built). * Inter-module Python API interop is still a bit rough but workable via the capsule mechanism. Can be evolved later. * Exports the frontends/pytorch python sources to the project python/ build directory. * Requires D89294 to land.pull/81/head
parent
86df4cabeb
commit
30cfc6499f
|
@ -12,3 +12,33 @@ function(npcomp_python_target_compile_options target)
|
||||||
/EHsc /GR>
|
/EHsc /GR>
|
||||||
)
|
)
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
|
function(npcomp_python_create_symlinks binary_dir source_dir)
|
||||||
|
# Do nothing if building in-source
|
||||||
|
if (${binary_dir} STREQUAL ${source_dir})
|
||||||
|
return()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
file(GLOB_RECURSE python_files RELATIVE ${source_dir} *.py)
|
||||||
|
foreach (path_file ${python_files})
|
||||||
|
get_filename_component(folder ${path_file} PATH)
|
||||||
|
|
||||||
|
# Create REAL folder
|
||||||
|
file(MAKE_DIRECTORY "${binary_dir}/${folder}")
|
||||||
|
|
||||||
|
# Get OS dependent path to use in `execute_process`
|
||||||
|
file(TO_NATIVE_PATH "${binary_dir}/${path_file}" link)
|
||||||
|
file(TO_NATIVE_PATH "${source_dir}/${path_file}" target)
|
||||||
|
|
||||||
|
# TODO: Switch to copy on windows if symlink still not supported by
|
||||||
|
# then.
|
||||||
|
set(cmake_verb create_symlink)
|
||||||
|
execute_process(COMMAND ${CMAKE_COMMAND} -E ${cmake_verb} ${target} ${link}
|
||||||
|
RESULT_VARIABLE result
|
||||||
|
ERROR_VARIABLE output)
|
||||||
|
|
||||||
|
if (NOT ${result} EQUAL 0)
|
||||||
|
message(FATAL_ERROR "Could not create symbolic link for: ${target} --> ${output}")
|
||||||
|
endif()
|
||||||
|
endforeach(path_file)
|
||||||
|
endfunction(npcomp_python_create_symlinks)
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 75ae846de69cccd6ed66357f3ee3ad3301849d95
|
Subproject commit ad958f648e46680966375a93a3f2f1f5ee870671
|
|
@ -20,4 +20,5 @@ if(NPCOMP_ENABLE_TORCH_TYPE_DISPATCH)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(csrc)
|
add_subdirectory(csrc)
|
||||||
|
add_subdirectory(python)
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
|
|
||||||
#include "module_builder.h"
|
#include "module_builder.h"
|
||||||
|
|
||||||
|
#include "mlir-c/Bindings/Python/Interop.h"
|
||||||
#include "mlir-c/Registration.h"
|
#include "mlir-c/Registration.h"
|
||||||
#include "mlir-c/StandardAttributes.h"
|
#include "mlir-c/StandardAttributes.h"
|
||||||
#include "mlir-c/StandardTypes.h"
|
#include "mlir-c/StandardTypes.h"
|
||||||
|
@ -15,38 +16,49 @@
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
namespace {
|
static py::object getMlirIrClass(const char *className) {
|
||||||
/// Accumulates into a python string from a method that accepts an
|
// Note that the "mlir" module may be a loader which internally sets up
|
||||||
/// MlirStringCallback.
|
// the child modules, so it must be resolved incrementally (vs "mlir.ir").
|
||||||
/// TODO: Remove this once the MLIR Python objects are exposed directly.
|
return py::module::import("mlir").attr("ir").attr(className);
|
||||||
struct PyPrintAccumulator {
|
|
||||||
py::list parts;
|
|
||||||
|
|
||||||
void *getUserData() { return this; }
|
|
||||||
|
|
||||||
MlirStringCallback getCallback() {
|
|
||||||
return [](const char *part, intptr_t size, void *userData) {
|
|
||||||
PyPrintAccumulator *printAccum =
|
|
||||||
static_cast<PyPrintAccumulator *>(userData);
|
|
||||||
py::str pyPart(part, size); // Decodes as UTF-8 by default.
|
|
||||||
printAccum->parts.append(std::move(pyPart));
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
py::str join() {
|
static py::object createPythonContextIfNone(py::object contextObj) {
|
||||||
py::str delim("", 0);
|
if (contextObj.is_none()) {
|
||||||
return delim.attr("join")(parts);
|
contextObj = getMlirIrClass("Context")();
|
||||||
|
}
|
||||||
|
return contextObj;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
ModuleBuilder::ModuleBuilder()
|
static MlirContext castPythonObjectToMlirContext(py::object &contextObj) {
|
||||||
// TODO: Once the MLIR C/Python capsule API is in place, these should be
|
assert(!contextObj.is_none() && "context cannot be None");
|
||||||
// derived from Python level objects (which will provide better lifetime
|
auto contextCapsule = contextObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||||
// semantics and interop). Until then, they are just scoped to this instance
|
MlirContext context = mlirPythonCapsuleToContext(contextCapsule.ptr());
|
||||||
// and must not escape.
|
if (mlirContextIsNull(context)) {
|
||||||
: context(mlirContextCreate()), unknownLoc(mlirLocationUnknownGet(context)),
|
// An error will have already been set by the above.
|
||||||
module(mlirModuleCreateEmpty(unknownLoc)), typeMapper(context) {
|
throw py::error_already_set();
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
|
static py::object castMlirModuleToPythonObject(MlirModule module) {
|
||||||
|
auto moduleClass = getMlirIrClass("Module");
|
||||||
|
auto moduleCapsule =
|
||||||
|
py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(module));
|
||||||
|
return moduleClass.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(moduleCapsule);
|
||||||
|
}
|
||||||
|
|
||||||
|
static MlirModule createEmptyModule(MlirContext context) {
|
||||||
|
// TODO: Extract location from backtrace.
|
||||||
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
return mlirModuleCreateEmpty(loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
||||||
|
: contextObj(createPythonContextIfNone(std::move(contextObj))),
|
||||||
|
context(castPythonObjectToMlirContext(this->contextObj)),
|
||||||
|
module(createEmptyModule(this->context)),
|
||||||
|
moduleObj(castMlirModuleToPythonObject(module)),
|
||||||
|
unknownLoc(mlirLocationUnknownGet(context)), typeMapper(this->context) {
|
||||||
// TODO: Rework this once dialect registration C-APIs are in place.
|
// TODO: Rework this once dialect registration C-APIs are in place.
|
||||||
// https://reviews.llvm.org/D88162
|
// https://reviews.llvm.org/D88162
|
||||||
mlirRegisterAllDialects(context);
|
mlirRegisterAllDialects(context);
|
||||||
|
@ -56,19 +68,6 @@ ModuleBuilder::ModuleBuilder()
|
||||||
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
||||||
}
|
}
|
||||||
|
|
||||||
ModuleBuilder::~ModuleBuilder() {
|
|
||||||
mlirModuleDestroy(module);
|
|
||||||
mlirContextDestroy(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
py::str ModuleBuilder::getAsm() {
|
|
||||||
MlirOperation operation = mlirModuleGetOperation(module);
|
|
||||||
PyPrintAccumulator printAccum;
|
|
||||||
mlirOperationPrint(operation, printAccum.getCallback(),
|
|
||||||
printAccum.getUserData());
|
|
||||||
return printAccum.join();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<AcapController>
|
std::shared_ptr<AcapController>
|
||||||
ModuleBuilder::startCaptureFunction(std::string &name,
|
ModuleBuilder::startCaptureFunction(std::string &name,
|
||||||
std::vector<at::Tensor> args) {
|
std::vector<at::Tensor> args) {
|
||||||
|
@ -103,8 +102,9 @@ MlirBlock ModuleBuilder::getBodyBlock() {
|
||||||
|
|
||||||
void ModuleBuilder::bind(py::module &m) {
|
void ModuleBuilder::bind(py::module &m) {
|
||||||
py::class_<ModuleBuilder>(m, "ModuleBuilder")
|
py::class_<ModuleBuilder>(m, "ModuleBuilder")
|
||||||
.def(py::init<>())
|
.def(py::init<py::object>(), py::arg("context") = py::none())
|
||||||
.def("__str__", &ModuleBuilder::getAsm)
|
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||||
|
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||||
.def("capture_function", &ModuleBuilder::startCaptureFunction,
|
.def("capture_function", &ModuleBuilder::startCaptureFunction,
|
||||||
py::keep_alive<0, 1>());
|
py::keep_alive<0, 1>());
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
|
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
|
||||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
|
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
|
||||||
|
|
||||||
// TODO: Remove this dep once the getAsm() method is removed.
|
|
||||||
#include "../pybind.h"
|
#include "../pybind.h"
|
||||||
|
|
||||||
#include "acap_dispatch.h"
|
#include "acap_dispatch.h"
|
||||||
|
@ -24,14 +23,13 @@ namespace torch_mlir {
|
||||||
/// of PyTorch programs/execution.
|
/// of PyTorch programs/execution.
|
||||||
class ModuleBuilder {
|
class ModuleBuilder {
|
||||||
public:
|
public:
|
||||||
ModuleBuilder();
|
ModuleBuilder(pybind11::object contextObj);
|
||||||
~ModuleBuilder();
|
|
||||||
|
|
||||||
/// Creates Python bindings for the class.
|
/// Creates Python bindings for the class.
|
||||||
static void bind(pybind11::module &m);
|
static void bind(pybind11::module &m);
|
||||||
|
|
||||||
// TODO: Remove this once the MLIR Python objects are exposed directly.
|
pybind11::object getContextObj() { return contextObj; }
|
||||||
pybind11::str getAsm();
|
pybind11::object getModuleObj() { return moduleObj; }
|
||||||
|
|
||||||
// Starts a device-capture based function.
|
// Starts a device-capture based function.
|
||||||
// TODO: Add inputs.
|
// TODO: Add inputs.
|
||||||
|
@ -41,10 +39,15 @@ public:
|
||||||
private:
|
private:
|
||||||
MlirBlock getBodyBlock();
|
MlirBlock getBodyBlock();
|
||||||
|
|
||||||
|
// Capture references to the python-owned context and module. Ownership
|
||||||
|
// is delegated to python for these, and the C-API types are extracted via
|
||||||
|
// the capsule API.
|
||||||
|
pybind11::object contextObj;
|
||||||
MlirContext context;
|
MlirContext context;
|
||||||
MlirLocation unknownLoc;
|
|
||||||
MlirModule module;
|
MlirModule module;
|
||||||
|
pybind11::object moduleObj;
|
||||||
MlirOperation terminator;
|
MlirOperation terminator;
|
||||||
|
MlirLocation unknownLoc;
|
||||||
|
|
||||||
TypeMapper typeMapper;
|
TypeMapper typeMapper;
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
################################################################################
|
||||||
|
# Manage python source files
|
||||||
|
# Collapse all local python sources to the project level python/ directory.
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
npcomp_python_create_symlinks(${CMAKE_BINARY_DIR}/python ${CMAKE_CURRENT_SOURCE_DIR})
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
# This is a trampoline module which loads the _torch_mlir native module
|
||||||
|
# and binds names locally. It exists to allow for customization of behavior
|
||||||
|
# prior to loading shared objects.
|
||||||
|
|
||||||
|
from _torch_mlir import ModuleBuilder
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModuleBuilder",
|
||||||
|
]
|
|
@ -6,12 +6,12 @@
|
||||||
# TODO: Once stabilized, expand tests to include all argument dtypes.
|
# TODO: Once stabilized, expand tests to include all argument dtypes.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import _torch_mlir as m
|
import torch_mlir
|
||||||
|
|
||||||
t0 = torch.randn((1,4))
|
t0 = torch.randn((1,4))
|
||||||
t1 = torch.randn((4,1))
|
t1 = torch.randn((4,1))
|
||||||
|
|
||||||
mb = m.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
with mb.capture_function("foobar", [t0, t1]) as f:
|
with mb.capture_function("foobar", [t0, t1]) as f:
|
||||||
result = t0 + t1
|
result = t0 + t1
|
||||||
f.returns([result])
|
f.returns([result])
|
||||||
|
@ -23,7 +23,7 @@ with mb.capture_function("foobar", [t0, t1]) as f:
|
||||||
# CHECK: return %0 : !numpy.ndarray<[4,4]:f32>
|
# CHECK: return %0 : !numpy.ndarray<[4,4]:f32>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
print(mb)
|
print(mb.module)
|
||||||
|
|
||||||
# CHECK: CAPTURE: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
|
# CHECK: CAPTURE: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
|
||||||
for line in f.get_debug_log(): print(line)
|
for line in f.get_debug_log(): print(line)
|
||||||
|
|
|
@ -19,51 +19,12 @@ add_dependencies(NPCOMPPythonResources
|
||||||
################################################################################
|
################################################################################
|
||||||
# Manage python source files
|
# Manage python source files
|
||||||
################################################################################
|
################################################################################
|
||||||
function (create_symlinks)
|
npcomp_python_create_symlinks(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
# Do nothing if building in-source
|
|
||||||
if (${CMAKE_CURRENT_BINARY_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR})
|
|
||||||
return()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
foreach (path_file ${ARGN})
|
|
||||||
get_filename_component(folder ${path_file} PATH)
|
|
||||||
|
|
||||||
# Create REAL folder
|
|
||||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${folder}")
|
|
||||||
|
|
||||||
# Delete symlink if it exists
|
|
||||||
file(REMOVE "${CMAKE_CURRENT_BINARY_DIR}/${path_file}")
|
|
||||||
|
|
||||||
# Get OS dependent path to use in `execute_process`
|
|
||||||
file(TO_NATIVE_PATH "${CMAKE_CURRENT_BINARY_DIR}/${path_file}" link)
|
|
||||||
file(TO_NATIVE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${path_file}" target)
|
|
||||||
|
|
||||||
if (UNIX)
|
|
||||||
set(command ln -s ${target} ${link})
|
|
||||||
else()
|
|
||||||
set(command cmd.exe /c mklink ${link} ${target})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
execute_process(COMMAND ${command}
|
|
||||||
RESULT_VARIABLE result
|
|
||||||
ERROR_VARIABLE output)
|
|
||||||
|
|
||||||
if (NOT ${result} EQUAL 0)
|
|
||||||
message(FATAL_ERROR "Could not create symbolic link for: ${target} --> ${output}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
endforeach(path_file)
|
|
||||||
endfunction(create_symlinks)
|
|
||||||
|
|
||||||
file(GLOB_RECURSE python_files RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.py)
|
|
||||||
create_symlinks(${python_files})
|
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Native extensions
|
# Native extensions
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
include(NpcompPython)
|
|
||||||
|
|
||||||
# Normally on unix-like platforms, extensions are built as "MODULE" libraries
|
# Normally on unix-like platforms, extensions are built as "MODULE" libraries
|
||||||
# and do not explicitly link to the python shared object. This allows for
|
# and do not explicitly link to the python shared object. This allows for
|
||||||
# come greater deployment flexibility since the extension will bind to
|
# come greater deployment flexibility since the extension will bind to
|
||||||
|
|
Loading…
Reference in New Issue