2020-12-30 05:22:18 +08:00
|
|
|
//===- NpcompModule.cpp - MLIR Python bindings ----------------------------===//
|
2020-04-27 06:50:23 +08:00
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include <cstddef>
|
|
|
|
#include <unordered_map>
|
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
|
|
#include "mlir-c/Diagnostics.h"
|
|
|
|
#include "npcomp-c/InitLLVM.h"
|
2020-11-11 13:38:13 +08:00
|
|
|
#include "npcomp-c/Registration.h"
|
2020-12-30 05:22:18 +08:00
|
|
|
#include "npcomp-c/Types.h"
|
|
|
|
#include "npcomp/Python/PybindUtils.h"
|
2020-05-01 07:00:00 +08:00
|
|
|
|
2020-07-11 08:36:32 +08:00
|
|
|
#ifdef NPCOMP_ENABLE_REFJIT
|
|
|
|
#include "npcomp/Backend/RefJIT/PythonModule.h"
|
|
|
|
#endif
|
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
namespace {
|
2020-05-01 07:00:00 +08:00
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
MlirType shapedToNdArrayArrayType(MlirType shaped_type) {
|
|
|
|
if (!mlirTypeIsAShaped(shaped_type)) {
|
|
|
|
throw py::raiseValueError("type is not a shaped type");
|
|
|
|
}
|
|
|
|
return npcompNdArrayTypeGetFromShaped(shaped_type);
|
2020-05-01 07:00:00 +08:00
|
|
|
}
|
2020-04-27 06:50:23 +08:00
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
MlirType ndarrayToTensorType(MlirType ndarray_type) {
|
|
|
|
if (!npcompTypeIsANdArray(ndarray_type)) {
|
|
|
|
throw py::raiseValueError("type is not an ndarray type");
|
|
|
|
}
|
|
|
|
return npcompNdArrayTypeToTensor(ndarray_type);
|
2020-11-11 13:38:13 +08:00
|
|
|
}
|
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
MlirType slotObjectType(MlirContext context, const std::string &className,
|
|
|
|
const std::vector<MlirType> &slotTypes) {
|
|
|
|
MlirStringRef classNameSr{className.data(), className.size()};
|
|
|
|
return ::npcompSlotObjectTypeGet(context, classNameSr, slotTypes.size(),
|
|
|
|
slotTypes.data());
|
|
|
|
}
|
2020-04-27 06:50:23 +08:00
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
// TODO: Move this upstream.
|
|
|
|
void emitError(MlirLocation loc, std::string message) {
|
|
|
|
::mlirEmitError(loc, message.c_str());
|
|
|
|
}
|
2020-05-01 07:00:00 +08:00
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
} // namespace
|
2020-05-07 09:24:51 +08:00
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
PYBIND11_MODULE(_npcomp, m) {
|
|
|
|
m.doc() = "Npcomp native python bindings";
|
2020-06-19 14:28:30 +08:00
|
|
|
|
2020-12-30 05:22:18 +08:00
|
|
|
m.def("register_all_dialects", ::npcompRegisterAllDialects);
|
|
|
|
m.def("_register_all_passes", ::npcompRegisterAllPasses);
|
|
|
|
m.def("_initialize_llvm_codegen", ::npcompInitializeLLVMCodegen);
|
|
|
|
m.def("shaped_to_ndarray_type", shapedToNdArrayArrayType);
|
|
|
|
m.def("ndarray_to_tensor_type", ndarrayToTensorType);
|
|
|
|
m.def("slot_object_type", slotObjectType);
|
|
|
|
m.def("emit_error", emitError);
|
2020-11-11 13:38:13 +08:00
|
|
|
|
2020-06-19 14:28:30 +08:00
|
|
|
// Optional backend modules.
|
|
|
|
auto backend_m = m.def_submodule("backend", "Backend support");
|
|
|
|
(void)backend_m;
|
|
|
|
|
2020-07-11 08:36:32 +08:00
|
|
|
#ifdef NPCOMP_ENABLE_REFJIT
|
|
|
|
auto refjit_m =
|
|
|
|
backend_m.def_submodule("refjit", "Reference CPU Jit Backend");
|
|
|
|
::npcomp::python::defineBackendRefJitModule(refjit_m);
|
|
|
|
#endif
|
2020-04-27 06:50:23 +08:00
|
|
|
}
|