mirror of https://github.com/llvm/torch-mlir
Merge npcomp and mlir python namespaces.
* Now the parts of the MLIR API are directly exported under the npcomp module (i.e. `npcomp.ir`, etc). * Has required fixes for https://reviews.llvm.org/D108489 * Deletes npcomp.tracing vs fixing it because it was a very early experiment that will not be carried forward. * This makes the npcomp python distribution completely standalone and separate from an mlir installation. * Makes most of npcomp itself relocatable for future use as a library. * Most things are a namespace package now. In the future we can s/torch_mlir/npcomp.frontends.torch/ and have it layer properly.pull/289/head
parent
177ccdd55b
commit
4148f88576
|
@ -1,5 +1,9 @@
|
||||||
include(NpcompPython)
|
include(NpcompPython)
|
||||||
|
|
||||||
|
# TODO: Add this to an npcomp header of some kind so it doesn't need to be
|
||||||
|
# passed loose.
|
||||||
|
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=npcomp.")
|
||||||
|
|
||||||
# Sharp edge: Torch extensions need to use the same pybind11 that torch
|
# Sharp edge: Torch extensions need to use the same pybind11 that torch
|
||||||
# was compiled with, or else there will be issues in cross module exception
|
# was compiled with, or else there will be issues in cross module exception
|
||||||
# handling (which will abort instead of raise). We circumvent the possibility
|
# handling (which will abort instead of raise). We circumvent the possibility
|
||||||
|
|
|
@ -21,7 +21,7 @@ namespace py = pybind11;
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
static py::object getMlirIrClass(const char *className) {
|
static py::object getMlirIrClass(const char *className) {
|
||||||
return py::module::import("mlir.ir").attr(className);
|
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr(className);
|
||||||
}
|
}
|
||||||
|
|
||||||
static py::object createPythonContextIfNone(py::object contextObj) {
|
static py::object createPythonContextIfNone(py::object contextObj) {
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
# prior to loading shared objects.
|
# prior to loading shared objects.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import npcomp
|
import npcomp._mlir_libs._npcomp
|
||||||
|
|
||||||
# Our native extension is not self-contained. It references libraries which
|
# Our native extension is not self-contained. It references libraries which
|
||||||
# must come in via the above first.
|
# must come in via the above first.
|
||||||
|
|
|
@ -10,9 +10,9 @@ import tempfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mlir.passmanager import PassManager
|
|
||||||
|
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
from npcomp.passmanager import PassManager
|
||||||
from npcomp.compiler.pytorch.backend import refjit
|
from npcomp.compiler.pytorch.backend import refjit
|
||||||
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
|
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
|
||||||
from torch_mlir_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
|
@ -22,6 +22,9 @@ extension_sources = [str(p) for p in this_dir.joinpath("csrc").rglob("*.cpp")]
|
||||||
include_dirs = npcomp_build.get_include_dirs()
|
include_dirs = npcomp_build.get_include_dirs()
|
||||||
lib_dirs = npcomp_build.get_lib_dirs()
|
lib_dirs = npcomp_build.get_lib_dirs()
|
||||||
npcomp_libs = [npcomp_build.get_capi_link_library_name()]
|
npcomp_libs = [npcomp_build.get_capi_link_library_name()]
|
||||||
|
# TODO: Export this in some way from an npcomp config file include vs needing
|
||||||
|
# it loose here.
|
||||||
|
compile_args = ["-DMLIR_PYTHON_PACKAGE_PREFIX=npcomp."]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="npcomp-torch",
|
name="npcomp-torch",
|
||||||
|
@ -31,7 +34,8 @@ setup(
|
||||||
sources=extension_sources,
|
sources=extension_sources,
|
||||||
include_dirs=include_dirs,
|
include_dirs=include_dirs,
|
||||||
library_dirs=lib_dirs,
|
library_dirs=lib_dirs,
|
||||||
libraries=npcomp_libs),
|
libraries=npcomp_libs,
|
||||||
|
extra_compile_args=compile_args),
|
||||||
],
|
],
|
||||||
cmdclass={
|
cmdclass={
|
||||||
"build_ext": cpp_extension.BuildExtension
|
"build_ext": cpp_extension.BuildExtension
|
||||||
|
|
|
@ -10,8 +10,7 @@ import sys
|
||||||
|
|
||||||
print(f"PYTHONPATH={sys.path}")
|
print(f"PYTHONPATH={sys.path}")
|
||||||
|
|
||||||
import mlir
|
import npcomp.ir
|
||||||
import npcomp
|
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
print("Extensions all loaded")
|
print("Extensions all loaded")
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
include(AddMLIRPython)
|
include(AddMLIRPython)
|
||||||
include(MLIRDetectPythonEnv)
|
include(MLIRDetectPythonEnv)
|
||||||
|
|
||||||
|
# Specifies that all MLIR packages are co-located under npcomp.
|
||||||
|
# TODO: Add an upstream cmake param for this vs having a global here.
|
||||||
|
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=npcomp.")
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Resources that must be packaged into the python tree
|
# Resources that must be packaged into the python tree
|
||||||
################################################################################
|
################################################################################
|
||||||
|
@ -28,21 +32,18 @@ declare_mlir_python_sources(NPCOMPPythonSources)
|
||||||
|
|
||||||
declare_mlir_python_sources(NPCOMPPythonSources.Core
|
declare_mlir_python_sources(NPCOMPPythonSources.Core
|
||||||
ADD_TO_PARENT NPCOMPPythonSources
|
ADD_TO_PARENT NPCOMPPythonSources
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/npcomp"
|
||||||
SOURCES
|
SOURCES
|
||||||
npcomp/__init__.py
|
build.py
|
||||||
npcomp/build.py
|
decorators.py
|
||||||
npcomp/decorators.py
|
exporter.py
|
||||||
npcomp/exporter.py
|
smoketest.py
|
||||||
npcomp/smoketest.py
|
types.py
|
||||||
npcomp/types.py
|
|
||||||
npcomp/dialects/_ods_common.py
|
|
||||||
SOURCES_GLOB
|
SOURCES_GLOB
|
||||||
npcomp/compiler/*.py
|
compiler/*.py
|
||||||
npcomp/frontends/*.py
|
frontends/*.py
|
||||||
npcomp/torch/*.py
|
torch/*.py
|
||||||
npcomp/tracing/*.py
|
utils/*.py
|
||||||
npcomp/utils/*.py
|
|
||||||
)
|
)
|
||||||
|
|
||||||
declare_mlir_python_sources(NPCOMPPythonSources.Dialects
|
declare_mlir_python_sources(NPCOMPPythonSources.Dialects
|
||||||
|
@ -83,23 +84,23 @@ declare_mlir_python_extension(NPCOMPPythonExtensions.Core
|
||||||
|
|
||||||
declare_mlir_dialect_python_bindings(
|
declare_mlir_dialect_python_bindings(
|
||||||
ADD_TO_PARENT NPCOMPPythonSources.Dialects
|
ADD_TO_PARENT NPCOMPPythonSources.Dialects
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/npcomp"
|
||||||
TD_FILE npcomp/dialects/BasicpyBind.td
|
TD_FILE dialects/BasicpyBind.td
|
||||||
SOURCES npcomp/dialects/basicpy.py
|
SOURCES dialects/basicpy.py
|
||||||
DIALECT_NAME basicpy)
|
DIALECT_NAME basicpy)
|
||||||
|
|
||||||
declare_mlir_dialect_python_bindings(
|
declare_mlir_dialect_python_bindings(
|
||||||
ADD_TO_PARENT NPCOMPPythonSources.Dialects
|
ADD_TO_PARENT NPCOMPPythonSources.Dialects
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/npcomp"
|
||||||
TD_FILE npcomp/dialects/NumpyBind.td
|
TD_FILE dialects/NumpyBind.td
|
||||||
SOURCES npcomp/dialects/numpy.py
|
SOURCES dialects/numpy.py
|
||||||
DIALECT_NAME numpy)
|
DIALECT_NAME numpy)
|
||||||
|
|
||||||
declare_mlir_dialect_python_bindings(
|
declare_mlir_dialect_python_bindings(
|
||||||
ADD_TO_PARENT NPCOMPPythonSources.Dialects
|
ADD_TO_PARENT NPCOMPPythonSources.Dialects
|
||||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/npcomp"
|
||||||
TD_FILE npcomp/dialects/TorchBind.td
|
TD_FILE dialects/TorchBind.td
|
||||||
SOURCES npcomp/dialects/torch.py
|
SOURCES dialects/torch.py
|
||||||
DIALECT_NAME torch)
|
DIALECT_NAME torch)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
|
@ -109,44 +110,30 @@ declare_mlir_dialect_python_bindings(
|
||||||
# Bundle our own, self-contained CAPI library with all of our deps.
|
# Bundle our own, self-contained CAPI library with all of our deps.
|
||||||
add_mlir_python_common_capi_library(NPCOMPPythonCAPI
|
add_mlir_python_common_capi_library(NPCOMPPythonCAPI
|
||||||
INSTALL_COMPONENT NPCOMPPythonModules
|
INSTALL_COMPONENT NPCOMPPythonModules
|
||||||
INSTALL_DESTINATION python_packages/npcomp_core/mlir/_mlir_libs
|
INSTALL_DESTINATION python_packages/npcomp_core/npcomp/_mlir_libs
|
||||||
# NOTE: When the MLIR API is relocated under npcomp, this would change to
|
OUTPUT_DIRECTORY "${MLIR_NPCOMP_PYTHON_PACKAGES_DIR}/npcomp_core/npcomp/_mlir_libs"
|
||||||
# .../npcomp/_mlir_libs
|
|
||||||
OUTPUT_DIRECTORY "${MLIR_NPCOMP_PYTHON_PACKAGES_DIR}/npcomp_core/mlir/_mlir_libs"
|
|
||||||
RELATIVE_INSTALL_ROOT "../../../.."
|
RELATIVE_INSTALL_ROOT "../../../.."
|
||||||
DECLARED_SOURCES
|
DECLARED_SOURCES
|
||||||
# TODO: This can be chopped down significantly for size.
|
# TODO: Common MLIR deps can be reduced substantially.
|
||||||
MLIRPythonSources
|
MLIRPythonSources.Core
|
||||||
|
MLIRPythonSources.Dialects
|
||||||
|
MLIRPythonSources.ExecutionEngine
|
||||||
MLIRPythonExtension.AllPassesRegistration
|
MLIRPythonExtension.AllPassesRegistration
|
||||||
NPCOMPPythonSources
|
NPCOMPPythonSources
|
||||||
NPCOMPPythonExtensions
|
NPCOMPPythonExtensions
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bundle the MLIR python sources into our package.
|
# Bundle MLIR and NPCOMP into the top-level npcomp package.
|
||||||
# The MLIR API is position independent, so we explicitly output it to the mlir/
|
add_mlir_python_modules(NPCOMPPythonModules
|
||||||
# folder as a temporary measure. It will eventually migrate under the npcomp/
|
ROOT_PREFIX "${MLIR_NPCOMP_PYTHON_PACKAGES_DIR}/npcomp_core/npcomp"
|
||||||
# folder and be accessible under the unified "import npcomp..." namespace.
|
INSTALL_PREFIX "python_packages/npcomp_core/npcomp"
|
||||||
add_mlir_python_modules(NPCOMPMLIRPythonModules
|
|
||||||
ROOT_PREFIX "${MLIR_NPCOMP_PYTHON_PACKAGES_DIR}/npcomp_core/mlir"
|
|
||||||
INSTALL_PREFIX "python_packages/npcomp_core/mlir"
|
|
||||||
DECLARED_SOURCES
|
DECLARED_SOURCES
|
||||||
MLIRPythonSources
|
MLIRPythonSources
|
||||||
MLIRPythonExtension.AllPassesRegistration
|
MLIRPythonExtension.AllPassesRegistration
|
||||||
# We need the npcomp extensions co-located with the MLIR extensions. When
|
|
||||||
# the namespace is unified, this moves to the below.
|
|
||||||
MLIRPythonCAPIHeaderSources
|
MLIRPythonCAPIHeaderSources
|
||||||
|
NPCOMPPythonSources
|
||||||
NPCOMPPythonExtensions
|
NPCOMPPythonExtensions
|
||||||
NPCOMPPythonCAPIHeaderSources
|
NPCOMPPythonCAPIHeaderSources
|
||||||
COMMON_CAPI_LINK_LIBS
|
COMMON_CAPI_LINK_LIBS
|
||||||
NPCOMPPythonCAPI
|
NPCOMPPythonCAPI
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bundle the NPCOMP python sources into our package.
|
|
||||||
add_mlir_python_modules(NPCOMPPythonModules
|
|
||||||
ROOT_PREFIX "${MLIR_NPCOMP_PYTHON_PACKAGES_DIR}/npcomp_core"
|
|
||||||
INSTALL_PREFIX "python_packages/npcomp_core"
|
|
||||||
DECLARED_SOURCES
|
|
||||||
NPCOMPPythonSources
|
|
||||||
COMMON_CAPI_LINK_LIBS
|
|
||||||
NPCOMPPythonCAPI
|
|
||||||
)
|
|
||||||
|
|
|
@ -52,10 +52,10 @@ void emitError(MlirLocation loc, std::string message) {
|
||||||
|
|
||||||
PYBIND11_MODULE(_npcomp, m) {
|
PYBIND11_MODULE(_npcomp, m) {
|
||||||
m.doc() = "Npcomp native python bindings";
|
m.doc() = "Npcomp native python bindings";
|
||||||
|
::npcompRegisterAllPasses();
|
||||||
|
::npcompInitializeLLVMCodegen();
|
||||||
|
|
||||||
m.def("register_all_dialects", ::npcompRegisterAllDialects);
|
m.def("register_all_dialects", ::npcompRegisterAllDialects);
|
||||||
m.def("_register_all_passes", ::npcompRegisterAllPasses);
|
|
||||||
m.def("_initialize_llvm_codegen", ::npcompInitializeLLVMCodegen);
|
|
||||||
m.def("shaped_to_ndarray_type", shapedToNdArrayArrayType);
|
m.def("shaped_to_ndarray_type", shapedToNdArrayArrayType);
|
||||||
m.def("ndarray_to_tensor_type", ndarrayToTensorType);
|
m.def("ndarray_to_tensor_type", ndarrayToTensorType);
|
||||||
m.def("slot_object_type", slotObjectType);
|
m.def("slot_object_type", slotObjectType);
|
||||||
|
|
|
@ -20,153 +20,11 @@
|
||||||
#include "mlir-c/Bindings/Python/Interop.h"
|
#include "mlir-c/Bindings/Python/Interop.h"
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
#include "mlir-c/Pass.h"
|
#include "mlir-c/Pass.h"
|
||||||
|
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||||
#include "llvm/ADT/Optional.h"
|
#include "llvm/ADT/Optional.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
namespace pybind11 {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
|
|
||||||
|
|
||||||
/// Helper to convert a presumed MLIR API object to a capsule, accepting either
|
|
||||||
/// an explicit Capsule (which can happen when two C APIs are communicating
|
|
||||||
/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
|
|
||||||
/// attribute (through which supported MLIR Python API objects export their
|
|
||||||
/// contained API pointer as a capsule). This is intended to be used from
|
|
||||||
/// type casters, which are invoked with a raw handle (unowned). The returned
|
|
||||||
/// object's lifetime may not extend beyond the apiObject handle without
|
|
||||||
/// explicitly having its refcount increased (i.e. on return).
|
|
||||||
static py::object mlirApiObjectToCapsule(py::handle apiObject) {
|
|
||||||
if (PyCapsule_CheckExact(apiObject.ptr()))
|
|
||||||
return py::reinterpret_borrow<py::object>(apiObject);
|
|
||||||
return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: Currently all of the following support cast from py::object to the
|
|
||||||
// Mlir* C-API type, but only a few light-weight, context-bound ones
|
|
||||||
// implicitly cast the other way because the use case has not yet emerged and
|
|
||||||
// ownership is unclear.
|
|
||||||
|
|
||||||
/// Casts object -> MlirAttribute.
|
|
||||||
template <> struct type_caster<MlirAttribute> {
|
|
||||||
PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
|
|
||||||
bool load(handle src, bool) {
|
|
||||||
auto capsule = mlirApiObjectToCapsule(src);
|
|
||||||
value = mlirPythonCapsuleToAttribute(capsule.ptr());
|
|
||||||
if (mlirAttributeIsNull(value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
static handle cast(MlirAttribute v, return_value_policy, handle) {
|
|
||||||
auto capsule =
|
|
||||||
py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v));
|
|
||||||
return py::module::import("mlir.ir")
|
|
||||||
.attr("Attribute")
|
|
||||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
|
||||||
.release();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Casts object -> MlirContext.
|
|
||||||
template <> struct type_caster<MlirContext> {
|
|
||||||
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
|
|
||||||
bool load(handle src, bool) {
|
|
||||||
auto capsule = mlirApiObjectToCapsule(src);
|
|
||||||
value = mlirPythonCapsuleToContext(capsule.ptr());
|
|
||||||
if (mlirContextIsNull(value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Casts object -> MlirLocation.
|
|
||||||
template <> struct type_caster<MlirLocation> {
|
|
||||||
PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
|
|
||||||
bool load(handle src, bool) {
|
|
||||||
auto capsule = mlirApiObjectToCapsule(src);
|
|
||||||
value = mlirPythonCapsuleToLocation(capsule.ptr());
|
|
||||||
if (mlirLocationIsNull(value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
static handle cast(MlirLocation v, return_value_policy, handle) {
|
|
||||||
auto capsule =
|
|
||||||
py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v));
|
|
||||||
return py::module::import("mlir.ir")
|
|
||||||
.attr("Location")
|
|
||||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
|
||||||
.release();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Casts object -> MlirModule.
|
|
||||||
template <> struct type_caster<MlirModule> {
|
|
||||||
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
|
|
||||||
bool load(handle src, bool) {
|
|
||||||
auto capsule = mlirApiObjectToCapsule(src);
|
|
||||||
value = mlirPythonCapsuleToModule(capsule.ptr());
|
|
||||||
if (mlirModuleIsNull(value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Casts object -> MlirOperation.
|
|
||||||
template <> struct type_caster<MlirOperation> {
|
|
||||||
PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
|
|
||||||
bool load(handle src, bool) {
|
|
||||||
auto capsule = mlirApiObjectToCapsule(src);
|
|
||||||
value = mlirPythonCapsuleToOperation(capsule.ptr());
|
|
||||||
if (mlirOperationIsNull(value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Casts object -> MlirPassManager.
|
|
||||||
template <> struct type_caster<MlirPassManager> {
|
|
||||||
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
|
|
||||||
bool load(handle src, bool) {
|
|
||||||
auto capsule = mlirApiObjectToCapsule(src);
|
|
||||||
value = mlirPythonCapsuleToPassManager(capsule.ptr());
|
|
||||||
if (mlirPassManagerIsNull(value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Casts object -> MlirType.
|
|
||||||
template <> struct type_caster<MlirType> {
|
|
||||||
PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
|
|
||||||
bool load(handle src, bool) {
|
|
||||||
auto capsule = mlirApiObjectToCapsule(src);
|
|
||||||
value = mlirPythonCapsuleToType(capsule.ptr());
|
|
||||||
if (mlirTypeIsNull(value)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
static handle cast(MlirType t, return_value_policy, handle) {
|
|
||||||
auto capsule =
|
|
||||||
py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t));
|
|
||||||
return py::module::import("mlir.ir")
|
|
||||||
.attr("Type")
|
|
||||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
|
||||||
.release();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace pybind11
|
|
||||||
|
|
||||||
namespace pybind11 {
|
namespace pybind11 {
|
||||||
|
|
||||||
/// Raises a python exception with the given message.
|
/// Raises a python exception with the given message.
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
from mlir import _cext_loader
|
|
||||||
_cext_loader._cext.globals.append_dialect_search_prefix("npcomp.dialects")
|
|
||||||
|
|
||||||
_cext = _cext_loader._load_extension("_npcomp")
|
|
||||||
_cext._register_all_passes()
|
|
||||||
_cext._initialize_llvm_codegen()
|
|
||||||
|
|
||||||
# Top-level symbols.
|
|
||||||
from .exporter import *
|
|
||||||
from .types import *
|
|
||||||
|
|
||||||
from . import tracing
|
|
||||||
from . import utils
|
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
from mlir._mlir_libs import get_include_dirs, get_lib_dirs
|
from ._mlir_libs import get_include_dirs, get_lib_dirs
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_include_dirs",
|
"get_include_dirs",
|
||||||
|
|
|
@ -18,7 +18,7 @@ def get_refjit():
|
||||||
global _refjit
|
global _refjit
|
||||||
if _refjit is not None:
|
if _refjit is not None:
|
||||||
return _refjit
|
return _refjit
|
||||||
from .... import _cext
|
from ...._mlir_libs import _npcomp as _cext
|
||||||
try:
|
try:
|
||||||
imported_refjit = _cext.backend.refjit
|
imported_refjit = _cext.backend.refjit
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
@ -8,15 +8,15 @@ from typing import Callable, Iterator, Sequence, Tuple
|
||||||
import functools
|
import functools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mlir import ir as _ir
|
|
||||||
|
|
||||||
from npcomp import _cext
|
|
||||||
from npcomp.dialects import numpy as numpy_ops
|
|
||||||
|
|
||||||
from ....utils import logging
|
|
||||||
from ...interfaces import *
|
from ...interfaces import *
|
||||||
from ...partial_eval_base import *
|
from ...partial_eval_base import *
|
||||||
|
|
||||||
|
from ....utils import logging
|
||||||
|
|
||||||
|
from ..... import ir as _ir
|
||||||
|
from ....._mlir_libs import _npcomp as _cext
|
||||||
|
from .....dialects import numpy as numpy_ops
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_ufuncs_from_module",
|
"get_ufuncs_from_module",
|
||||||
"bind_ufuncs",
|
"bind_ufuncs",
|
||||||
|
|
|
@ -6,14 +6,13 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from mlir import ir as _ir
|
from ...interfaces import *
|
||||||
from mlir.dialects import std as std_ops
|
|
||||||
|
|
||||||
from npcomp import _cext
|
|
||||||
from npcomp.dialects import numpy as numpy_ops
|
|
||||||
|
|
||||||
from ....utils import logging
|
from ....utils import logging
|
||||||
from ...interfaces import *
|
|
||||||
|
from ..... import ir as _ir
|
||||||
|
from .....dialects import std as std_ops, numpy as numpy_ops
|
||||||
|
from ....._mlir_libs import _npcomp as _cext
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CreateNumpyValueCoder",
|
"CreateNumpyValueCoder",
|
||||||
|
|
|
@ -10,15 +10,17 @@ import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from mlir import ir as _ir
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .importer import *
|
from .importer import *
|
||||||
from .interfaces import *
|
from .interfaces import *
|
||||||
from .name_resolver_base import *
|
from .name_resolver_base import *
|
||||||
from .value_coder_base import *
|
from .value_coder_base import *
|
||||||
from .target import *
|
from .target import *
|
||||||
|
|
||||||
from ..utils.mlir_utils import *
|
from ..utils.mlir_utils import *
|
||||||
|
|
||||||
|
from ... import ir as _ir
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ImportFrontend",
|
"ImportFrontend",
|
||||||
]
|
]
|
||||||
|
|
|
@ -8,14 +8,12 @@ import ast
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from mlir import ir as _ir
|
|
||||||
from mlir.dialects import std as std_ops
|
|
||||||
|
|
||||||
from npcomp import _cext
|
|
||||||
from npcomp.dialects import basicpy as basicpy_ops
|
|
||||||
|
|
||||||
from ..utils import logging
|
|
||||||
from .interfaces import *
|
from .interfaces import *
|
||||||
|
from ..utils import logging
|
||||||
|
|
||||||
|
from ... import ir as _ir
|
||||||
|
from ...dialects import std as std_ops, basicpy as basicpy_ops
|
||||||
|
from ..._mlir_libs import _npcomp as _cext
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FunctionContext",
|
"FunctionContext",
|
||||||
|
|
|
@ -8,10 +8,9 @@ from enum import Enum
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional, Sequence, Tuple, Union
|
from typing import List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from mlir import ir as _ir
|
|
||||||
|
|
||||||
from .target import *
|
from .target import *
|
||||||
from ..utils.mlir_utils import *
|
from ..utils.mlir_utils import *
|
||||||
|
from ... import ir as _ir
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Configuration",
|
"Configuration",
|
||||||
|
|
|
@ -5,9 +5,10 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from mlir import ir as _ir
|
|
||||||
from .interfaces import *
|
from .interfaces import *
|
||||||
|
|
||||||
|
from ... import ir as _ir
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConstModuleNameResolver",
|
"ConstModuleNameResolver",
|
||||||
"LocalNameResolver",
|
"LocalNameResolver",
|
||||||
|
|
|
@ -3,10 +3,11 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
from typing import *
|
from typing import *
|
||||||
from mlir import ir as _ir
|
|
||||||
|
|
||||||
from ..utils.mlir_utils import *
|
from ..utils.mlir_utils import *
|
||||||
|
|
||||||
|
from ... import ir as _ir
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GenericTarget32",
|
"GenericTarget32",
|
||||||
"GenericTarget64",
|
"GenericTarget64",
|
||||||
|
|
|
@ -6,13 +6,13 @@
|
||||||
import ast
|
import ast
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from ..utils import logging
|
|
||||||
from .frontend import *
|
from .frontend import *
|
||||||
from .interfaces import *
|
from .interfaces import *
|
||||||
from .partial_eval_base import *
|
from .partial_eval_base import *
|
||||||
from .target import *
|
from .target import *
|
||||||
from .value_coder_base import *
|
from .value_coder_base import *
|
||||||
from .extensions import numpy as npc
|
from .extensions import numpy as npc
|
||||||
|
from ..utils import logging
|
||||||
|
|
||||||
|
|
||||||
def create_import_dump_decorator(*,
|
def create_import_dump_decorator(*,
|
||||||
|
|
|
@ -5,12 +5,11 @@
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from mlir import ir as _ir
|
|
||||||
from mlir.dialects import std as std_ops
|
|
||||||
from npcomp.dialects import basicpy as basicpy_ops
|
|
||||||
|
|
||||||
from .interfaces import *
|
from .interfaces import *
|
||||||
|
|
||||||
|
from ... import ir as _ir
|
||||||
|
from ...dialects import std as std_ops, basicpy as basicpy_ops
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BuiltinsValueCoder",
|
"BuiltinsValueCoder",
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,7 +7,7 @@ from typing import TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mlir.ir import Module
|
from npcomp.ir import Module
|
||||||
|
|
||||||
# A type shared between the result of `NpcompBackend.compile` and the input
|
# A type shared between the result of `NpcompBackend.compile` and the input
|
||||||
# to `NpcompBackend.load`. Each backend will likely have a different definition
|
# to `NpcompBackend.load`. Each backend will likely have a different definition
|
||||||
|
|
|
@ -6,8 +6,8 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mlir.ir import *
|
from npcomp.ir import *
|
||||||
from mlir.passmanager import *
|
from npcomp.passmanager import *
|
||||||
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
||||||
from npcomp.compiler.utils import logging
|
from npcomp.compiler.utils import logging
|
||||||
from .abc import NpcompBackend
|
from .abc import NpcompBackend
|
||||||
|
|
|
@ -5,9 +5,9 @@
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from mlir import ir as _ir
|
from ... import ir as _ir
|
||||||
from mlir.dialects import builtin as builtin_ops
|
from ...dialects import builtin as builtin_ops
|
||||||
from npcomp import _cext
|
from ..._mlir_libs import _npcomp as _cext
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ImportContext",
|
"ImportContext",
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
# Generated tablegen dialects expect to be able to find some symbols from
|
|
||||||
# the mlir.dialects package.
|
|
||||||
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context
|
|
|
@ -1,3 +0,0 @@
|
||||||
# Module level symbols.
|
|
||||||
from .context import *
|
|
||||||
from .mlir_trace import *
|
|
|
@ -1,189 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class TracingError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TraceContext:
|
|
||||||
"""Context for intercepting array traces.
|
|
||||||
|
|
||||||
Context manager:
|
|
||||||
----------------
|
|
||||||
Instances act as context managers, the inner-most of which can be
|
|
||||||
queried with current() or optional_current().
|
|
||||||
|
|
||||||
>>> with TraceContext(desc=1) as tc:
|
|
||||||
... print(tc)
|
|
||||||
... print(TraceContext.current())
|
|
||||||
<TraceContext 1>
|
|
||||||
<TraceContext 1>
|
|
||||||
>>> print(TraceContext.optional_current())
|
|
||||||
None
|
|
||||||
>>> TraceContext.current()
|
|
||||||
Traceback (most recent call last):
|
|
||||||
...
|
|
||||||
RuntimeError: No active TraceContext
|
|
||||||
|
|
||||||
Unique ids:
|
|
||||||
-----------
|
|
||||||
Many things in tracing require a context-local id.
|
|
||||||
|
|
||||||
>>> with TraceContext() as tc:
|
|
||||||
... print(tc.get_next_id())
|
|
||||||
... print(tc.get_next_id())
|
|
||||||
1
|
|
||||||
2
|
|
||||||
|
|
||||||
"""
|
|
||||||
_local = threading.local()
|
|
||||||
__slots__ = [
|
|
||||||
"_desc",
|
|
||||||
"_next_id",
|
|
||||||
"active",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, desc=None):
|
|
||||||
_check_numpy_version()
|
|
||||||
self._desc = desc
|
|
||||||
self._next_id = 1
|
|
||||||
self.active = False
|
|
||||||
|
|
||||||
def _handle_array_getitem(self, array, key):
|
|
||||||
"""Handles a call to __getitem__ on a traced array."""
|
|
||||||
raise NotImplementedError("Array slicing not implemented")
|
|
||||||
|
|
||||||
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
|
||||||
"""Handles a ufunc invocation involving at least one TracedArray."""
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def _handle_array_func(self, func, types, inputs, kwargs):
|
|
||||||
"""Handles an __array_func__ hook involving at least on TracedArray."""
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def get_next_id(self):
|
|
||||||
"""Gets the next unique id for the context."""
|
|
||||||
rv = self._next_id
|
|
||||||
self._next_id += 1
|
|
||||||
return rv
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_context_stack(cls):
|
|
||||||
try:
|
|
||||||
return cls._local.s
|
|
||||||
except AttributeError:
|
|
||||||
cls._local.s = []
|
|
||||||
return cls._local.s
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def optional_current(cls) -> Optional["TraceContext"]:
|
|
||||||
s = cls._get_context_stack()
|
|
||||||
if s:
|
|
||||||
return s[-1]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def current(cls) -> "TraceContext":
|
|
||||||
c = cls.optional_current()
|
|
||||||
if c is None:
|
|
||||||
raise RuntimeError("No active TraceContext")
|
|
||||||
return c
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
s = self._get_context_stack()
|
|
||||||
if s:
|
|
||||||
s[-1].active = False
|
|
||||||
s.append(self)
|
|
||||||
self.active = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
s = self._get_context_stack()
|
|
||||||
s.pop()
|
|
||||||
self.active = False
|
|
||||||
if s:
|
|
||||||
s[-1].active = True
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<TraceContext %r>" % self._desc
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_active(tc: TraceContext):
|
|
||||||
assert tc.active, (
|
|
||||||
"Attempt to trace an action on an inactive trace context: %r" % tc)
|
|
||||||
|
|
||||||
|
|
||||||
class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
|
||||||
"""An array that traces its operations.
|
|
||||||
|
|
||||||
Unique ids:
|
|
||||||
-----------
|
|
||||||
>>> tc = TraceContext()
|
|
||||||
>>> TracedArray(tc=tc)
|
|
||||||
<TracedArray 1>
|
|
||||||
>>> TracedArray(tc=tc)
|
|
||||||
<TracedArray 2>
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, tc: Optional[TraceContext] = None):
|
|
||||||
self._tc = tc if tc is not None else TraceContext.current()
|
|
||||||
self._uid = self._tc.get_next_id()
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return id(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def uid(self):
|
|
||||||
return self._uid
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<TracedArray %d>" % self._uid
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
tc = self._tc
|
|
||||||
_assert_active(tc)
|
|
||||||
return tc._handle_array_getitem(self, key)
|
|
||||||
|
|
||||||
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
|
||||||
tc = self._tc
|
|
||||||
_assert_active(tc)
|
|
||||||
return tc._handle_ufunc(ufunc, method, inputs, kwargs)
|
|
||||||
|
|
||||||
def __array_function__(self, func, types, args, kwargs):
|
|
||||||
tc = self._tc
|
|
||||||
_assert_active(tc)
|
|
||||||
return tc._handle_array_func(func, types, args, kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def T(self):
|
|
||||||
"""Shortcut for transpose."""
|
|
||||||
tc = self._tc
|
|
||||||
_assert_active(tc)
|
|
||||||
return tc._handle_array_func(np.transpose, [TracedArray], [self], {})
|
|
||||||
|
|
||||||
|
|
||||||
def _check_numpy_version():
|
|
||||||
version = np.lib.NumpyVersion(np.__version__)
|
|
||||||
if version < "1.16.0":
|
|
||||||
raise RuntimeError("Numpy version >= 1.16 is required")
|
|
||||||
if version > "1.17.0":
|
|
||||||
return
|
|
||||||
if os.environ.get("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION") != "1":
|
|
||||||
raise RuntimeError("For numpy 1.16, the environment variable "
|
|
||||||
"NUMPY_EXPERIMENTAL_ARRAY_FUNCTION must equal 1")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
|
@ -1,296 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from mlir import ir as _ir
|
|
||||||
|
|
||||||
from npcomp.dialects import numpy as numpy_ops
|
|
||||||
|
|
||||||
|
|
||||||
class Protocol(Enum):
|
|
||||||
UFUNC = 1
|
|
||||||
ARRAY_FUNC = 2
|
|
||||||
|
|
||||||
|
|
||||||
class TraceValueType(Enum):
|
|
||||||
NDARRAY = 1
|
|
||||||
|
|
||||||
|
|
||||||
class TraceValue(namedtuple("TraceValue", ["value", "type"])):
|
|
||||||
__slots__ = ()
|
|
||||||
"""A Python value and the trace type that it should correspond to."""
|
|
||||||
|
|
||||||
|
|
||||||
TraceValue.__new__.__defaults__ = (TraceValueType.NDARRAY,)
|
|
||||||
|
|
||||||
|
|
||||||
class TraceInvocation(
|
|
||||||
namedtuple("TraceInvocation", ["inputs", "kwargs", "protocol", "method"])):
|
|
||||||
"""An invocation of a single functions.
|
|
||||||
|
|
||||||
This abstracts over both ufuncs and array_funcs, differentiating by the
|
|
||||||
protocol and method.
|
|
||||||
"""
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
|
|
||||||
TraceInvocation.__new__.__defaults__ = (Protocol.ARRAY_FUNC, "__call__")
|
|
||||||
|
|
||||||
|
|
||||||
class EmissionRequest(
|
|
||||||
namedtuple("EmissionRequest", ["input_ssa_values", "ic", "extra"])):
|
|
||||||
"""Represents the result of processing inputs from an invocation.
|
|
||||||
|
|
||||||
The `input_ssa_values` are mlir.ir.Value instances corresponding to
|
|
||||||
input_trace_values in TraceValueMap.
|
|
||||||
|
|
||||||
The `extra` value is only relevant to the producer and can be used as a
|
|
||||||
blackbox mechanism to transfer un-tracked state from an invocation to
|
|
||||||
emission.
|
|
||||||
|
|
||||||
The `dialect_helper` fields correspond to mlir.ir.DialectHelper.
|
|
||||||
"""
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
|
|
||||||
EmissionRequest.__new__.__defaults__ = (None,)
|
|
||||||
|
|
||||||
|
|
||||||
class TraceValueMap(
|
|
||||||
namedtuple("TraceValueMap",
|
|
||||||
["input_trace_values", "result_trace_value_types", "extra"])):
|
|
||||||
"""The result of mapping an invocation to corresponding op structure.
|
|
||||||
|
|
||||||
This type associates:
|
|
||||||
- Python (object, TraceValueType) representing invocation inputs that
|
|
||||||
correspond to SSA values in the IR.
|
|
||||||
- TraceValueTypes that are the expected logical result types from the
|
|
||||||
invocation.
|
|
||||||
- 'extra' object that is passed to followon Emitter methods.
|
|
||||||
"""
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
|
|
||||||
TraceValueMap.__new__.__defaults__ = (None)
|
|
||||||
|
|
||||||
|
|
||||||
class FuncEmitter:
|
|
||||||
"""An emitter for an op-like function invocation."""
|
|
||||||
|
|
||||||
def map_invocation(self, trace_invocation: TraceInvocation) -> TraceValueMap:
|
|
||||||
"""Maps from an invocation to EmissionRequest.
|
|
||||||
|
|
||||||
This hook is also responsible for validating the invocation and should
|
|
||||||
raise appropriate user-visible exceptions (i.e. when invoked with incorrect
|
|
||||||
arguments).
|
|
||||||
|
|
||||||
This hook is used to prepare for emission in a define-by-run scenario.
|
|
||||||
Static emission from an AST needs to be prepared via another mechanism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trace_invocation: An Invocation instance to map.
|
|
||||||
Returns:
|
|
||||||
A TraceValueMap describing the structure of the invocation as mapped
|
|
||||||
to/from IR.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def map_results(self, py_results, extra):
|
|
||||||
"""Maps a list of python results to actual function return values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
py_results: List of python results corresponding to the emitted op
|
|
||||||
results.
|
|
||||||
extra: The extra object returned by map_invocation.
|
|
||||||
Returns:
|
|
||||||
Actual function result. Typically this requires special handling to
|
|
||||||
unpack the result of functions that return 1 item.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def emit(self, request: EmissionRequest):
|
|
||||||
"""Emits IR using the provided ops and types factories.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emission_inputs: An EmissionRequest produced by tracing each TraceValue
|
|
||||||
from a previous call to map_invocation and the corresponding extra
|
|
||||||
value.
|
|
||||||
Returns:
|
|
||||||
An iterable of mlir.ir.Value instances representing the outputs of the
|
|
||||||
operation. The `builder` on `ops` must be positioned to consume these
|
|
||||||
values.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class GenericCallUfuncEmitter(FuncEmitter):
|
|
||||||
"""A FuncEmitter for generic ufuncs requiring no special behavior.
|
|
||||||
|
|
||||||
Representation:
|
|
||||||
>>> emitter = GenericCallUfuncEmitter("numpy.add")
|
|
||||||
>>> emitter
|
|
||||||
<ufunc emitter 'numpy.add'>
|
|
||||||
>>> inv = TraceInvocation([1, 2], {}, protocol=Protocol.UFUNC)
|
|
||||||
>>> inputs = emitter.map_invocation(inv)
|
|
||||||
>>> inputs
|
|
||||||
TraceValueMap(input_trace_values=[TraceValue(value=1, type=<TraceValueType.NDARRAY: 1>), TraceValue(value=2, type=<TraceValueType.NDARRAY: 1>)], result_trace_value_types=[<TraceValueType.NDARRAY: 1>], extra=None)
|
|
||||||
|
|
||||||
Error on unsupported kwargs:
|
|
||||||
>>> inv = TraceInvocation([1, 2], {"foobar": 1}, protocol=Protocol.UFUNC)
|
|
||||||
>>> emitter.map_invocation(inv)
|
|
||||||
Traceback (most recent call last):
|
|
||||||
...
|
|
||||||
ValueError: Unexpected keyword args for ufunc numpy.add: foobar
|
|
||||||
|
|
||||||
"""
|
|
||||||
__slots__ = ("_ufunc_name")
|
|
||||||
|
|
||||||
def __init__(self, ufunc_name: str):
|
|
||||||
self._ufunc_name = ufunc_name
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<ufunc emitter '%s'>" % self._ufunc_name
|
|
||||||
|
|
||||||
def map_invocation(self,
|
|
||||||
trace_invocation: TraceInvocation) -> EmissionRequest:
|
|
||||||
assert trace_invocation.protocol == Protocol.UFUNC
|
|
||||||
assert trace_invocation.method == "__call__"
|
|
||||||
if trace_invocation.kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
"Unexpected keyword args for ufunc %s: %s" %
|
|
||||||
(self._ufunc_name, ", ".join(trace_invocation.kwargs.keys())))
|
|
||||||
# Without above special cases, any positional args map to emission
|
|
||||||
# inputs.
|
|
||||||
return TraceValueMap([
|
|
||||||
TraceValue(i, TraceValueType.NDARRAY) for i in trace_invocation.inputs
|
|
||||||
], [TraceValueType.NDARRAY],
|
|
||||||
extra=None)
|
|
||||||
|
|
||||||
def map_results(self, py_results, extra):
|
|
||||||
# Ufuncs always return one result, so just unpack it.
|
|
||||||
return py_results[0]
|
|
||||||
|
|
||||||
def emit(self, request: EmissionRequest):
|
|
||||||
ic = request.ic
|
|
||||||
name_attr = _ir.StringAttr.get(self._ufunc_name)
|
|
||||||
result_type = ic.unknown_tensor_type
|
|
||||||
call_op = numpy_ops.BuiltinUfuncCallOp(result_type,
|
|
||||||
qualified_name=name_attr,
|
|
||||||
inputs=request.input_ssa_values,
|
|
||||||
loc=ic.loc,
|
|
||||||
ip=ic.ip)
|
|
||||||
return call_op.results
|
|
||||||
|
|
||||||
|
|
||||||
class GenericArrayFuncEmitter(FuncEmitter):
|
|
||||||
"""Emitter for array funcs that don't do anything 'special'."""
|
|
||||||
__slots__ = ("_op_name", "_nresults")
|
|
||||||
|
|
||||||
def __init__(self, op_name: str, nresults: int = 1):
|
|
||||||
self._op_name = op_name
|
|
||||||
self._nresults = nresults
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<array_func emitter '%s'>" % self._op_name
|
|
||||||
|
|
||||||
def map_invocation(self,
|
|
||||||
trace_invocation: TraceInvocation) -> EmissionRequest:
|
|
||||||
assert trace_invocation.protocol == Protocol.ARRAY_FUNC
|
|
||||||
if trace_invocation.method != "__call__":
|
|
||||||
raise NotImplementedError("Only __call__ is supported for %s (got '%s')" %
|
|
||||||
(
|
|
||||||
self._op_name,
|
|
||||||
trace_invocation.method,
|
|
||||||
))
|
|
||||||
if trace_invocation.kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
"Unexpected keyword args for %s: %s" %
|
|
||||||
(self._op_name, ", ".join(trace_invocation.kwargs.keys())))
|
|
||||||
# Without above special cases, any positional args map to emission
|
|
||||||
# inputs.
|
|
||||||
return TraceValueMap([
|
|
||||||
TraceValue(i, TraceValueType.NDARRAY) for i in trace_invocation.inputs
|
|
||||||
], [TraceValueType.NDARRAY] * self._nresults,
|
|
||||||
extra=None)
|
|
||||||
|
|
||||||
def map_results(self, py_results, extra):
|
|
||||||
if self._nresults == 1:
|
|
||||||
return py_results[0]
|
|
||||||
else:
|
|
||||||
return tuple(py_results)
|
|
||||||
|
|
||||||
def emit(self, request: EmissionRequest):
|
|
||||||
ic = request.ic
|
|
||||||
op_result_types = [ic.unknown_tensor_type] * self._nresults
|
|
||||||
op = _ir.Operation.create(self._op_name,
|
|
||||||
results=op_result_types,
|
|
||||||
operands=request.input_ssa_values,
|
|
||||||
loc=ic.loc,
|
|
||||||
ip=ic.ip)
|
|
||||||
return op.results
|
|
||||||
|
|
||||||
|
|
||||||
class EmitterRegistry:
|
|
||||||
"""Registry of known Emitter instances mapped to source function.
|
|
||||||
|
|
||||||
>>> r = EmitterRegistry.create_default()
|
|
||||||
>>> r.lookup_ufunc(np.add, "__call__")
|
|
||||||
<ufunc emitter 'numpy.add'>
|
|
||||||
>>> r.lookup_array_func(np.dot)
|
|
||||||
<array_func emitter 'numpy.dot'>
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._ufunc_map = {} # Dictionary of (f, method) -> Emitter
|
|
||||||
self._arrayfunc_map = {} # Dictionary of f -> Emitter
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_default(cls):
|
|
||||||
registry = cls()
|
|
||||||
registry.register_defaults()
|
|
||||||
return registry
|
|
||||||
|
|
||||||
def register_ufunc(self, ufunc, method, emitter):
|
|
||||||
# Last registration wins.
|
|
||||||
self._ufunc_map[(ufunc, method)] = emitter
|
|
||||||
|
|
||||||
def register_array_func(self, f, emitter):
|
|
||||||
# Last registration wins.
|
|
||||||
self._arrayfunc_map[f] = emitter
|
|
||||||
|
|
||||||
def lookup_ufunc(self, ufunc, method):
|
|
||||||
return self._ufunc_map.get((ufunc, method))
|
|
||||||
|
|
||||||
def lookup_array_func(self, f):
|
|
||||||
return self._arrayfunc_map.get(f)
|
|
||||||
|
|
||||||
def register_defaults(self):
|
|
||||||
# Find all ufuncs in the numpy module and register by name.
|
|
||||||
for member in sorted(dir(np)):
|
|
||||||
ufunc = getattr(np, member)
|
|
||||||
if isinstance(ufunc, np.ufunc):
|
|
||||||
self.register_ufunc(ufunc, "__call__",
|
|
||||||
GenericCallUfuncEmitter("numpy." + member))
|
|
||||||
# Register generic 1-result array funcs.
|
|
||||||
GENERIC_FUNCS = (
|
|
||||||
(np.inner, "numpy.inner"),
|
|
||||||
(np.outer, "numpy.outer"),
|
|
||||||
(np.dot, "numpy.dot"),
|
|
||||||
(np.vdot, "numpy.vdot"),
|
|
||||||
(np.linalg.det, "numpy.linalg.det"),
|
|
||||||
# TODO: This needs a custom implementation to differentiate when
|
|
||||||
# axes is specified (this version will fail).
|
|
||||||
(np.transpose, "numpy.transpose"),
|
|
||||||
)
|
|
||||||
for f, op_name in GENERIC_FUNCS:
|
|
||||||
self.register_array_func(f, GenericArrayFuncEmitter(op_name))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
|
@ -1,296 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Iterable, Optional
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mlir import ir as _ir
|
|
||||||
from mlir.dialects import std as std_ops
|
|
||||||
|
|
||||||
from npcomp import _cext
|
|
||||||
from npcomp.dialects import basicpy as basicpy_ops
|
|
||||||
from npcomp.dialects import numpy as numpy_ops
|
|
||||||
|
|
||||||
from ..exporter import *
|
|
||||||
from ..types import *
|
|
||||||
from ..compiler.utils.mlir_utils import *
|
|
||||||
|
|
||||||
from .context import *
|
|
||||||
from .emitters import *
|
|
||||||
|
|
||||||
|
|
||||||
class ModuleBuilder:
|
|
||||||
"""Builds an MLIR module by tracing functions."""
|
|
||||||
|
|
||||||
__slots__ = [
|
|
||||||
"emitters",
|
|
||||||
"ic",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
mlir_context: Optional[_ir.Context] = None,
|
|
||||||
emitter_registry=None):
|
|
||||||
ic = self.ic = ImportContext(mlir_context)
|
|
||||||
ic.module = _ir.Module.create(loc=ic.loc)
|
|
||||||
self.emitters = (emitter_registry
|
|
||||||
if emitter_registry else EmitterRegistry.create_default())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def module(self):
|
|
||||||
return self.ic.module
|
|
||||||
|
|
||||||
def trace(self, *export_py_funcs: ExportPyFunction):
|
|
||||||
"""Traces exported py functions."""
|
|
||||||
for export_py_func in export_py_funcs:
|
|
||||||
assert isinstance(export_py_func, ExportPyFunction), (
|
|
||||||
"Expected an exported python function (from the Exporter class)")
|
|
||||||
tracer = FunctionTracer(self, export_py_func)
|
|
||||||
with tracer:
|
|
||||||
tracer.trace()
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionTracer(TraceContext):
|
|
||||||
"""A trace of a single function."""
|
|
||||||
__slots__ = [
|
|
||||||
"module_builder",
|
|
||||||
"epf",
|
|
||||||
"_args_array_params",
|
|
||||||
"_f",
|
|
||||||
"_f_types",
|
|
||||||
"_ic",
|
|
||||||
"_python_args",
|
|
||||||
"_result_array_params",
|
|
||||||
"_traced_arrays",
|
|
||||||
"_external_arrays",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
|
||||||
super().__init__(desc="[trace of %s]" % epf.__name__)
|
|
||||||
self.module_builder = module_builder
|
|
||||||
self.epf = epf
|
|
||||||
self._traced_arrays = {} # Mapping of TracedArray to current consumer value
|
|
||||||
self._external_arrays = {} # Mapping of id to (ndarray, ir.Value)
|
|
||||||
self._validate()
|
|
||||||
|
|
||||||
# Alias some parent members for convenience.
|
|
||||||
self._ic = module_builder.ic
|
|
||||||
with self._ic.context:
|
|
||||||
# Extract ArrayParams for all args and results.
|
|
||||||
self._args_array_params = [
|
|
||||||
ArrayParams.from_constraints(arg.constraints)
|
|
||||||
for arg in self.epf.sig.args
|
|
||||||
]
|
|
||||||
self._python_args = [None] * len(self._args_array_params)
|
|
||||||
self._result_array_params = ArrayParams.from_constraints(
|
|
||||||
self.epf.sig.result.constraints)
|
|
||||||
|
|
||||||
# Create the MLIR function.
|
|
||||||
self._f, self._f_types = self._create_mlir_function()
|
|
||||||
self._create_trace_roots()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def entry_block(self) -> _ir.Block:
|
|
||||||
return self._f.regions[0].blocks[0]
|
|
||||||
|
|
||||||
def trace(self):
|
|
||||||
# Invoke the python function with placeholders.
|
|
||||||
# TODO: More sophisticated signature merging
|
|
||||||
# TODO: Multiple results
|
|
||||||
# TODO: Error reporting
|
|
||||||
ic = self._ic
|
|
||||||
ic.insert_end_of_block(self.entry_block)
|
|
||||||
with ic.context:
|
|
||||||
py_results = (self.epf.pyfunc(*self._python_args),)
|
|
||||||
if len(py_results) != len(self._f_types):
|
|
||||||
raise TracingError("Traced function returned != %d results: %r" % (
|
|
||||||
len(self._f_types),
|
|
||||||
py_results,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Narrow all results to the declared return types.
|
|
||||||
return_operands = []
|
|
||||||
for py_result, mlir_result_type in zip(py_results, self._f_types):
|
|
||||||
mlir_result = self.get_traced_array_value(py_result)
|
|
||||||
# narrow to declared result type.
|
|
||||||
return_operands.extend(
|
|
||||||
numpy_ops.NarrowOp(mlir_result_type,
|
|
||||||
mlir_result,
|
|
||||||
loc=ic.loc,
|
|
||||||
ip=ic.ip).results)
|
|
||||||
std_ops.ReturnOp(return_operands, loc=ic.loc, ip=ic.ip)
|
|
||||||
ic.pop_ip()
|
|
||||||
|
|
||||||
def set_traced_array(self, traced_array, value):
|
|
||||||
"""Sets the current SSA value for a traced_array."""
|
|
||||||
assert isinstance(traced_array, TracedArray)
|
|
||||||
self._traced_arrays[traced_array] = value
|
|
||||||
|
|
||||||
def get_traced_array_value(self, traced_array):
|
|
||||||
if not isinstance(traced_array, TracedArray):
|
|
||||||
# Generic import of external value. For now, we just treat these as
|
|
||||||
# local consts.
|
|
||||||
return self._get_external_array_value(traced_array)
|
|
||||||
|
|
||||||
traced_value = self._traced_arrays.get(traced_array)
|
|
||||||
if traced_value is None:
|
|
||||||
raise TracingError("Unregistered traced array: %r", (traced_array,))
|
|
||||||
return traced_value
|
|
||||||
|
|
||||||
def _get_external_array_value(self, external_array):
|
|
||||||
ic = self._ic
|
|
||||||
if not isinstance(external_array, np.ndarray):
|
|
||||||
raise TracingError("Expected ndarray but got: %r" % (external_array,))
|
|
||||||
found_it = self._external_arrays.get(id(external_array))
|
|
||||||
if found_it:
|
|
||||||
return found_it[1]
|
|
||||||
# Import it.
|
|
||||||
dense_attr = _ir.DenseElementsAttr.get(external_array, context=ic.context)
|
|
||||||
const_value = std_ops.ConstantOp(dense_attr.type,
|
|
||||||
dense_attr,
|
|
||||||
loc=ic.loc,
|
|
||||||
ip=ic.ip).result
|
|
||||||
self._external_arrays[id(external_array)] = (external_array, const_value)
|
|
||||||
return const_value
|
|
||||||
|
|
||||||
def _validate(self):
|
|
||||||
if not all(
|
|
||||||
arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args):
|
|
||||||
raise NotImplementedError("Non NdArray args: %r" % (self.epf.sig.args,))
|
|
||||||
if not self.epf.sig.result.type_class == TypeClass.NdArray:
|
|
||||||
raise NotImplementedError("Non NdArray result: %r" %
|
|
||||||
(self.epf.sig.result,))
|
|
||||||
|
|
||||||
def _create_mlir_function(self):
|
|
||||||
ic = self._ic
|
|
||||||
epf = self.epf
|
|
||||||
f_args = [
|
|
||||||
_ir.Type.parse(ap.mlir_tensor_type_asm)
|
|
||||||
for ap in self._args_array_params
|
|
||||||
]
|
|
||||||
f_types = [_ir.Type.parse(self._result_array_params.mlir_tensor_type_asm)]
|
|
||||||
ic.insert_end_of_block(ic.module.body)
|
|
||||||
f_type = _ir.FunctionType.get(f_args, f_types)
|
|
||||||
f, _ = ic.FuncOp(epf.__name__, f_type, create_entry_block=True)
|
|
||||||
return f, f_types
|
|
||||||
|
|
||||||
def _create_trace_roots(self):
|
|
||||||
entry_block = self.entry_block
|
|
||||||
for index, ap in enumerate(self._args_array_params):
|
|
||||||
if ap is not None:
|
|
||||||
ta = TracedArray(self)
|
|
||||||
self.set_traced_array(ta, entry_block.arguments[index])
|
|
||||||
self._python_args[index] = ta
|
|
||||||
|
|
||||||
def _resolve_input_ssa_values(self, trace_values: Iterable[TraceValue]):
|
|
||||||
"""Resolves input python values to SSA values."""
|
|
||||||
ssa_values = []
|
|
||||||
for tv in trace_values:
|
|
||||||
assert tv.type == TraceValueType.NDARRAY, (
|
|
||||||
"Unsupported TraceValueType: %r" % tv.type)
|
|
||||||
ssa_value = self.get_traced_array_value(tv.value)
|
|
||||||
ssa_values.append(ssa_value)
|
|
||||||
return ssa_values
|
|
||||||
|
|
||||||
def _resolve_result_py_values(self,
|
|
||||||
trace_value_types: Iterable[TraceValueType],
|
|
||||||
ssa_values):
|
|
||||||
"""Resolves result SSA values to runtime python values."""
|
|
||||||
assert len(trace_value_types) == len(ssa_values), (
|
|
||||||
"Mismatched emitter declared result types and results")
|
|
||||||
py_values = []
|
|
||||||
for trace_value_type, ssa_value in zip(trace_value_types, ssa_values):
|
|
||||||
assert trace_value_type == TraceValueType.NDARRAY, (
|
|
||||||
"Unsupported TraceValueType: %r" % trace_value_type)
|
|
||||||
py_value = TracedArray(self)
|
|
||||||
self.set_traced_array(py_value, ssa_value)
|
|
||||||
py_values.append(py_value)
|
|
||||||
return py_values
|
|
||||||
|
|
||||||
def _emit_invocation(self, emitter: FuncEmitter, invocation: TraceInvocation):
|
|
||||||
tv_map = emitter.map_invocation(invocation)
|
|
||||||
input_ssa_values = self._resolve_input_ssa_values(tv_map.input_trace_values)
|
|
||||||
request = EmissionRequest(input_ssa_values, ic=self._ic, extra=tv_map.extra)
|
|
||||||
result_ssa_values = emitter.emit(request)
|
|
||||||
py_values = self._resolve_result_py_values(tv_map.result_trace_value_types,
|
|
||||||
result_ssa_values)
|
|
||||||
return emitter.map_results(py_values, tv_map.extra)
|
|
||||||
|
|
||||||
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
|
||||||
emitter = self.module_builder.emitters.lookup_ufunc(ufunc, method)
|
|
||||||
if not emitter:
|
|
||||||
return NotImplemented
|
|
||||||
invocation = TraceInvocation(inputs, kwargs, Protocol.UFUNC, method)
|
|
||||||
return self._emit_invocation(emitter, invocation)
|
|
||||||
|
|
||||||
def _handle_array_func(self, func, types, inputs, kwargs):
|
|
||||||
emitter = self.module_builder.emitters.lookup_array_func(func)
|
|
||||||
if not emitter:
|
|
||||||
return NotImplemented
|
|
||||||
invocation = TraceInvocation(inputs, kwargs, Protocol.ARRAY_FUNC)
|
|
||||||
return self._emit_invocation(emitter, invocation)
|
|
||||||
|
|
||||||
def _emit_slice_value(self, slice_element):
|
|
||||||
ic = self._ic
|
|
||||||
if slice_element == None:
|
|
||||||
return basicpy_ops.SingletonOp(ic.none_type, loc=ic.loc, ip=ic.ip).result
|
|
||||||
elif slice_element == Ellipsis:
|
|
||||||
return basicpy_ops.SingletonOp(ic.ellipsis_type, loc=ic.loc,
|
|
||||||
ip=ic.ip).result
|
|
||||||
elif isinstance(slice_element, int):
|
|
||||||
return std_ops.ConstantOp(ic.index_type,
|
|
||||||
_ir.IntegerAttr.get(ic.index_type,
|
|
||||||
slice_element),
|
|
||||||
loc=ic.loc,
|
|
||||||
ip=ic.ip).result
|
|
||||||
elif isinstance(slice_element, slice):
|
|
||||||
return self._emit_slice_object(slice_element)
|
|
||||||
else:
|
|
||||||
# Assume array convertible.
|
|
||||||
raise NotImplementedError(
|
|
||||||
"TODO: Slicing with generic arrays not yet implemented")
|
|
||||||
|
|
||||||
def _emit_slice_object(self, slice_object: slice):
|
|
||||||
ic = self._ic
|
|
||||||
|
|
||||||
def emit_index(index):
|
|
||||||
if index is None:
|
|
||||||
return basicpy_ops.SingletonOp(ic.none_type, loc=ic.loc,
|
|
||||||
ip=ic.ip).result
|
|
||||||
else:
|
|
||||||
return std_ops.ConstantOp(ic.index_type,
|
|
||||||
_ir.IntegerAttr.get(ic.index_type,
|
|
||||||
int(index)),
|
|
||||||
loc=ic.loc,
|
|
||||||
ip=ic.ip).result
|
|
||||||
|
|
||||||
start = emit_index(slice_object.start)
|
|
||||||
stop = emit_index(slice_object.stop)
|
|
||||||
step = emit_index(slice_object.step)
|
|
||||||
result_type = _cext.slot_object_type(ic.context, "slice",
|
|
||||||
[start.type, stop.type, step.type])
|
|
||||||
return basicpy_ops.SlotObjectMakeOp(result_type, [start, stop, step],
|
|
||||||
loc=ic.loc,
|
|
||||||
ip=ic.ip).result
|
|
||||||
|
|
||||||
def _handle_array_getitem(self, array, key):
|
|
||||||
ic = self._ic
|
|
||||||
array_value = self.get_traced_array_value(array)
|
|
||||||
# Array slicing is always based on a tuple.
|
|
||||||
slice_tuple = key if isinstance(key, tuple) else (key,)
|
|
||||||
# Resolve and emit each slice element.
|
|
||||||
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
|
|
||||||
result_value = numpy_ops.GetSliceOp(ic.unknown_array_type,
|
|
||||||
array_value,
|
|
||||||
slice_values,
|
|
||||||
loc=ic.loc,
|
|
||||||
ip=ic.ip).result
|
|
||||||
result_array = TracedArray(self)
|
|
||||||
self.set_traced_array(result_array, result_value)
|
|
||||||
return result_array
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
36
setup.py
36
setup.py
|
@ -17,11 +17,21 @@ import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from setuptools import find_packages, setup, Extension
|
from distutils.command.build import build as _build
|
||||||
|
from setuptools import find_namespace_packages, setup, Extension
|
||||||
from setuptools.command.build_ext import build_ext
|
from setuptools.command.build_ext import build_ext
|
||||||
from setuptools.command.build_py import build_py
|
from setuptools.command.build_py import build_py
|
||||||
|
|
||||||
|
|
||||||
|
# Build phase discovery is unreliable. Just tell it what phases to run.
|
||||||
|
class CustomBuild(_build):
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.run_command("build_py")
|
||||||
|
self.run_command("build_ext")
|
||||||
|
self.run_command("build_scripts")
|
||||||
|
|
||||||
|
|
||||||
class CMakeExtension(Extension):
|
class CMakeExtension(Extension):
|
||||||
|
|
||||||
def __init__(self, name, sourcedir=""):
|
def __init__(self, name, sourcedir=""):
|
||||||
|
@ -79,21 +89,25 @@ setup(
|
||||||
long_description="",
|
long_description="",
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CMakeExtension("mlir._mlir_libs._mlir"),
|
CMakeExtension("npcomp._mlir_libs._mlir"),
|
||||||
CMakeExtension("mlir._mlir_libs._npcomp"),
|
CMakeExtension("npcomp._mlir_libs._npcomp"),
|
||||||
# TODO: We don't really want these but they are along for the ride.
|
# TODO: We don't really want these but they are along for the ride.
|
||||||
CMakeExtension("mlir._mlir_libs._mlirAsyncPasses"),
|
CMakeExtension("npcomp._mlir_libs._mlirAsyncPasses"),
|
||||||
CMakeExtension("mlir._mlir_libs._mlirConversions"),
|
CMakeExtension("npcomp._mlir_libs._mlirConversions"),
|
||||||
CMakeExtension("mlir._mlir_libs._mlirTransforms"),
|
CMakeExtension("npcomp._mlir_libs._mlirTransforms"),
|
||||||
CMakeExtension("mlir._mlir_libs._mlirSparseTensorPasses"),
|
CMakeExtension("npcomp._mlir_libs._mlirSparseTensorPasses"),
|
||||||
CMakeExtension("mlir._mlir_libs._mlirAllPassesRegisration"),
|
CMakeExtension("npcomp._mlir_libs._mlirAllPassesRegisration"),
|
||||||
CMakeExtension("mlir._mlir_libs._mlirLinalgPasses"),
|
CMakeExtension("npcomp._mlir_libs._mlirLinalgPasses"),
|
||||||
CMakeExtension("mlir._mlir_libs._mlirGPUPasses"),
|
CMakeExtension("npcomp._mlir_libs._mlirGPUPasses"),
|
||||||
],
|
],
|
||||||
cmdclass={
|
cmdclass={
|
||||||
|
"build": CustomBuild,
|
||||||
"built_ext": NoopBuildExtension,
|
"built_ext": NoopBuildExtension,
|
||||||
"build_py": CMakeBuild,
|
"build_py": CMakeBuild,
|
||||||
},
|
},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
packages=find_packages(),
|
packages=find_namespace_packages(include=[
|
||||||
|
"npcomp",
|
||||||
|
"npcomp.*",
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,7 +17,6 @@ set(NPCOMP_TEST_DEPENDS
|
||||||
npcomp-opt
|
npcomp-opt
|
||||||
refback-run
|
refback-run
|
||||||
NPCOMPPythonModules
|
NPCOMPPythonModules
|
||||||
NPCOMPMLIRPythonModules
|
|
||||||
)
|
)
|
||||||
|
|
||||||
add_lit_testsuite(check-npcomp-lit "Running the npcomp regression tests"
|
add_lit_testsuite(check-npcomp-lit "Running the npcomp regression tests"
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import npcomp as npc
|
|
||||||
from npcomp.types import *
|
|
||||||
|
|
||||||
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
|
|
||||||
bias = np.random.uniform(size=(4,)).astype(np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def constants(a: np.ndarray) -> np.ndarray:
|
|
||||||
return np.dot(a, weights) + bias
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Implement subclassing and deriving constraints by run
|
|
||||||
exp = npc.Exporter()
|
|
||||||
exp.constants = constants
|
|
||||||
|
|
||||||
mb = npc.tracing.ModuleBuilder()
|
|
||||||
mb.trace(exp.constants)
|
|
||||||
# CHECK-LABEL: func @constants(
|
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
|
||||||
# CHECK: %[[VAL_1:.*]] = constant dense<{{.*}}> : tensor<16x4xf32>
|
|
||||||
# CHECK: %[[VAL_2:.*]] = numpy.dot %[[VAL_0]], %[[VAL_1]] : (tensor<*x!numpy.any_dtype>, tensor<16x4xf32>) -> tensor<*x!basicpy.UnknownType>
|
|
||||||
# CHECK: %[[VAL_3:.*]] = constant dense<{{.*}}> : tensor<4xf32>
|
|
||||||
# CHECK: %[[VAL_4:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[VAL_2]], %[[VAL_3]]) : (tensor<*x!basicpy.UnknownType>, tensor<4xf32>) -> tensor<*x!basicpy.UnknownType>
|
|
||||||
# CHECK: %[[VAL_5:.*]] = numpy.narrow %[[VAL_4]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
|
|
||||||
# CHECK: return %[[VAL_5]] : tensor<*x!numpy.any_dtype>
|
|
||||||
# CHECK: }
|
|
||||||
print(mb.module)
|
|
|
@ -1,38 +0,0 @@
|
||||||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import npcomp as npc
|
|
||||||
from npcomp.types import *
|
|
||||||
|
|
||||||
|
|
||||||
def dot2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
||||||
return np.dot(a, b)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Implement subclassing and deriving constraints by run
|
|
||||||
exp = npc.Exporter()
|
|
||||||
exp.dot2d = dot2d
|
|
||||||
exp.dot2d.sig.args["a"] += Shape(4, 16)
|
|
||||||
exp.dot2d.sig.args["a"] += DynamicDim(0)
|
|
||||||
exp.dot2d.sig.args["a"] += DType(np.float32)
|
|
||||||
exp.dot2d.sig.args["b"] += Shape(16, 32)
|
|
||||||
exp.dot2d.sig.args["b"] += DType(np.float32)
|
|
||||||
exp.dot2d.sig.result += Shape(4, 32)
|
|
||||||
exp.dot2d.sig.result += DynamicDim(0)
|
|
||||||
exp.dot2d.sig.result += DType(np.float32)
|
|
||||||
|
|
||||||
mb = npc.tracing.ModuleBuilder()
|
|
||||||
mb.trace(exp.dot2d)
|
|
||||||
|
|
||||||
# CHECK-LABEL: func @dot2d(
|
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<?x16xf32>,
|
|
||||||
# CHECK-SAME: %[[VAL_1:.*]]: tensor<16x32xf32>) -> tensor<?x32xf32> {
|
|
||||||
# CHECK: %[[VAL_2:.*]] = numpy.dot %[[VAL_0]], %[[VAL_1]] : (tensor<?x16xf32>, tensor<16x32xf32>) -> tensor<*x!basicpy.UnknownType>
|
|
||||||
# CHECK: %[[VAL_3:.*]] = numpy.narrow %[[VAL_2]] : (tensor<*x!basicpy.UnknownType>) -> tensor<?x32xf32>
|
|
||||||
# CHECK: return %[[VAL_3]] : tensor<?x32xf32>
|
|
||||||
# CHECK: }
|
|
||||||
print(mb.module)
|
|
|
@ -1,41 +0,0 @@
|
||||||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NUMPY_EXPERIMENTAL_ARRAY_FUNCTION"] = "1"
|
|
||||||
|
|
||||||
from npcomp.types import *
|
|
||||||
from npcomp.exporter import *
|
|
||||||
from npcomp.tracing.mlir_trace import *
|
|
||||||
|
|
||||||
|
|
||||||
def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
||||||
return a * b + a + b
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Implement subclassing and deriving constraints by run
|
|
||||||
exp = Exporter()
|
|
||||||
exp.simple_mul = simple_mul
|
|
||||||
exp.simple_mul.sig.args["a"] += Shape(1, 4)
|
|
||||||
exp.simple_mul.sig.args["a"] += DynamicDim(0)
|
|
||||||
exp.simple_mul.sig.args["a"] += DType(np.float32)
|
|
||||||
exp.simple_mul.sig.args["b"] += Shape(1)
|
|
||||||
exp.simple_mul.sig.args["b"] += DType(np.float32)
|
|
||||||
exp.simple_mul.sig.result += Shape(1, 4)
|
|
||||||
exp.simple_mul.sig.result += DynamicDim(0)
|
|
||||||
exp.simple_mul.sig.result += DType(np.float32)
|
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
|
||||||
mb.trace(exp.simple_mul)
|
|
||||||
# This test exercises the full tracing path and incidentally checks the ops.
|
|
||||||
# CHECK: func @simple_mul(%arg0: tensor<?x4xf32>, %arg1: tensor<1xf32>) -> tensor<?x4xf32> {
|
|
||||||
# CHECK: %0 = numpy.builtin_ufunc_call<"numpy.multiply"> (%arg0, %arg1) : (tensor<?x4xf32>, tensor<1xf32>) -> tensor<*x!basicpy.UnknownType>
|
|
||||||
# CHECK: %1 = numpy.builtin_ufunc_call<"numpy.add"> (%0, %arg0) : (tensor<*x!basicpy.UnknownType>, tensor<?x4xf32>) -> tensor<*x!basicpy.UnknownType>
|
|
||||||
# CHECK: %2 = numpy.builtin_ufunc_call<"numpy.add"> (%1, %arg1) : (tensor<*x!basicpy.UnknownType>, tensor<1xf32>) -> tensor<*x!basicpy.UnknownType>
|
|
||||||
# CHECK: %3 = numpy.narrow %2 : (tensor<*x!basicpy.UnknownType>) -> tensor<?x4xf32>
|
|
||||||
# CHECK: return %3 : tensor<?x4xf32>
|
|
||||||
# CHECK: }
|
|
||||||
print(str(mb.module))
|
|
|
@ -1,48 +0,0 @@
|
||||||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import npcomp as npc
|
|
||||||
from npcomp.types import *
|
|
||||||
|
|
||||||
|
|
||||||
def slice_array1(a: np.ndarray) -> np.ndarray:
|
|
||||||
return a[1, 2:10:2, 3:4, ..., :, 0]
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Implement subclassing and deriving constraints by run
|
|
||||||
exp = npc.Exporter()
|
|
||||||
exp.slice_array1 = slice_array1
|
|
||||||
|
|
||||||
mb = npc.tracing.ModuleBuilder()
|
|
||||||
mb.trace(exp.slice_array1)
|
|
||||||
|
|
||||||
# TODO: The numpy.get_slice op emission should be analyzed: it probably
|
|
||||||
# needs to both accept and produce either arrays or tensors and the following
|
|
||||||
# narrow should do likewise.
|
|
||||||
# CHECK-LABEL: func @slice_array1(
|
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
|
||||||
# CHECK: %[[VAL_1:.*]] = constant 1 : index
|
|
||||||
# CHECK: %[[VAL_2:.*]] = constant 2 : index
|
|
||||||
# CHECK: %[[VAL_3:.*]] = constant 10 : index
|
|
||||||
# CHECK: %[[VAL_4:.*]] = constant 2 : index
|
|
||||||
# CHECK: %[[VAL_5:.*]] = basicpy.slot_object_make(%[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) -> !basicpy.SlotObject<slice, index, index, index>
|
|
||||||
# CHECK: %[[VAL_6:.*]] = constant 3 : index
|
|
||||||
# CHECK: %[[VAL_7:.*]] = constant 4 : index
|
|
||||||
# CHECK: %[[VAL_8:.*]] = basicpy.singleton : !basicpy.NoneType
|
|
||||||
# CHECK: %[[VAL_9:.*]] = basicpy.slot_object_make(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) -> !basicpy.SlotObject<slice, index, index, !basicpy.NoneType>
|
|
||||||
# CHECK: %[[VAL_10:.*]] = basicpy.singleton : !basicpy.EllipsisType
|
|
||||||
# CHECK: %[[VAL_11:.*]] = basicpy.singleton : !basicpy.NoneType
|
|
||||||
# CHECK: %[[VAL_12:.*]] = basicpy.singleton : !basicpy.NoneType
|
|
||||||
# CHECK: %[[VAL_13:.*]] = basicpy.singleton : !basicpy.NoneType
|
|
||||||
# CHECK: %[[VAL_14:.*]] = basicpy.slot_object_make(%[[VAL_11]], %[[VAL_12]], %[[VAL_13]]) -> !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>
|
|
||||||
# CHECK: %[[VAL_15:.*]] = constant 0 : index
|
|
||||||
# CHECK: %[[VAL_16:.*]] = numpy.get_slice %[[VAL_0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_9]], %[[VAL_10]], %[[VAL_14]], %[[VAL_15]] : (tensor<*x!numpy.any_dtype>, index, !basicpy.SlotObject<slice, index, index, index>, !basicpy.SlotObject<slice, index, index, !basicpy.NoneType>, !basicpy.EllipsisType, !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>, index) -> !numpy.ndarray<*:?>
|
|
||||||
# CHECK: %[[VAL_17:.*]] = numpy.narrow %[[VAL_16]] : (!numpy.ndarray<*:?>) -> tensor<*x!numpy.any_dtype>
|
|
||||||
# CHECK: return %[[VAL_17]] : tensor<*x!numpy.any_dtype>
|
|
||||||
# CHECK: }
|
|
||||||
|
|
||||||
print(mb.module)
|
|
|
@ -1,42 +0,0 @@
|
||||||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import npcomp as npc
|
|
||||||
from npcomp.types import *
|
|
||||||
|
|
||||||
|
|
||||||
def transpose_attribute(a: np.ndarray) -> np.ndarray:
|
|
||||||
return a.T
|
|
||||||
|
|
||||||
|
|
||||||
def transpose(a: np.ndarray) -> np.ndarray:
|
|
||||||
return np.transpose(a)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Implement subclassing and deriving constraints by run
|
|
||||||
exp = npc.Exporter()
|
|
||||||
exp.transpose_attribute = transpose_attribute
|
|
||||||
exp.transpose = transpose
|
|
||||||
|
|
||||||
mb = npc.tracing.ModuleBuilder()
|
|
||||||
mb.trace(exp.transpose_attribute, exp.transpose)
|
|
||||||
|
|
||||||
# TODO: Consolidate any_dtype -> UnknownType.
|
|
||||||
# CHECK-LABEL: func @transpose_attribute(
|
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
|
||||||
# CHECK: %[[VAL_1:.*]] = numpy.transpose %[[VAL_0]] : (tensor<*x!numpy.any_dtype>) -> tensor<*x!basicpy.UnknownType>
|
|
||||||
# CHECK: %[[VAL_2:.*]] = numpy.narrow %[[VAL_1]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
|
|
||||||
# CHECK: return %[[VAL_2]] : tensor<*x!numpy.any_dtype>
|
|
||||||
# CHECK: }
|
|
||||||
|
|
||||||
# CHECK-LABEL: func @transpose(
|
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
|
||||||
# CHECK: %[[VAL_1:.*]] = numpy.transpose %[[VAL_0]] : (tensor<*x!numpy.any_dtype>) -> tensor<*x!basicpy.UnknownType>
|
|
||||||
# CHECK: %[[VAL_2:.*]] = numpy.narrow %[[VAL_1]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
|
|
||||||
# CHECK: return %[[VAL_2]] : tensor<*x!numpy.any_dtype>
|
|
||||||
# CHECK: }
|
|
||||||
print(mb.module)
|
|
|
@ -24,9 +24,6 @@ def run_doctest(mod):
|
||||||
|
|
||||||
TEST_MODULES = (
|
TEST_MODULES = (
|
||||||
"npcomp.compiler.numpy.py_value_utils",
|
"npcomp.compiler.numpy.py_value_utils",
|
||||||
"npcomp.tracing.context",
|
|
||||||
"npcomp.tracing.emitters",
|
|
||||||
"npcomp.tracing.mlir_trace",
|
|
||||||
"npcomp.types",
|
"npcomp.types",
|
||||||
"npcomp.exporter",
|
"npcomp.exporter",
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
# Some checks that we can import the various extensions and libraries and
|
|
||||||
# not have symbol collisions or other goings on.
|
|
||||||
# RUN: %PYTHON %s
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
print(f"PYTHONPATH={sys.path}")
|
|
||||||
|
|
||||||
import mlir.ir
|
|
||||||
import npcomp
|
|
||||||
|
|
||||||
print("Extensions all loaded")
|
|
Loading…
Reference in New Issue