mirror of https://github.com/llvm/torch-mlir
NFC: Delete npcomp python API and switch to upstream.
* Most updates are mechanical except: * python/npcomp/__init__.py and python/NpcompModule.cpp: New init/registration bits to replace some automatic things being done in the old bindings. Also an annoying linkage hack that I'll need to triage next. * NpcompModule.cpp: New python helpers for custom types and other hard to reach items (for the new bindings). * PybindUtils.h: Extended type casting so that the local extension can directly exchange Mlir* C types. * python/npcomp/dialects/*: Build support and ODS bindings for local dialects. * mlir_utils.py: Defines an ImportContext to replace the old/bad "Helper" class that tracked locations, and insertion points. This has a number of methods on it that would be good candidates to think about better ways to do them upstream. * Also hoisted a few stand-alone samples to dedicated unit tests as they covered important things. * More cleanup can be done, but keeping this patch as mechanical as possible to stay in NFC land (this is big enough).pull/144/head
parent
951d7ff42c
commit
3f706473fd
|
@ -21,6 +21,9 @@ extern "C" {
|
|||
*/
|
||||
void npcompRegisterAllDialects(MlirContext context);
|
||||
|
||||
/** Registers all NPComp passes for symbolic access with the global registry. */
|
||||
void npcompRegisterAllPasses();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -75,6 +75,9 @@ MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
|||
/// Helper that gets an equivalent NdArrayType from a ShapedType.
|
||||
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
|
||||
|
||||
/// Helper that converts an NdArrayType to a TensorType.
|
||||
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType);
|
||||
|
||||
/*============================================================================*/
|
||||
/* None type. */
|
||||
/*============================================================================*/
|
||||
|
@ -85,6 +88,14 @@ int npcompTypeIsANone(MlirType t);
|
|||
/** Gets the type of the singleton 'None'. */
|
||||
MlirType npcompNoneTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* SlotObject type. */
|
||||
/*============================================================================*/
|
||||
|
||||
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
|
||||
intptr_t slotTypeCount,
|
||||
const MlirType *slotTypes);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Tuple type. */
|
||||
/*============================================================================*/
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- BasicPyOps.td - Basic Python ops --------------------*- tablegen -*-===//
|
||||
//===- BasicpyOps.td - Basic Python ops --------------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -9,7 +9,7 @@
|
|||
#ifndef NPCOMP_DIALECT_BASICPY_IR_BASICPY_OPS
|
||||
#define NPCOMP_DIALECT_BASICPY_IR_BASICPY_OPS
|
||||
|
||||
include "BasicpyDialect.td"
|
||||
include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
@ -510,7 +510,6 @@ def Basicpy_SlotObjectMakeOp : Basicpy_Op<"slot_object_make", [
|
|||
!basicpy.SlotObject<slice, !basicpy.NoneType>
|
||||
}];
|
||||
let arguments = (ins
|
||||
StrAttr:$className,
|
||||
// TODO: Tighter constraints on allowable types.
|
||||
Variadic<AnyType>:$slots
|
||||
);
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#ifndef NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
|
||||
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
|
||||
|
||||
include "NumpyDialect.td"
|
||||
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
|
||||
include "npcomp/Typing/Analysis/CPA/Interfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
@ -167,7 +167,7 @@ def Numpy_TransposeOp : Numpy_Op<"transpose", []> {
|
|||
// See: https://docs.scipy.org/doc/numpy/user/basics.indexing.html
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def Numpy_GetSlice : Numpy_Op<"get_slice", []> {
|
||||
def Numpy_GetSliceOp : Numpy_Op<"get_slice", []> {
|
||||
let summary = "Gets a slice of an array";
|
||||
let description = [{
|
||||
This op encapsulates all forms of indexing into an array by taking a
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
//===- MlirInit.h - MLIR config and init ----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_MLIRINIT_H
|
||||
#define NPCOMP_PYTHON_MLIRINIT_H
|
||||
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
// Note that the CPP module is compiled without RTTI or exceptions, unlike
|
||||
// the rest of the pybind code. Therefore, we also stash some trampolines
|
||||
// here for parts of the code that are not RTTI-compatible.
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class MLIRContext;
|
||||
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
// One time initialization.
|
||||
bool npcompMlirInitialize();
|
||||
|
||||
// Loads globally registered dialects into the MLIRContext.
|
||||
// This is temporary until there is an upstream story for handling dialect
|
||||
// registration in python-based systems.
|
||||
void loadGlobalDialectsIntoContext(MLIRContext *context);
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_PYTHON_MLIRINIT_H
|
|
@ -1,202 +0,0 @@
|
|||
//===- MlirIr.h - MLIR IR Bindings ----------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_MLIR_IR_H
|
||||
#define NPCOMP_PYTHON_MLIR_IR_H
|
||||
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Region.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
struct PyContext;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename ListTy, typename ItemWrapperTy> class PyIpListWrapper {
|
||||
public:
|
||||
using ThisTy = PyIpListWrapper<ListTy, ItemWrapperTy>;
|
||||
static void bind(py::module m, const char *className);
|
||||
PyIpListWrapper(ListTy &list) : list(list) {}
|
||||
|
||||
private:
|
||||
ListTy &list;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Wrapper types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Wrapper around an Operation*.
|
||||
struct PyBaseOperation {
|
||||
virtual ~PyBaseOperation();
|
||||
static void bind(py::module m);
|
||||
virtual Operation *getOperation() = 0;
|
||||
};
|
||||
|
||||
/// Wrapper around Module, capturing a PyContext reference.
|
||||
struct PyModuleOp : PyBaseOperation {
|
||||
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
|
||||
: context(context), moduleOp(moduleOp) {
|
||||
assert(moduleOp);
|
||||
}
|
||||
~PyModuleOp();
|
||||
static void bind(py::module m);
|
||||
Operation *getOperation() override;
|
||||
std::string toAsm(bool enableDebugInfo, bool prettyForm,
|
||||
int64_t largeElementLimit);
|
||||
|
||||
std::shared_ptr<PyContext> context;
|
||||
ModuleOp moduleOp;
|
||||
};
|
||||
|
||||
/// Wrapper around an Operation*.
|
||||
struct PyOperationRef : PyBaseOperation {
|
||||
PyOperationRef(Operation *operation) : operation(operation) {
|
||||
assert(operation);
|
||||
}
|
||||
PyOperationRef(Operation &operation) : operation(&operation) {}
|
||||
~PyOperationRef();
|
||||
static void bind(py::module m);
|
||||
Operation *getOperation() override;
|
||||
|
||||
Operation *operation;
|
||||
};
|
||||
|
||||
/// Wrapper around SymbolTable.
|
||||
struct PySymbolTable {
|
||||
PySymbolTable(SymbolTable &symbolTable) : symbolTable(symbolTable) {}
|
||||
static void bind(py::module m);
|
||||
SymbolTable &symbolTable;
|
||||
};
|
||||
|
||||
/// Wrapper around Value.
|
||||
struct PyValue {
|
||||
PyValue(Value value) : value(value) { assert(value); }
|
||||
static void bind(py::module m);
|
||||
operator Value() { return value; }
|
||||
Value value;
|
||||
};
|
||||
|
||||
/// Wrapper around Identifier.
|
||||
struct PyIdentifier {
|
||||
PyIdentifier(Identifier identifier) : identifier(identifier) {}
|
||||
static void bind(py::module m);
|
||||
Identifier identifier;
|
||||
};
|
||||
|
||||
/// Wrapper around Attribute.
|
||||
struct PyAttribute {
|
||||
PyAttribute(Attribute attr) : attr(attr) { assert(attr); }
|
||||
static void bind(py::module m);
|
||||
Attribute attr;
|
||||
};
|
||||
|
||||
/// Wrapper around MLIRContext.
|
||||
struct PyContext : std::enable_shared_from_this<PyContext> {
|
||||
PyContext();
|
||||
static void bind(py::module m);
|
||||
PyModuleOp parseAsm(const std::string &asm_text);
|
||||
MLIRContext context;
|
||||
};
|
||||
|
||||
/// Wrapper around a Block&.
|
||||
struct PyBlockRef {
|
||||
PyBlockRef(Block &block) : block(block) {}
|
||||
static void bind(py::module m);
|
||||
Block █
|
||||
};
|
||||
|
||||
/// Wrapper around a Region&.
|
||||
struct PyRegionRef {
|
||||
PyRegionRef(Region ®ion) : region(region) {}
|
||||
static void bind(py::module m);
|
||||
Region ®ion;
|
||||
};
|
||||
|
||||
struct PyType {
|
||||
PyType() = default;
|
||||
PyType(Type type) : type(type) {}
|
||||
static void bind(py::module m);
|
||||
operator Type() { return type; }
|
||||
Type type;
|
||||
};
|
||||
|
||||
/// Wrapper around an OpBuilder reference.
|
||||
/// This class is inherently dangerous because it does not track ownership
|
||||
/// of IR objects that it may be operating on and incorrect usage can cause
|
||||
/// memory access errors, just as it can in C++. It is intended for use by
|
||||
/// higher level constructs that are specifically coded to satisfy object
|
||||
/// lifetime needs.
|
||||
class PyBaseOpBuilder {
|
||||
public:
|
||||
virtual ~PyBaseOpBuilder();
|
||||
static void bind(py::module m);
|
||||
virtual OpBuilder &getBuilder(bool requirePosition = false) = 0;
|
||||
MLIRContext *getContext() { return getBuilder(false).getContext(); }
|
||||
|
||||
// For convenience, we track the current location at the builder level
|
||||
// to avoid lots of parameter passing.
|
||||
void setCurrentLoc(Location loc) { currentLoc = loc; }
|
||||
Location getCurrentLoc() {
|
||||
if (currentLoc) {
|
||||
return Location(currentLoc);
|
||||
} else {
|
||||
return UnknownLoc::get(getBuilder(false).getContext());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
LocationAttr currentLoc;
|
||||
};
|
||||
|
||||
/// Wrapper around an instance of an OpBuilder.
|
||||
class PyOpBuilder : public PyBaseOpBuilder {
|
||||
public:
|
||||
PyOpBuilder(PyContext &context) : builder(&context.context) {}
|
||||
~PyOpBuilder() override;
|
||||
static void bind(py::module m);
|
||||
OpBuilder &getBuilder(bool requirePosition = false) override;
|
||||
|
||||
private:
|
||||
OpBuilder builder;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Custom types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Helper for creating (possibly dialect specific) IR objects. This class
|
||||
/// is intended to be subclassed on the Python side (possibly with multiple
|
||||
/// inheritance) to provide Python level APIs for custom dialects. The base
|
||||
/// class contains helpers for std types and ops.
|
||||
class PyDialectHelper {
|
||||
public:
|
||||
PyDialectHelper(PyContext &context, PyOpBuilder &builder)
|
||||
: context(context), pyOpBuilder(builder) {}
|
||||
static void bind(py::module m);
|
||||
MLIRContext *getContext() { return pyOpBuilder.getContext(); }
|
||||
|
||||
protected:
|
||||
PyContext &context;
|
||||
PyOpBuilder &pyOpBuilder;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_PYTHON_MLIR_IR_H
|
|
@ -1,34 +0,0 @@
|
|||
//===- MlirPass.h - MLIR Pass Bindings ------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_MLIR_PASS_H
|
||||
#define NPCOMP_PYTHON_MLIR_PASS_H
|
||||
|
||||
#include "MlirIr.h"
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
struct PyPassManager {
|
||||
PyPassManager(std::shared_ptr<PyContext> context, bool verifyModules)
|
||||
: passManager(&context->context, OpPassManager::Nesting::Implicit),
|
||||
context(std::move(context)) {
|
||||
passManager.enableVerifier(verifyModules);
|
||||
}
|
||||
static void bind(py::module m);
|
||||
PassManager passManager;
|
||||
|
||||
private:
|
||||
std::shared_ptr<PyContext> context;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_PYTHON_MLIR_PASS_H
|
|
@ -1,28 +0,0 @@
|
|||
//===- NpcompModule.h - Module registrations ------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_NPCOMP_MODULE_H
|
||||
#define NPCOMP_PYTHON_NPCOMP_MODULE_H
|
||||
|
||||
#include "PybindUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
void defineMlirIrModule(py::module m);
|
||||
void defineMlirPassModule(py::module m);
|
||||
void defineMlirCoreDialects(py::module m);
|
||||
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
void defineNpcompDialect(py::module m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_PYTHON_NPCOMP_MODULE_H
|
|
@ -28,11 +28,51 @@ namespace detail {
|
|||
template <typename T>
|
||||
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
|
||||
|
||||
/// Helper to convert a presumed MLIR API object to a capsule, accepting either
|
||||
/// an explicit Capsule (which can happen when two C APIs are communicating
|
||||
/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
|
||||
/// attribute (through which supported MLIR Python API objects export their
|
||||
/// contained API pointer as a capsule). This is intended to be used from
|
||||
/// type casters, which are invoked with a raw handle (unowned). The returned
|
||||
/// object's lifetime may not extend beyond the apiObject handle without
|
||||
/// explicitly having its refcount increased (i.e. on return).
|
||||
static py::object mlirApiObjectToCapsule(py::handle apiObject) {
|
||||
if (PyCapsule_CheckExact(apiObject.ptr()))
|
||||
return py::reinterpret_borrow<py::object>(apiObject);
|
||||
return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
}
|
||||
|
||||
// Note: Currently all of the following support cast from py::object to the
|
||||
// Mlir* C-API type, but only a few light-weight, context-bound ones
|
||||
// implicitly cast the other way because the use case has not yet emerged and
|
||||
// ownership is unclear.
|
||||
|
||||
/// Casts object -> MlirAttribute.
|
||||
template <> struct type_caster<MlirAttribute> {
|
||||
PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = mlirApiObjectToCapsule(src);
|
||||
value = mlirPythonCapsuleToAttribute(capsule.ptr());
|
||||
if (mlirAttributeIsNull(value)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static handle cast(MlirAttribute v, return_value_policy, handle) {
|
||||
auto capsule =
|
||||
py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v));
|
||||
return py::module::import("mlir.ir")
|
||||
.attr("Attribute")
|
||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
||||
.release();
|
||||
}
|
||||
};
|
||||
|
||||
/// Casts object -> MlirContext.
|
||||
template <> struct type_caster<MlirContext> {
|
||||
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
auto capsule = mlirApiObjectToCapsule(src);
|
||||
value = mlirPythonCapsuleToContext(capsule.ptr());
|
||||
if (mlirContextIsNull(value)) {
|
||||
return false;
|
||||
|
@ -41,11 +81,32 @@ template <> struct type_caster<MlirContext> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Casts object -> MlirLocation.
|
||||
template <> struct type_caster<MlirLocation> {
|
||||
PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = mlirApiObjectToCapsule(src);
|
||||
value = mlirPythonCapsuleToLocation(capsule.ptr());
|
||||
if (mlirLocationIsNull(value)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static handle cast(MlirLocation v, return_value_policy, handle) {
|
||||
auto capsule =
|
||||
py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v));
|
||||
return py::module::import("mlir.ir")
|
||||
.attr("Location")
|
||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
||||
.release();
|
||||
}
|
||||
};
|
||||
|
||||
/// Casts object -> MlirModule.
|
||||
template <> struct type_caster<MlirModule> {
|
||||
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
auto capsule = mlirApiObjectToCapsule(src);
|
||||
value = mlirPythonCapsuleToModule(capsule.ptr());
|
||||
if (mlirModuleIsNull(value)) {
|
||||
return false;
|
||||
|
@ -54,11 +115,24 @@ template <> struct type_caster<MlirModule> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Casts object -> MlirOperation.
|
||||
template <> struct type_caster<MlirOperation> {
|
||||
PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = mlirApiObjectToCapsule(src);
|
||||
value = mlirPythonCapsuleToOperation(capsule.ptr());
|
||||
if (mlirOperationIsNull(value)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/// Casts object -> MlirPassManager.
|
||||
template <> struct type_caster<MlirPassManager> {
|
||||
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
auto capsule = mlirApiObjectToCapsule(src);
|
||||
value = mlirPythonCapsuleToPassManager(capsule.ptr());
|
||||
if (mlirPassManagerIsNull(value)) {
|
||||
return false;
|
||||
|
@ -67,6 +141,27 @@ template <> struct type_caster<MlirPassManager> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Casts object -> MlirType.
|
||||
template <> struct type_caster<MlirType> {
|
||||
PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = mlirApiObjectToCapsule(src);
|
||||
value = mlirPythonCapsuleToType(capsule.ptr());
|
||||
if (mlirTypeIsNull(value)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
static handle cast(MlirType t, return_value_policy, handle) {
|
||||
auto capsule =
|
||||
py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t));
|
||||
return py::module::import("mlir.ir")
|
||||
.attr("Type")
|
||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
||||
.release();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ target_link_libraries(NPCOMPBackendRefJITPythonModule
|
|||
MLIRExecutionEngine
|
||||
MLIRTargetLLVMIR
|
||||
|
||||
NPCOMPPythonCommon
|
||||
NPCOMPRefBackendJITHelpers
|
||||
)
|
||||
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Pass.h"
|
||||
#include "npcomp/Python/MlirIr.h"
|
||||
#include "npcomp/Python/MlirPass.h"
|
||||
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
|
||||
|
||||
using llvm::SmallVector;
|
||||
|
@ -21,8 +19,6 @@ using llvm::StringRef;
|
|||
using llvm::Twine;
|
||||
|
||||
// Make namespaces consistent.
|
||||
using mlir::PyModuleOp;
|
||||
using mlir::PyPassManager;
|
||||
using refback::JITModule;
|
||||
using refbackrt::Ref;
|
||||
using refbackrt::Tensor;
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
#include "npcomp-c/Registration.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "npcomp/InitAll.h"
|
||||
|
||||
void npcompRegisterAllDialects(MlirContext context) {
|
||||
|
@ -16,3 +18,11 @@ void npcompRegisterAllDialects(MlirContext context) {
|
|||
// TODO: Don't eagerly load once D88162 is in and clients can do this.
|
||||
unwrap(context)->getDialectRegistry().loadAll(unwrap(context));
|
||||
}
|
||||
|
||||
void npcompRegisterAllPasses() {
|
||||
::mlir::NPCOMP::registerAllPasses();
|
||||
|
||||
// Upstream passes we depend on.
|
||||
::mlir::registerCanonicalizerPass();
|
||||
::mlir::registerSCFToStandardPass();
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "npcomp-c/Types.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
|
@ -85,6 +86,10 @@ MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
|
|||
unwrap(shapedType).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType) {
|
||||
return wrap(unwrap(ndarrayType).cast<Numpy::NdArrayType>().toTensorType());
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* None type. */
|
||||
/*============================================================================*/
|
||||
|
@ -97,6 +102,23 @@ MlirType npcompNoneTypeGet(MlirContext context) {
|
|||
return wrap(Basicpy::NoneType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* SlotObject type. */
|
||||
/*============================================================================*/
|
||||
|
||||
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
|
||||
intptr_t slotTypeCount,
|
||||
const MlirType *slotTypes) {
|
||||
MLIRContext *cppContext = unwrap(context);
|
||||
auto classNameAttr = StringAttr::get(unwrap(className), cppContext);
|
||||
SmallVector<Type> slotTypesCpp;
|
||||
slotTypesCpp.resize(slotTypeCount);
|
||||
for (intptr_t i = 0; i < slotTypeCount; ++i) {
|
||||
slotTypesCpp[i] = unwrap(slotTypes[i]);
|
||||
}
|
||||
return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Tuple type. */
|
||||
/*============================================================================*/
|
||||
|
|
|
@ -2,7 +2,6 @@ add_subdirectory(CAPI)
|
|||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(RefBackend)
|
||||
add_subdirectory(Python)
|
||||
add_subdirectory(Typing)
|
||||
|
||||
if(NPCOMP_ENABLE_REFJIT)
|
||||
|
|
|
@ -1,54 +0,0 @@
|
|||
################################################################################
|
||||
# NPCOMPPythonCommon
|
||||
################################################################################
|
||||
|
||||
include(AddLLVM)
|
||||
include(NpcompPython)
|
||||
|
||||
# TODO: This should not be wired in at such a low/unconditional level.
|
||||
# It is done here to be kept with the other LLVM initialization until a better
|
||||
# place can be found for it.
|
||||
# set(ExtraInit_LIBADD)
|
||||
# if(NPCOMP_ENABLE_REFJIT)
|
||||
# llvm_map_components_to_libnames(refjit_llvm_libs
|
||||
# nativecodegen
|
||||
# )
|
||||
# message(STATUS "Including LLVM libs for RefJit: ${refjit_llvm_libs}")
|
||||
# list(APPEND ExtraInit_LIBADD
|
||||
# ${refjit_llvm_libs})
|
||||
# endif()
|
||||
|
||||
# include_directories(
|
||||
# ${PYTHON_INCLUDE_DIRS}
|
||||
# )
|
||||
|
||||
set(PYBIND_SOURCES
|
||||
MlirInit.cpp
|
||||
MlirIr.cpp
|
||||
MlirPass.cpp
|
||||
NpcompDialect.cpp
|
||||
CoreDialects.cpp
|
||||
)
|
||||
|
||||
add_library(NPCOMPPythonCommon
|
||||
${PYBIND_SOURCES}
|
||||
)
|
||||
|
||||
target_link_libraries(NPCOMPPythonCommon
|
||||
pybind11::module
|
||||
NPCOMPInitAll
|
||||
NPCOMPCAPI
|
||||
|
||||
# Core dialects
|
||||
MLIRSCF
|
||||
|
||||
# Upstream depends
|
||||
MLIRDialect
|
||||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
)
|
||||
|
||||
npcomp_python_target_compile_options(NPCOMPPythonCommon)
|
|
@ -1,64 +0,0 @@
|
|||
//===- NpcompDialect.cpp - Custom dialect classes -------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Python/MlirIr.h"
|
||||
#include "npcomp/Python/NpcompModule.h"
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
class ScfDialectHelper : public PyDialectHelper {
|
||||
public:
|
||||
using PyDialectHelper::PyDialectHelper;
|
||||
|
||||
static void bind(py::module m) {
|
||||
py::class_<ScfDialectHelper, PyDialectHelper>(m, "ScfDialectHelper")
|
||||
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
|
||||
py::keep_alive<1, 3>())
|
||||
.def("scf_yield_op",
|
||||
[](ScfDialectHelper &self,
|
||||
std::vector<PyValue> pyYields) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
llvm::SmallVector<Value, 4> yields(pyYields.begin(),
|
||||
pyYields.end());
|
||||
auto op = opBuilder.create<scf::YieldOp>(loc, yields);
|
||||
return op.getOperation();
|
||||
})
|
||||
.def(
|
||||
"scf_if_op",
|
||||
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
|
||||
PyValue cond, bool withElseRegion) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
|
||||
pyResultTypes.end());
|
||||
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
|
||||
withElseRegion);
|
||||
if (withElseRegion) {
|
||||
return py::make_tuple(
|
||||
PyOperationRef(op),
|
||||
op.getThenBodyBuilder().saveInsertionPoint(),
|
||||
op.getElseBodyBuilder().saveInsertionPoint());
|
||||
} else {
|
||||
return py::make_tuple(
|
||||
PyOperationRef(op),
|
||||
op.getThenBodyBuilder().saveInsertionPoint());
|
||||
}
|
||||
},
|
||||
py::arg("result_types"), py::arg("cond"),
|
||||
py::arg("with_else_region") = false);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::defineMlirCoreDialects(py::module m) { ScfDialectHelper::bind(m); }
|
|
@ -1,72 +0,0 @@
|
|||
//===- MlirInit.cpp -------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Python/MlirInit.h"
|
||||
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "npcomp-c/InitLLVM.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/InitAll.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
|
||||
bool mlir::npcomp::python::npcompMlirInitialize() {
|
||||
// Enable LLVM's signal handler to get nice stack traces.
|
||||
llvm::sys::SetOneShotPipeSignalFunction(
|
||||
llvm::sys::DefaultOneShotPipeSignalHandler);
|
||||
llvm::sys::PrintStackTraceOnErrorSignal("npcomp");
|
||||
|
||||
// Register any pass manager command line options.
|
||||
mlir::registerPassManagerCLOptions();
|
||||
mlir::registerMLIRContextCLOptions();
|
||||
|
||||
std::string program_name = "npcomp";
|
||||
std::vector<const char *> default_options = {program_name.c_str(), nullptr};
|
||||
llvm::cl::ParseCommandLineOptions(1, default_options.data());
|
||||
|
||||
// Pass registration.
|
||||
::mlir::registerAllPasses();
|
||||
::mlir::NPCOMP::registerAllPasses();
|
||||
|
||||
// Initialize code generation.
|
||||
npcompInitializeLLVMCodegen();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void mlir::npcomp::python::loadGlobalDialectsIntoContext(MLIRContext *context) {
|
||||
static mlir::DialectRegistry registry = ([]() {
|
||||
mlir::DialectRegistry registry;
|
||||
::mlir::registerAllDialects(registry);
|
||||
::mlir::NPCOMP::registerAllDialects(registry);
|
||||
return registry;
|
||||
})();
|
||||
registry.loadAll(context);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm,
|
||||
raw_ostream &errorStream = llvm::errs()) {
|
||||
return ::mlir::parsePassPipeline(pipeline, pm, errorStream);
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
} // namespace mlir
|
|
@ -1,987 +0,0 @@
|
|||
//===- MlirIr.cpp - MLIR IR Bindings --------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Python/MlirIr.h"
|
||||
#include "npcomp/Python/MlirInit.h"
|
||||
#include "npcomp/Python/NpcompModule.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Forward declarations
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct PyContext;
|
||||
|
||||
/// Parses an MLIR module from a string.
|
||||
/// For maximum efficiency, the `contents` should be zero terminated.
|
||||
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
||||
MLIRContext *context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Direct type bindings
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void bindInsertPoint(py::module m) {
|
||||
py::class_<OpBuilder::InsertPoint>(m, "InsertPoint");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Internal only template definitions
|
||||
// Since it is only legal to use explicit instantiations of templates in
|
||||
// mlir_ir.h, implementations are kept in this module to keep things scoped
|
||||
// well for the compiler.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename ListTy, typename ItemWrapperTy>
|
||||
void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
|
||||
const char *className) {
|
||||
struct PyItemIterator : public llvm::iterator_adaptor_base<
|
||||
PyItemIterator, typename ListTy::iterator,
|
||||
typename std::iterator_traits<
|
||||
typename ListTy::iterator>::iterator_category,
|
||||
typename ListTy::value_type> {
|
||||
PyItemIterator() = default;
|
||||
PyItemIterator(typename ListTy::iterator &&other)
|
||||
: PyItemIterator::iterator_adaptor_base(std::move(other)) {}
|
||||
ItemWrapperTy operator*() const { return ItemWrapperTy(*this->I); }
|
||||
};
|
||||
|
||||
py::class_<ThisTy>(m, className)
|
||||
.def_property_readonly(
|
||||
"front",
|
||||
[](ThisTy &self) { return ItemWrapperTy(self.list.front()); })
|
||||
.def("__len__", [](ThisTy &self) { return self.list.size(); })
|
||||
.def(
|
||||
"__iter__",
|
||||
[](ThisTy &self) {
|
||||
PyItemIterator begin(self.list.begin());
|
||||
PyItemIterator end(self.list.end());
|
||||
return py::make_iterator(begin, end);
|
||||
},
|
||||
py::keep_alive<0, 1>());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Explicit template instantiations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template class PyIpListWrapper<Region::BlockListType, PyBlockRef>;
|
||||
using PyBlockList = PyIpListWrapper<Region::BlockListType, PyBlockRef>;
|
||||
|
||||
template class PyIpListWrapper<Block::OpListType, PyOperationRef>;
|
||||
using PyOperationList = PyIpListWrapper<Block::OpListType, PyOperationRef>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type mapBufferFormatToType(MLIRContext *context, const std::string &format,
|
||||
py::ssize_t itemSize) {
|
||||
// Floating point formats.
|
||||
if (format == "f")
|
||||
return FloatType::getF32(context);
|
||||
if (format == "d")
|
||||
return FloatType::getF64(context);
|
||||
if (format == "D")
|
||||
return ComplexType::get(FloatType::getF64(context));
|
||||
|
||||
// Signed integer formats.
|
||||
if (format == "b" || format == "h" || format == "i" || format == "l" ||
|
||||
format == "L") {
|
||||
unsigned width = itemSize * 8;
|
||||
return IntegerType::get(context, width,
|
||||
IntegerType::SignednessSemantics::Signed);
|
||||
}
|
||||
|
||||
// Unsigned integer format.
|
||||
if (format == "B" || format == "H" || format == "I" || format == "k" ||
|
||||
format == "K") {
|
||||
unsigned width = itemSize * 8;
|
||||
return IntegerType::get(context, width,
|
||||
IntegerType::SignednessSemantics::Unsigned);
|
||||
}
|
||||
|
||||
return Type();
|
||||
}
|
||||
|
||||
/// Creates a DenseElementsAttr from a python buffer which must have been
|
||||
/// requested to be C-Contiguous.
|
||||
Attribute createDenseElementsAttrFromBuffer(MLIRContext *context,
|
||||
py::buffer_info &array) {
|
||||
Type elementType =
|
||||
mapBufferFormatToType(context, array.format, array.itemsize);
|
||||
if (!elementType) {
|
||||
throw py::raiseValueError(
|
||||
"Unsupported buffer/array type for conversion to DenseElementsAttr");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> shape(array.shape.begin(),
|
||||
array.shape.begin() + array.ndim);
|
||||
RankedTensorType type = RankedTensorType::get(shape, elementType);
|
||||
const char *rawBufferPtr = reinterpret_cast<const char *>(array.ptr);
|
||||
ArrayRef<char> rawBuffer(rawBufferPtr, array.size * array.itemsize);
|
||||
return DenseElementsAttr::getFromRawBuffer(type, rawBuffer, false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Diagnostics
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// RAII class to capture diagnostics for later reporting back to the python
|
||||
/// layer.
|
||||
class DiagnosticCapture {
|
||||
public:
|
||||
DiagnosticCapture(mlir::MLIRContext *mlir_context)
|
||||
: mlir_context(mlir_context) {
|
||||
handler_id = mlir_context->getDiagEngine().registerHandler(
|
||||
[&](Diagnostic &d) -> LogicalResult {
|
||||
diagnostics.push_back(std::move(d));
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
~DiagnosticCapture() {
|
||||
if (mlir_context) {
|
||||
mlir_context->getDiagEngine().eraseHandler(handler_id);
|
||||
}
|
||||
}
|
||||
DiagnosticCapture(DiagnosticCapture &&other) {
|
||||
mlir_context = other.mlir_context;
|
||||
diagnostics.swap(other.diagnostics);
|
||||
handler_id = other.handler_id;
|
||||
other.mlir_context = nullptr;
|
||||
}
|
||||
|
||||
std::vector<mlir::Diagnostic> &getDiagnostics() { return diagnostics; }
|
||||
|
||||
// Consumes/clears diagnostics.
|
||||
std::string consumeDiagnosticsAsString(const char *error_message);
|
||||
void clearDiagnostics() { diagnostics.clear(); }
|
||||
|
||||
private:
|
||||
MLIRContext *mlir_context;
|
||||
std::vector<mlir::Diagnostic> diagnostics;
|
||||
mlir::DiagnosticEngine::HandlerID handler_id;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyDialectHelper
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyDialectHelper::bind(py::module m) {
|
||||
py::class_<PyDialectHelper>(m, "DialectHelper")
|
||||
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
|
||||
py::keep_alive<1, 3>())
|
||||
.def_property_readonly("builder",
|
||||
[](PyDialectHelper &self) -> PyBaseOpBuilder & {
|
||||
return self.pyOpBuilder;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
|
||||
return self.context.shared_from_this();
|
||||
})
|
||||
.def(
|
||||
"op",
|
||||
[](PyDialectHelper &self, const std::string &opNameStr,
|
||||
std::vector<PyType> pyResultTypes, std::vector<PyValue> pyOperands,
|
||||
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
OperationName opName(opNameStr, opBuilder.getContext());
|
||||
SmallVector<Type, 4> types(pyResultTypes.begin(),
|
||||
pyResultTypes.end());
|
||||
SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||
pyOperands.end());
|
||||
DictionaryAttr attrList;
|
||||
if (attrs) {
|
||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||
if (!dictAttrs) {
|
||||
throw py::raiseValueError(
|
||||
"Expected `attrs` to be a DictionaryAttr");
|
||||
}
|
||||
attrList = dictAttrs;
|
||||
} else {
|
||||
attrList = DictionaryAttr::get({}, self.getContext());
|
||||
}
|
||||
Operation *op =
|
||||
Operation::create(loc, opName, types, operands, attrList);
|
||||
opBuilder.insert(op);
|
||||
return op;
|
||||
},
|
||||
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
||||
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
||||
.def(
|
||||
"func_op",
|
||||
[](PyDialectHelper &self, const std::string &name, PyType type,
|
||||
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
|
||||
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
||||
if (!functionType) {
|
||||
throw py::raiseValueError("Illegal function type");
|
||||
}
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
// TODO: Dedup attr creation from op().
|
||||
DictionaryAttr attrList;
|
||||
if (attrs) {
|
||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||
if (!dictAttrs) {
|
||||
throw py::raiseValueError(
|
||||
"Expected `attrs` to be a DictionaryAttr");
|
||||
}
|
||||
attrList = dictAttrs;
|
||||
} else {
|
||||
attrList = DictionaryAttr::get({}, self.getContext());
|
||||
}
|
||||
FuncOp op =
|
||||
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
|
||||
/*attrs=*/attrList.getValue());
|
||||
if (createEntryBlock) {
|
||||
Block *entryBlock = new Block();
|
||||
entryBlock->addArguments(functionType.getInputs());
|
||||
op.getBody().push_back(entryBlock);
|
||||
opBuilder.setInsertionPointToStart(entryBlock);
|
||||
}
|
||||
return PyOperationRef(op);
|
||||
},
|
||||
py::arg("name"), py::arg("type"),
|
||||
py::arg("create_entry_block") = false,
|
||||
py::arg("attrs") = llvm::Optional<PyAttribute>(),
|
||||
R"(Creates a new `func` op, optionally creating an entry block.
|
||||
If an entry block is created, the builder will be positioned
|
||||
to its start.)")
|
||||
.def(
|
||||
"select_op",
|
||||
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
|
||||
PyValue falseValue) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
return PyOperationRef(opBuilder.create<SelectOp>(
|
||||
loc, conditionValue, trueValue, falseValue));
|
||||
},
|
||||
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
|
||||
.def("return_op",
|
||||
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||
pyOperands.end());
|
||||
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
||||
})
|
||||
.def("constant_op",
|
||||
[](PyDialectHelper &self, PyType type, PyAttribute value) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
return PyOperationRef(
|
||||
opBuilder.create<ConstantOp>(loc, type.type, value.attr));
|
||||
})
|
||||
|
||||
// Types.
|
||||
.def_property_readonly("index_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IndexType::get(self.getContext());
|
||||
})
|
||||
.def(
|
||||
"integer_type",
|
||||
[](PyDialectHelper &self, unsigned width) -> PyType {
|
||||
return IntegerType::get(self.getContext(), width);
|
||||
},
|
||||
py::arg("width") = 32)
|
||||
.def_property_readonly("i1_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(self.getContext(), 1);
|
||||
})
|
||||
.def_property_readonly("i16_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(self.getContext(), 32);
|
||||
})
|
||||
.def_property_readonly("i32_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(self.getContext(), 32);
|
||||
})
|
||||
.def_property_readonly("i64_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(self.getContext(), 64);
|
||||
})
|
||||
.def_property_readonly("f32_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return FloatType::getF32(self.getContext());
|
||||
})
|
||||
.def_property_readonly("f64_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return FloatType::getF64(self.getContext());
|
||||
})
|
||||
.def(
|
||||
"tensor_type",
|
||||
[](PyDialectHelper &self, PyType elementType,
|
||||
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
|
||||
if (!elementType.type) {
|
||||
throw py::raiseValueError("Null element type");
|
||||
}
|
||||
if (shape) {
|
||||
return RankedTensorType::get(*shape, elementType.type);
|
||||
} else {
|
||||
return UnrankedTensorType::get(elementType.type);
|
||||
}
|
||||
},
|
||||
py::arg("element_type"),
|
||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||
.def("function_type",
|
||||
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
||||
std::vector<PyType> results) -> PyType {
|
||||
llvm::SmallVector<Type, 4> inputTypes;
|
||||
llvm::SmallVector<Type, 1> resultTypes;
|
||||
for (auto input : inputs) {
|
||||
inputTypes.push_back(input.type);
|
||||
}
|
||||
for (auto result : results) {
|
||||
resultTypes.push_back(result.type);
|
||||
}
|
||||
return FunctionType::get(self.getContext(), inputTypes,
|
||||
resultTypes);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module initialization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void emitDiagnostic(DiagnosticSeverity severity, PyAttribute loc,
|
||||
std::string &message) {
|
||||
auto locAttr = loc.attr.dyn_cast_or_null<LocationAttr>();
|
||||
if (!locAttr) {
|
||||
throw py::raiseValueError("Expected a LocationAttr");
|
||||
}
|
||||
auto &diagEngine = locAttr.getContext()->getDiagEngine();
|
||||
diagEngine.emit(Location(locAttr), severity) << message;
|
||||
}
|
||||
|
||||
void defineMlirIrModule(py::module m) {
|
||||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||
|
||||
// Globals.
|
||||
m.def(
|
||||
"emit_error",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Error, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
m.def(
|
||||
"emit_warning",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
m.def(
|
||||
"emit_remark",
|
||||
[](PyAttribute loc, std::string message) {
|
||||
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
|
||||
},
|
||||
py::arg("loc"), py::arg("message"));
|
||||
|
||||
// Python only types.
|
||||
PyDialectHelper::bind(m);
|
||||
|
||||
// Utility types.
|
||||
PyBlockList::bind(m, "BlockList");
|
||||
PyOperationList::bind(m, "OperationList");
|
||||
|
||||
// Wrapper types.
|
||||
PyAttribute::bind(m);
|
||||
PyBaseOperation::bind(m);
|
||||
PyBaseOpBuilder::bind(m);
|
||||
PyBlockRef::bind(m);
|
||||
PyContext::bind(m);
|
||||
PyIdentifier::bind(m);
|
||||
PyModuleOp::bind(m);
|
||||
PyOperationRef::bind(m);
|
||||
PyOpBuilder::bind(m);
|
||||
PyRegionRef::bind(m);
|
||||
PySymbolTable::bind(m);
|
||||
PyType::bind(m);
|
||||
PyValue::bind(m);
|
||||
|
||||
// Direct wrappings.
|
||||
bindInsertPoint(m);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyContext
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PyContext::PyContext() {
|
||||
mlir::npcomp::python::loadGlobalDialectsIntoContext(&context);
|
||||
}
|
||||
|
||||
void PyContext::bind(py::module m) {
|
||||
py::class_<PyContext, std::shared_ptr<PyContext>>(m, "MLIRContext")
|
||||
.def(py::init<>([]() {
|
||||
// Need explicit make_shared to avoid UB with enable_shared_from_this.
|
||||
return std::make_shared<PyContext>();
|
||||
}))
|
||||
.def("new_module",
|
||||
[&](PyContext &self) -> PyModuleOp {
|
||||
Location loc = UnknownLoc::get(&self.context);
|
||||
auto m = ModuleOp::create(loc);
|
||||
return PyModuleOp(self.shared_from_this(), m);
|
||||
})
|
||||
.def("parse_asm", &PyContext::parseAsm)
|
||||
.def(
|
||||
"new_builder",
|
||||
[](PyContext &self) {
|
||||
// Note: we collapse the Builder and OpBuilder into one because
|
||||
// there is little reason to expose the inheritance hierarchy to
|
||||
// Python.
|
||||
return PyOpBuilder(self);
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def("identifier",
|
||||
[](PyContext &self, std::string s) -> PyIdentifier {
|
||||
return Identifier::get(s, &self.context);
|
||||
})
|
||||
.def(
|
||||
"file_line_col_loc_attr",
|
||||
[](PyContext &self, PyIdentifier filename, unsigned line,
|
||||
unsigned column) -> PyAttribute {
|
||||
return static_cast<LocationAttr>(FileLineColLoc::get(
|
||||
filename.identifier, line, column, &self.context));
|
||||
},
|
||||
py::arg("filename"), py::arg("line"), py::arg("column"))
|
||||
// Salient functions from Builder.
|
||||
.def("parse_type",
|
||||
[](PyContext &self, const std::string &asmText) {
|
||||
Type t = parseType(asmText, &self.context);
|
||||
if (!t) {
|
||||
std::string message = "Unable to parse MLIR type: ";
|
||||
message.append(asmText);
|
||||
throw py::raiseValueError(message);
|
||||
}
|
||||
return PyType(t);
|
||||
})
|
||||
.def(
|
||||
"integer_attr",
|
||||
[](PyContext &self, PyType type, int64_t value) -> PyAttribute {
|
||||
if (!type.type.isa<IntegerType>()) {
|
||||
throw py::raiseValueError("Expected IntegerType");
|
||||
}
|
||||
return IntegerAttr::get(type.type, value);
|
||||
},
|
||||
py::arg("type"), py::arg("value"))
|
||||
.def("float_attr",
|
||||
[](PyContext &self, PyType type, double value) -> PyAttribute {
|
||||
if (!type.type.isa<FloatType>()) {
|
||||
throw py::raiseValueError("Expected FloatType");
|
||||
}
|
||||
return FloatAttr::get(type.type, value);
|
||||
})
|
||||
.def("index_attr",
|
||||
[](PyContext &self, int64_t indexValue) -> PyAttribute {
|
||||
return IntegerAttr::get(IndexType::get(&self.context), indexValue);
|
||||
})
|
||||
.def("string_attr",
|
||||
[](PyContext &self, const std::string &s) -> PyAttribute {
|
||||
return StringAttr::get(s, &self.context);
|
||||
})
|
||||
.def("bytes_attr",
|
||||
[](PyContext &self, py::bytes bytes) -> PyAttribute {
|
||||
char *buffer;
|
||||
ssize_t length;
|
||||
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bytes.ptr(), &buffer,
|
||||
&length)) {
|
||||
throw py::raiseValueError("Cannot extract bytes");
|
||||
}
|
||||
return StringAttr::get(StringRef(buffer, length), &self.context);
|
||||
})
|
||||
.def("flat_symbol_ref_attr",
|
||||
[](PyContext &self, const std::string &s) -> PyAttribute {
|
||||
return FlatSymbolRefAttr::get(s, &self.context);
|
||||
})
|
||||
.def("dictionary_attr",
|
||||
[](PyContext &self, py::dict d) -> PyAttribute {
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
for (auto &it : d) {
|
||||
auto key = it.first.cast<std::string>();
|
||||
auto value = it.second.cast<PyAttribute>();
|
||||
auto keyIdent = Identifier::get(key, &self.context);
|
||||
attrs.emplace_back(keyIdent, value.attr);
|
||||
}
|
||||
return DictionaryAttr::get(attrs, &self.context);
|
||||
})
|
||||
.def("array_attr",
|
||||
[](PyContext &self, py::list l) -> PyAttribute {
|
||||
SmallVector<Attribute, 4> attrs;
|
||||
for (auto &it : l) {
|
||||
attrs.push_back(it.cast<PyAttribute>().attr);
|
||||
}
|
||||
return ArrayAttr::get(attrs, &self.context);
|
||||
})
|
||||
.def(
|
||||
"dense_elements_attr",
|
||||
[](PyContext &self, py::buffer array) -> PyAttribute {
|
||||
// Request a contiguous view.
|
||||
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
||||
Py_buffer *view = new Py_buffer();
|
||||
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
|
||||
delete view;
|
||||
throw py::error_already_set();
|
||||
}
|
||||
py::buffer_info array_info(view);
|
||||
return createDenseElementsAttrFromBuffer(&self.context, array_info);
|
||||
},
|
||||
py::arg("array"))
|
||||
.def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute {
|
||||
return UnitAttr::get(&self.context);
|
||||
});
|
||||
}
|
||||
|
||||
PyModuleOp PyContext::parseAsm(const std::string &asm_text) {
|
||||
// Arrange to get a view that includes a terminating null to avoid
|
||||
// additional copy.
|
||||
// TODO: Consider using the buffer protocol to access and avoid more copies.
|
||||
const char *asm_chars = asm_text.c_str();
|
||||
StringRef asm_sr(asm_chars, asm_text.size() + 1);
|
||||
|
||||
// TODO: Output non failure diagnostics (somewhere)
|
||||
DiagnosticCapture diag_capture(&context);
|
||||
auto module_ref = parseMLIRModuleFromString(asm_sr, &context);
|
||||
if (!module_ref) {
|
||||
throw py::raiseValueError(
|
||||
diag_capture.consumeDiagnosticsAsString("Error parsing ASM"));
|
||||
}
|
||||
return PyModuleOp{shared_from_this(), module_ref.release()};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyBaseOperation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PyBaseOperation::~PyBaseOperation() = default;
|
||||
|
||||
void PyBaseOperation::bind(py::module m) {
|
||||
py::class_<PyBaseOperation>(m, "BaseOperation")
|
||||
.def_property_readonly(
|
||||
"name",
|
||||
[](PyBaseOperation &self) {
|
||||
return std::string(self.getOperation()->getName().getStringRef());
|
||||
})
|
||||
.def_property_readonly("is_registered",
|
||||
[](PyBaseOperation &self) {
|
||||
return self.getOperation()->isRegistered();
|
||||
})
|
||||
.def_property_readonly("num_regions",
|
||||
[](PyBaseOperation &self) {
|
||||
return self.getOperation()->getNumRegions();
|
||||
})
|
||||
.def_property_readonly("results",
|
||||
[](PyBaseOperation &self) {
|
||||
auto *op = self.getOperation();
|
||||
std::vector<PyValue> results(op->result_begin(),
|
||||
op->result_end());
|
||||
return results;
|
||||
})
|
||||
.def_property_readonly("result",
|
||||
[](PyBaseOperation &self) -> PyValue {
|
||||
auto *op = self.getOperation();
|
||||
if (op->getNumResults() != 1) {
|
||||
throw py::raiseValueError(
|
||||
"Operation does not have 1 result");
|
||||
}
|
||||
return op->getOpResult(0);
|
||||
})
|
||||
.def("region",
|
||||
[](PyBaseOperation &self, int index) {
|
||||
auto *op = self.getOperation();
|
||||
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
|
||||
throw py::raisePyError(PyExc_IndexError,
|
||||
"Region index out of bounds");
|
||||
}
|
||||
return PyRegionRef(op->getRegion(index));
|
||||
})
|
||||
.def_property_readonly("first_block", [](PyBaseOperation &self) {
|
||||
Operation *op = self.getOperation();
|
||||
assert(op);
|
||||
if (op->getNumRegions() == 0) {
|
||||
throw py::raiseValueError("Op has no regions");
|
||||
}
|
||||
auto ®ion = op->getRegion(0);
|
||||
if (region.empty()) {
|
||||
throw py::raiseValueError("Op has no blocks");
|
||||
}
|
||||
return PyBlockRef(region.front());
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyOperationRef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PyOperationRef::~PyOperationRef() = default;
|
||||
void PyOperationRef::bind(py::module m) {
|
||||
py::class_<PyOperationRef, PyBaseOperation>(m, "OperationRef");
|
||||
}
|
||||
|
||||
Operation *PyOperationRef::getOperation() { return operation; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyModuleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PyModuleOp::~PyModuleOp() = default;
|
||||
void PyModuleOp::bind(py::module m) {
|
||||
py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp")
|
||||
.def_property_readonly("context",
|
||||
[](PyModuleOp &self) { return self.context; })
|
||||
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
|
||||
py::arg("pretty") = false, py::arg("large_element_limit") = -1);
|
||||
}
|
||||
|
||||
Operation *PyModuleOp::getOperation() { return moduleOp; }
|
||||
|
||||
std::string PyModuleOp::toAsm(bool enableDebugInfo, bool prettyForm,
|
||||
int64_t largeElementLimit) {
|
||||
// Print to asm.
|
||||
std::string asmOutput;
|
||||
llvm::raw_string_ostream sout(asmOutput);
|
||||
OpPrintingFlags printFlags;
|
||||
if (enableDebugInfo) {
|
||||
printFlags.enableDebugInfo(prettyForm);
|
||||
}
|
||||
if (largeElementLimit >= 0) {
|
||||
printFlags.elideLargeElementsAttrs(largeElementLimit);
|
||||
}
|
||||
moduleOp.print(sout, printFlags);
|
||||
return sout.str();
|
||||
}
|
||||
|
||||
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
||||
MLIRContext *context) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> contents_buffer;
|
||||
if (contents.back() == 0) {
|
||||
// If it has a nul terminator, just use as-is.
|
||||
contents_buffer = llvm::MemoryBuffer::getMemBuffer(contents.drop_back());
|
||||
} else {
|
||||
// Otherwise, make a copy.
|
||||
contents_buffer = llvm::MemoryBuffer::getMemBufferCopy(contents, "EMBED");
|
||||
}
|
||||
|
||||
llvm::SourceMgr source_mgr;
|
||||
source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc());
|
||||
OwningModuleRef mlir_module = parseSourceFile(source_mgr, context);
|
||||
return mlir_module;
|
||||
}
|
||||
|
||||
// Custom location printer that prints prettier, multi-line file output
|
||||
// suitable for human readable error messages. The standard printer just prints
|
||||
// a long nested expression not particularly human friendly). Note that there
|
||||
// is a location pretty printer in the MLIR AsmPrinter. It is private and
|
||||
// doesn't do any path shortening, which seems to make long Python stack traces
|
||||
// a bit easier to scan.
|
||||
// TODO: Upstream this.
|
||||
void printLocation(Location loc, raw_ostream &out) {
|
||||
TypeSwitch<Location>(loc)
|
||||
.Case<OpaqueLoc>(
|
||||
[&](OpaqueLoc loc) { printLocation(loc.getFallbackLocation(), out); })
|
||||
.Case<UnknownLoc>([&](Location) { out << " [unknown location]\n"; })
|
||||
.Case<FileLineColLoc>([&](FileLineColLoc line_col_loc) {
|
||||
StringRef this_filename = line_col_loc.getFilename();
|
||||
auto slash_pos = this_filename.find_last_of("/\\");
|
||||
// We print both the basename and extended names with a structure like
|
||||
// `foo.py:35:4`. Even though technically the line/col
|
||||
// information is redundant to include in both names, having it on both
|
||||
// makes it easier to paste the paths into an editor and jump to the
|
||||
// exact location.
|
||||
std::string line_col_suffix =
|
||||
":" + std::to_string(line_col_loc.getLine()) + ":" +
|
||||
std::to_string(line_col_loc.getColumn());
|
||||
bool has_basename = false;
|
||||
StringRef basename = this_filename;
|
||||
if (slash_pos != StringRef::npos) {
|
||||
has_basename = true;
|
||||
basename = this_filename.substr(slash_pos + 1);
|
||||
}
|
||||
out << " at: " << basename << line_col_suffix;
|
||||
if (has_basename) {
|
||||
StringRef extended_name = this_filename;
|
||||
// Print out two tabs, as basenames usually vary in length by more
|
||||
// than one tab width.
|
||||
out << "\t\t( " << extended_name << line_col_suffix << " )";
|
||||
}
|
||||
out << "\n";
|
||||
})
|
||||
.Case<NameLoc>([&](NameLoc nameLoc) {
|
||||
out << " @'" << nameLoc.getName() << "':\n";
|
||||
auto childLoc = nameLoc.getChildLoc();
|
||||
if (!childLoc.isa<UnknownLoc>()) {
|
||||
out << "(...\n";
|
||||
printLocation(childLoc, out);
|
||||
out << ")\n";
|
||||
}
|
||||
})
|
||||
.Case<CallSiteLoc>([&](CallSiteLoc callSite) {
|
||||
printLocation(callSite.getCaller(), out);
|
||||
printLocation(callSite.getCallee(), out);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PySymbolTable
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PySymbolTable::bind(py::module m) {
|
||||
py::class_<PySymbolTable>(m, "SymbolTable")
|
||||
.def_property_readonly_static("symbol_attr_name",
|
||||
[](const py::object &) {
|
||||
auto sr =
|
||||
SymbolTable::getSymbolAttrName();
|
||||
return py::str(sr.data(), sr.size());
|
||||
})
|
||||
.def_property_readonly_static(
|
||||
"visibility_attr_name", [](const py::object &) {
|
||||
auto sr = SymbolTable::getVisibilityAttrName();
|
||||
return py::str(sr.data(), sr.size());
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DiagnosticCapture
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string
|
||||
DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
|
||||
std::string s;
|
||||
llvm::raw_string_ostream sout(s);
|
||||
bool first = true;
|
||||
if (error_message) {
|
||||
sout << error_message;
|
||||
first = false;
|
||||
}
|
||||
for (auto &d : diagnostics) {
|
||||
if (!first) {
|
||||
sout << "\n\n";
|
||||
} else {
|
||||
first = false;
|
||||
}
|
||||
|
||||
switch (d.getSeverity()) {
|
||||
case DiagnosticSeverity::Note:
|
||||
sout << "[NOTE]";
|
||||
break;
|
||||
case DiagnosticSeverity::Warning:
|
||||
sout << "[WARNING]";
|
||||
break;
|
||||
case DiagnosticSeverity::Error:
|
||||
sout << "[ERROR]";
|
||||
break;
|
||||
case DiagnosticSeverity::Remark:
|
||||
sout << "[REMARK]";
|
||||
break;
|
||||
}
|
||||
// Message.
|
||||
sout << ": " << d << "\n";
|
||||
printLocation(d.getLocation(), sout);
|
||||
}
|
||||
|
||||
diagnostics.clear();
|
||||
return sout.str();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyBlockRef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyBlockRef::bind(py::module m) {
|
||||
py::class_<PyBlockRef>(m, "BlockRef")
|
||||
.def_property_readonly("operations",
|
||||
[](PyBlockRef &self) {
|
||||
return PyOperationList(
|
||||
self.block.getOperations());
|
||||
})
|
||||
.def_property_readonly("args", [](PyBlockRef &self) {
|
||||
return std::vector<PyValue>(self.block.args_begin(),
|
||||
self.block.args_end());
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyRegionRef
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyRegionRef::bind(py::module m) {
|
||||
py::class_<PyRegionRef>(m, "RegionRef")
|
||||
.def_property_readonly("blocks", [](PyRegionRef &self) {
|
||||
return PyBlockList(self.region.getBlocks());
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyType::bind(py::module m) {
|
||||
py::class_<PyType>(m, "Type").def("__repr__",
|
||||
[](PyType &self) -> std::string {
|
||||
if (!self.type)
|
||||
return "<undefined type>";
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
self.type.print(os);
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyIdentifier
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyIdentifier::bind(py::module m) {
|
||||
py::class_<PyIdentifier>(m, "Identifier")
|
||||
.def("__str__", [](PyIdentifier &self) { return self.identifier.str(); })
|
||||
.def("__repr__", [](PyIdentifier &self) {
|
||||
std::string s("<Identifier \"");
|
||||
s.append(self.identifier.str());
|
||||
s.append("\">");
|
||||
return s;
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyValue
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyValue::bind(py::module m) {
|
||||
py::class_<PyValue>(m, "Value")
|
||||
.def_property_readonly(
|
||||
"type", [](PyValue &self) -> PyType { return self.value.getType(); })
|
||||
.def("__repr__", [](PyValue &self) {
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
os << self.value;
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyAttribute
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyAttribute::bind(py::module m) {
|
||||
py::class_<PyAttribute>(m, "Attribute")
|
||||
.def_property_readonly(
|
||||
"type",
|
||||
[](PyAttribute &self) -> PyType { return self.attr.getType(); })
|
||||
.def("__repr__", [](PyAttribute &self) {
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
os << self.attr;
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpBuilder implementations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PyBaseOpBuilder::~PyBaseOpBuilder() = default;
|
||||
PyOpBuilder::~PyOpBuilder() = default;
|
||||
|
||||
OpBuilder &PyOpBuilder::getBuilder(bool requirePosition) {
|
||||
if (requirePosition && !builder.getBlock()) {
|
||||
throw py::raisePyError(PyExc_IndexError, "Insertion point not set");
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
void PyBaseOpBuilder::bind(py::module m) {
|
||||
py::class_<PyBaseOpBuilder>(m, "BaseOpBuilder");
|
||||
}
|
||||
|
||||
void PyOpBuilder::bind(py::module m) {
|
||||
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
|
||||
.def(py::init<PyContext &>(), py::keep_alive<1, 2>())
|
||||
.def_property(
|
||||
"current_loc",
|
||||
[](PyOpBuilder &self) -> PyAttribute {
|
||||
return static_cast<Attribute>(self.getCurrentLoc());
|
||||
},
|
||||
[](PyOpBuilder &self, PyAttribute attr) {
|
||||
auto loc_attr = attr.attr.dyn_cast_or_null<LocationAttr>();
|
||||
if (!loc_attr) {
|
||||
throw py::raiseValueError("Expected a LocationAttr");
|
||||
}
|
||||
self.setCurrentLoc(Location(loc_attr));
|
||||
})
|
||||
.def_property(
|
||||
"insertion_point",
|
||||
[](PyOpBuilder &self) {
|
||||
return self.getBuilder(true).saveInsertionPoint();
|
||||
},
|
||||
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
|
||||
self.getBuilder(false).restoreInsertionPoint(ip);
|
||||
})
|
||||
.def(
|
||||
"set_file_line_col",
|
||||
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
|
||||
unsigned column) {
|
||||
Location loc = FileLineColLoc::get(filename.identifier, line,
|
||||
column, self.getContext());
|
||||
self.setCurrentLoc(loc);
|
||||
},
|
||||
py::arg("filename"), py::arg("line"), py::arg("column"),
|
||||
"Shortcut to set a FileLineCol current location")
|
||||
.def("clear_insertion_point",
|
||||
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
|
||||
.def(
|
||||
"insert_op_before",
|
||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||
Operation *op = pyOp.getOperation();
|
||||
self.builder.setInsertionPoint(op);
|
||||
},
|
||||
"Sets the insertion point to just before the specified op.")
|
||||
.def(
|
||||
"insert_op_after",
|
||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||
Operation *op = pyOp.getOperation();
|
||||
self.builder.setInsertionPointAfter(op);
|
||||
},
|
||||
"Sets the insertion point to just after the specified op.")
|
||||
.def(
|
||||
"insert_block_start",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
self.builder.setInsertionPointToStart(&block.block);
|
||||
},
|
||||
"Sets the insertion point to the start of the block.")
|
||||
.def(
|
||||
"insert_block_end",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
self.builder.setInsertionPointToEnd(&block.block);
|
||||
},
|
||||
"Sets the insertion point to the end of the block.")
|
||||
.def(
|
||||
"insert_before_terminator",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
auto *terminator = block.block.getTerminator();
|
||||
if (!terminator) {
|
||||
throw py::raiseValueError("Block has no terminator");
|
||||
}
|
||||
self.builder.setInsertionPoint(terminator);
|
||||
},
|
||||
"Sets the insertion point to just before the block terminator.");
|
||||
}
|
||||
|
||||
} // namespace mlir
|
|
@ -1,82 +0,0 @@
|
|||
//===- MlirIr.cpp - MLIR IR Bindings --------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Python/MlirPass.h"
|
||||
#include "npcomp/Python/MlirInit.h"
|
||||
#include "npcomp/Python/NpcompModule.h"
|
||||
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module initialization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void defineMlirPassModule(py::module m) {
|
||||
m.doc() = "Python bindings for mlir pass infra";
|
||||
|
||||
PyPassManager::bind(m);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PassManager
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyPassManager::bind(py::module m) {
|
||||
py::class_<PyPassManager>(m, "PassManager")
|
||||
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
|
||||
py::arg("verifyModules") = true)
|
||||
.def(
|
||||
"enableCrashReproducerGeneration",
|
||||
[](PyPassManager &self, std::string outputFile,
|
||||
bool genLocalReproducer) {
|
||||
self.passManager.enableCrashReproducerGeneration(
|
||||
outputFile, genLocalReproducer);
|
||||
},
|
||||
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
|
||||
.def("__len__",
|
||||
[](PyPassManager &self) { return self.passManager.size(); })
|
||||
.def("__str__",
|
||||
[](PyPassManager &self) {
|
||||
std::string spec;
|
||||
llvm::raw_string_ostream stream(spec);
|
||||
self.passManager.printAsTextualPipeline(stream);
|
||||
return spec;
|
||||
})
|
||||
.def("run",
|
||||
[](PyPassManager &self, PyModuleOp &module) {
|
||||
if (module.context.get() != self.context.get()) {
|
||||
throw py::raiseValueError(
|
||||
"Expected a module with the same context "
|
||||
"as the PassManager");
|
||||
}
|
||||
if (failed(self.passManager.run(module.moduleOp))) {
|
||||
// TODO: Wrap propagate context diagnostics
|
||||
throw py::raisePyError(PyExc_RuntimeError,
|
||||
"Could not run passes");
|
||||
}
|
||||
})
|
||||
.def("addPassPipelines", [](PyPassManager &self, py::args passPipelines) {
|
||||
std::string error;
|
||||
llvm::raw_string_ostream error_stream(error);
|
||||
for (auto pyPassPipeline : passPipelines) {
|
||||
auto passPipeline = pyPassPipeline.cast<std::string>();
|
||||
if (failed(mlir::parsePassPipeline(passPipeline, self.passManager,
|
||||
error_stream))) {
|
||||
std::string message = "failed to parse pass pipeline '";
|
||||
message.append(passPipeline);
|
||||
message.append("': ");
|
||||
message.append(error);
|
||||
throw py::raiseValueError(message);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlir
|
|
@ -1,172 +0,0 @@
|
|||
//===- NpcompDialect.cpp - Custom dialect classes -------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Python/MlirIr.h"
|
||||
#include "npcomp/Python/NpcompModule.h"
|
||||
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
||||
class BasicpyDialectHelper : public PyDialectHelper {
|
||||
public:
|
||||
using PyDialectHelper::PyDialectHelper;
|
||||
|
||||
static void bind(py::module m) {
|
||||
py::class_<BasicpyDialectHelper, PyDialectHelper>(m, "BasicpyDialectHelper")
|
||||
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
|
||||
py::keep_alive<1, 3>())
|
||||
// ---------------------------------------------------------------------
|
||||
// Basicpy dialect
|
||||
// ---------------------------------------------------------------------
|
||||
.def_property_readonly("basicpy_BoolType",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::BoolType::get(
|
||||
self.getContext());
|
||||
})
|
||||
.def_property_readonly("basicpy_BytesType",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::BytesType::get(
|
||||
self.getContext());
|
||||
})
|
||||
.def_property_readonly("basicpy_EllipsisType",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::EllipsisType::get(
|
||||
self.getContext());
|
||||
})
|
||||
.def_property_readonly("basicpy_NoneType",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::NoneType::get(
|
||||
self.getContext());
|
||||
})
|
||||
.def(
|
||||
"basicpy_SlotObject_type",
|
||||
[](BasicpyDialectHelper &self, std::string className,
|
||||
py::args pySlotTypes) -> PyType {
|
||||
SmallVector<Type, 4> slotTypes;
|
||||
for (auto pySlotType : pySlotTypes) {
|
||||
slotTypes.push_back(pySlotType.cast<PyType>());
|
||||
}
|
||||
auto classNameAttr =
|
||||
StringAttr::get(className, self.getContext());
|
||||
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
|
||||
},
|
||||
py::arg("className"))
|
||||
.def_property_readonly("basicpy_StrType",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::StrType::get(
|
||||
self.getContext());
|
||||
})
|
||||
.def_property_readonly("basicpy_UnknownType",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::UnknownType::get(
|
||||
self.getContext());
|
||||
})
|
||||
.def("basicpy_exec_op",
|
||||
[](BasicpyDialectHelper &self) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
auto op = opBuilder.create<Basicpy::ExecOp>(loc);
|
||||
return py::make_tuple(PyOperationRef(op),
|
||||
op.getBodyBuilder().saveInsertionPoint());
|
||||
})
|
||||
.def("basicpy_exec_discard_op",
|
||||
[](BasicpyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
llvm::SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||
pyOperands.end());
|
||||
auto op =
|
||||
opBuilder.create<Basicpy::ExecDiscardOp>(loc, operands);
|
||||
return PyOperationRef(op);
|
||||
})
|
||||
.def("basicpy_slot_object_get_op",
|
||||
[](BasicpyDialectHelper &self, PyValue slotObject,
|
||||
unsigned index) -> PyOperationRef {
|
||||
auto slotObjectType = slotObject.value.getType()
|
||||
.dyn_cast<Basicpy::SlotObjectType>();
|
||||
if (!slotObjectType) {
|
||||
throw py::raiseValueError("Operand must be a SlotObject");
|
||||
}
|
||||
if (index >= slotObjectType.getSlotCount()) {
|
||||
throw py::raiseValueError("Out of range slot index");
|
||||
}
|
||||
auto resultType = slotObjectType.getSlotTypes()[index];
|
||||
auto indexAttr =
|
||||
IntegerAttr::get(IndexType::get(self.getContext()), index);
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
auto op = opBuilder.create<Basicpy::SlotObjectGetOp>(
|
||||
loc, resultType, slotObject, indexAttr);
|
||||
return op.getOperation();
|
||||
})
|
||||
// ---------------------------------------------------------------------
|
||||
// Numpy dialect
|
||||
// ---------------------------------------------------------------------
|
||||
.def("numpy_copy_to_tensor_op",
|
||||
[](BasicpyDialectHelper &self, PyValue source) -> PyOperationRef {
|
||||
auto sourceType =
|
||||
source.value.getType().dyn_cast<Numpy::NdArrayType>();
|
||||
if (!sourceType) {
|
||||
source.value.dump();
|
||||
throw py::raiseValueError("expected ndarray type for "
|
||||
"numpy_copy_to_tensor_op");
|
||||
}
|
||||
auto dtype = sourceType.getDtype();
|
||||
auto optionalShape = sourceType.getOptionalShape();
|
||||
TensorType tensorType;
|
||||
if (optionalShape) {
|
||||
tensorType = RankedTensorType::get(*optionalShape, dtype);
|
||||
} else {
|
||||
tensorType = UnrankedTensorType::get(dtype);
|
||||
}
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
auto op = opBuilder.create<Numpy::CopyToTensorOp>(
|
||||
loc, tensorType, source.value);
|
||||
return op.getOperation();
|
||||
})
|
||||
.def("numpy_create_array_from_tensor_op",
|
||||
[](BasicpyDialectHelper &self, PyValue source) -> PyOperationRef {
|
||||
auto sourceType = source.value.getType().dyn_cast<TensorType>();
|
||||
if (!sourceType) {
|
||||
throw py::raiseValueError("expected tensor type for "
|
||||
"numpy_create_array_from_tensor_op");
|
||||
}
|
||||
auto dtype = sourceType.getElementType();
|
||||
llvm::Optional<ArrayRef<int64_t>> optionalShape;
|
||||
if (auto rankedTensorType =
|
||||
sourceType.dyn_cast<RankedTensorType>()) {
|
||||
optionalShape = rankedTensorType.getShape();
|
||||
}
|
||||
auto ndarrayType = Numpy::NdArrayType::get(dtype, optionalShape);
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
auto op = opBuilder.create<Numpy::CreateArrayFromTensorOp>(
|
||||
loc, ndarrayType, source.value);
|
||||
return op.getOperation();
|
||||
})
|
||||
.def("numpy_NdArrayType",
|
||||
[](BasicpyDialectHelper &self, PyType dtype) -> PyType {
|
||||
return Numpy::NdArrayType::get(dtype.type);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
using namespace ::mlir::NPCOMP;
|
||||
|
||||
void mlir::npcomp::python::defineNpcompDialect(py::module m) {
|
||||
BasicpyDialectHelper::bind(m);
|
||||
}
|
|
@ -75,9 +75,11 @@ target_link_libraries(NPCOMPNativePyExt
|
|||
NPCOMP
|
||||
|
||||
NPCOMPInitAll
|
||||
NPCOMPPythonCommon
|
||||
${NPCOMP_PYEXT_LIBADD}
|
||||
)
|
||||
npcomp_python_target_compile_options(NPCOMPNativePyExt)
|
||||
|
||||
mlir_check_all_link_libraries(NPCOMPNativePyExt)
|
||||
|
||||
# Order dependent: Built artifacts add dependencies to the above targets.
|
||||
add_subdirectory(npcomp/dialects)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- native.cpp - MLIR Python bindings ----------------------------------===//
|
||||
//===- NpcompModule.cpp - MLIR Python bindings ----------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -9,85 +9,58 @@
|
|||
#include <cstddef>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "npcomp/Python/MlirInit.h"
|
||||
#include "npcomp/Python/NpcompModule.h"
|
||||
#include "npcomp/Python/PybindUtils.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/InitLLVM.h"
|
||||
#include "npcomp-c/Registration.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
#include "npcomp/Python/PybindUtils.h"
|
||||
|
||||
#ifdef NPCOMP_ENABLE_REFJIT
|
||||
#include "npcomp/Backend/RefJIT/PythonModule.h"
|
||||
#endif
|
||||
|
||||
namespace mlir {
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
namespace {
|
||||
|
||||
static void defineLLVMModule(pybind11::module m) {
|
||||
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
|
||||
m.def("add_option",
|
||||
[](std::string name, llvm::Optional<std::string> value) {
|
||||
auto options_map = llvm::cl::getRegisteredOptions();
|
||||
auto found_it = options_map.find(name);
|
||||
if (found_it == options_map.end()) {
|
||||
std::string message = "Unknown LLVM option: ";
|
||||
message.append(name);
|
||||
throw py::raiseValueError(message.c_str());
|
||||
}
|
||||
|
||||
std::string value_sr = value ? *value : "";
|
||||
found_it->getValue()->addOccurrence(1, name, value_sr);
|
||||
},
|
||||
py::arg("name"), py::arg("value") = llvm::Optional<std::string>());
|
||||
m.def("reset_option",
|
||||
[](std::string name) {
|
||||
auto options_map = llvm::cl::getRegisteredOptions();
|
||||
auto found_it = options_map.find(name);
|
||||
if (found_it == options_map.end()) {
|
||||
std::string message = "Unknown LLVM option: ";
|
||||
message.append(name);
|
||||
throw py::raiseValueError(message.c_str());
|
||||
}
|
||||
found_it->getValue()->setDefault();
|
||||
},
|
||||
py::arg("name"));
|
||||
MlirType shapedToNdArrayArrayType(MlirType shaped_type) {
|
||||
if (!mlirTypeIsAShaped(shaped_type)) {
|
||||
throw py::raiseValueError("type is not a shaped type");
|
||||
}
|
||||
return npcompNdArrayTypeGetFromShaped(shaped_type);
|
||||
}
|
||||
|
||||
static void defineGlobals(py::module &m) {
|
||||
m.def("register_dialects", [](MlirContext context) {
|
||||
npcompRegisterAllDialects(context);
|
||||
});
|
||||
MlirType ndarrayToTensorType(MlirType ndarray_type) {
|
||||
if (!npcompTypeIsANdArray(ndarray_type)) {
|
||||
throw py::raiseValueError("type is not an ndarray type");
|
||||
}
|
||||
return npcompNdArrayTypeToTensor(ndarray_type);
|
||||
}
|
||||
|
||||
MlirType slotObjectType(MlirContext context, const std::string &className,
|
||||
const std::vector<MlirType> &slotTypes) {
|
||||
MlirStringRef classNameSr{className.data(), className.size()};
|
||||
return ::npcompSlotObjectTypeGet(context, classNameSr, slotTypes.size(),
|
||||
slotTypes.data());
|
||||
}
|
||||
|
||||
// TODO: Move this upstream.
|
||||
void emitError(MlirLocation loc, std::string message) {
|
||||
::mlirEmitError(loc, message.c_str());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(_npcomp, m) {
|
||||
// Guard the once init to happen once per process (vs module, which in
|
||||
// mondo builds can happen multiple times).
|
||||
static bool llvm_init_baton = ([]() { return npcompMlirInitialize(); })();
|
||||
(void)(llvm_init_baton);
|
||||
|
||||
m.doc() = "Npcomp native python bindings";
|
||||
|
||||
// TODO: Retire the llvm, mlir, passes, and dialect modules in favor of
|
||||
// upstream Python bindings.
|
||||
auto llvm_m = m.def_submodule("llvm", "LLVM interop");
|
||||
defineLLVMModule(llvm_m);
|
||||
|
||||
// "mlir" module.
|
||||
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
|
||||
auto mlir_ir_m = mlir_m.def_submodule("ir");
|
||||
defineMlirIrModule(mlir_ir_m);
|
||||
// Note: not "pass" because it is a reserved word
|
||||
auto mlir_pass_m = mlir_m.def_submodule("passes");
|
||||
defineMlirPassModule(mlir_pass_m);
|
||||
auto mlir_dialect_m = mlir_m.def_submodule("dialect");
|
||||
defineMlirCoreDialects(mlir_dialect_m);
|
||||
|
||||
// Outer "_npcomp" module
|
||||
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
|
||||
defineNpcompDialect(npcomp_dialect);
|
||||
|
||||
// Globals.
|
||||
defineGlobals(m);
|
||||
m.def("register_all_dialects", ::npcompRegisterAllDialects);
|
||||
m.def("_register_all_passes", ::npcompRegisterAllPasses);
|
||||
m.def("_initialize_llvm_codegen", ::npcompInitializeLLVMCodegen);
|
||||
m.def("shaped_to_ndarray_type", shapedToNdArrayArrayType);
|
||||
m.def("ndarray_to_tensor_type", ndarrayToTensorType);
|
||||
m.def("slot_object_type", slotObjectType);
|
||||
m.def("emit_error", emitError);
|
||||
|
||||
// Optional backend modules.
|
||||
auto backend_m = m.def_submodule("backend", "Backend support");
|
||||
|
@ -99,7 +72,3 @@ PYBIND11_MODULE(_npcomp, m) {
|
|||
::npcomp::python::defineBackendRefJitModule(refjit_m);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1,3 +1,25 @@
|
|||
def _load_extension():
|
||||
# TODO: Remote the RTLD_GLOBAL hack once local, cross module imports
|
||||
# resolve symbols properly. Something is keeping the dynamic loader on
|
||||
# Linux from treating the following vague symbols as the same across
|
||||
# _mlir and _npcomp:
|
||||
# mlir::detail::TypeIDExported::get<mlir::FuncOp>()::instance
|
||||
import sys
|
||||
import ctypes
|
||||
flags = sys.getdlopenflags()
|
||||
sys.setdlopenflags(flags | ctypes.RTLD_GLOBAL)
|
||||
import _npcomp
|
||||
sys.setdlopenflags(flags)
|
||||
|
||||
import mlir
|
||||
mlir._cext.globals.append_dialect_search_prefix("npcomp.dialects")
|
||||
return _npcomp
|
||||
|
||||
|
||||
_cext = _load_extension()
|
||||
_cext._register_all_passes()
|
||||
_cext._initialize_llvm_codegen()
|
||||
|
||||
# Top-level symbols.
|
||||
from .exporter import *
|
||||
from .types import *
|
||||
|
|
|
@ -7,8 +7,6 @@ import subprocess
|
|||
|
||||
from mlir.ir import *
|
||||
from mlir.passmanager import *
|
||||
from _npcomp import register_dialects
|
||||
from _npcomp import mlir as legacy_mlir
|
||||
from npcomp.compiler.generic.backend import iree as iree_backend
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
|
@ -56,7 +54,7 @@ class CompilerBackend:
|
|||
self._ireert = _get_iree()
|
||||
self._debug = logging.debug_enabled()
|
||||
|
||||
def compile(self, legacy_imported_ir_module: legacy_mlir.ir.ModuleOp):
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module.
|
||||
|
||||
Args:
|
||||
|
@ -67,10 +65,12 @@ class CompilerBackend:
|
|||
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||
it is operated on by methods on this class.
|
||||
"""
|
||||
# TODO: Once transitioned to new Python API, don't reparse the module.
|
||||
with Context() as context:
|
||||
register_dialects(context)
|
||||
imported_module = Module.parse(legacy_imported_ir_module.to_asm())
|
||||
with imported_module.context:
|
||||
# Frontend.
|
||||
if self._debug:
|
||||
logging.debug("Input IR:\n{}", imported_module)
|
||||
assert (
|
||||
imported_module.operation.verify()), "Imported module does not verify"
|
||||
# Frontend.
|
||||
pm = PassManager.parse(",".join(FRONTEND_PASSES))
|
||||
pm.run(imported_module)
|
||||
|
|
|
@ -6,8 +6,6 @@ import os
|
|||
|
||||
from mlir.ir import *
|
||||
from mlir.passmanager import *
|
||||
from _npcomp import register_dialects
|
||||
from _npcomp import mlir as legacy_mlir
|
||||
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
|
@ -37,7 +35,7 @@ class CompilerBackend:
|
|||
self._refjit = refjit_backend.get_refjit()
|
||||
self._debug = logging.debug_enabled()
|
||||
|
||||
def compile(self, legacy_imported_ir_module: legacy_mlir.ir.ModuleOp):
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module.
|
||||
|
||||
Args:
|
||||
|
@ -49,15 +47,16 @@ class CompilerBackend:
|
|||
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||
it is operated on by methods on this class.
|
||||
"""
|
||||
# TODO: Once transitioned to new Python API, don't reparse the module.
|
||||
with Context() as context:
|
||||
register_dialects(context)
|
||||
imported_module = Module.parse(legacy_imported_ir_module.to_asm())
|
||||
with imported_module.context as context:
|
||||
# Frontend.
|
||||
if self._debug:
|
||||
logging.debug("Input IR:\n{}", imported_module)
|
||||
assert (
|
||||
imported_module.operation.verify()), "Imported module does not verify"
|
||||
pm = PassManager.parse(",".join(FRONTEND_PASSES))
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("Frontend IR:{}", imported_module)
|
||||
logging.debug("Frontend IR:\n{}", imported_module)
|
||||
|
||||
# Backend.
|
||||
# Note that this is a separate pass manager purely to aid in debugging.
|
||||
|
@ -65,7 +64,7 @@ class CompilerBackend:
|
|||
self._refjit.build_backend_compilation_pipeline(pm)
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("Backend IR:{}", imported_module)
|
||||
logging.debug("Backend IR:\n{}", imported_module)
|
||||
|
||||
jit_module = self._refjit.JITModule.from_compiled_module(
|
||||
imported_module, refjit_backend.get_runtime_libs())
|
||||
|
|
|
@ -8,7 +8,10 @@ from typing import Callable, Iterator, Sequence, Tuple
|
|||
import functools
|
||||
import numpy as np
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from mlir import ir as _ir
|
||||
|
||||
from npcomp import _cext
|
||||
from npcomp.dialects import numpy as numpy_ops
|
||||
|
||||
from ....utils import logging
|
||||
from ...interfaces import *
|
||||
|
@ -71,7 +74,7 @@ class BuiltinUfuncLiveValueRef(LiveValueRef):
|
|||
self._qualified_name = qualified_name
|
||||
self._ufunc = ufunc
|
||||
|
||||
def resolve_call(self, env: Environment, args: Sequence[ir.Value],
|
||||
def resolve_call(self, env: Environment, args: Sequence[_ir.Value],
|
||||
keywords: Sequence[str]) -> PartialEvalResult:
|
||||
if keywords:
|
||||
return PartialEvalResult.error_message(
|
||||
|
@ -80,14 +83,27 @@ class BuiltinUfuncLiveValueRef(LiveValueRef):
|
|||
return PartialEvalResult.error_message(
|
||||
"ufunc {} expected {} inputs but got {}".format(
|
||||
self._qualified_name, self._ufunc.nin, len(args)))
|
||||
ir_h = env.ir_h
|
||||
ic = env.ic
|
||||
|
||||
# Because a ufunc call is defined in terms of tensors and, at this stage,
|
||||
# all "public" values are ndarray, do appropriate conversions.
|
||||
tensor_args = [ir_h.numpy_copy_to_tensor_op(arg).result for arg in args]
|
||||
result_type = ir_h.numpy_unknown_tensor_type
|
||||
tensor_result = ir_h.numpy_builtin_ufunc_call_op(
|
||||
*tensor_args,
|
||||
qualified_name=self._qualified_name,
|
||||
result_type=result_type).result
|
||||
array_result = ir_h.numpy_create_array_from_tensor_op(tensor_result).result
|
||||
def copy_to_tensor(value):
|
||||
tensor_type = _cext.ndarray_to_tensor_type(value.type)
|
||||
return numpy_ops.CopyToTensorOp(tensor_type, value, loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
|
||||
tensor_args = [copy_to_tensor(arg) for arg in args]
|
||||
result_type = ic.unknown_tensor_type
|
||||
tensor_result = numpy_ops.BuiltinUfuncCallOp(result_type,
|
||||
_ir.StringAttr.get(
|
||||
self._qualified_name,
|
||||
context=ic.context),
|
||||
tensor_args,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
array_result = numpy_ops.CreateArrayFromTensorOp(
|
||||
_cext.shaped_to_ndarray_type(tensor_result.type),
|
||||
tensor_result,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
return PartialEvalResult.yields_ir_value(array_result)
|
||||
|
|
|
@ -6,7 +6,11 @@
|
|||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from mlir import ir as _ir
|
||||
from mlir.dialects import std as std_ops
|
||||
|
||||
from npcomp import _cext
|
||||
from npcomp.dialects import numpy as numpy_ops
|
||||
|
||||
from ....utils import logging
|
||||
from ...interfaces import *
|
||||
|
@ -23,15 +27,22 @@ class NdArrayValueCoder(ValueCoder):
|
|||
__slots__ = []
|
||||
|
||||
def code_py_value_as_const(self, env: Environment,
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
py_value) -> Union[_NotImplementedType, _ir.Value]:
|
||||
# TODO: Query for ndarray compat (for duck typed and such)
|
||||
# TODO: Have a higher level name resolution signal which indicates const
|
||||
ir_h = env.ir_h
|
||||
ic = env.ic
|
||||
if isinstance(py_value, np.ndarray):
|
||||
dense_attr = ir_h.context.dense_elements_attr(py_value)
|
||||
dense_attr = _ir.DenseElementsAttr.get(py_value, context=ic.context)
|
||||
tensor_type = dense_attr.type
|
||||
tensor_value = ir_h.constant_op(tensor_type, dense_attr).result
|
||||
return ir_h.numpy_create_array_from_tensor_op(tensor_value).result
|
||||
tensor_value = std_ops.ConstantOp(tensor_type,
|
||||
dense_attr,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
ndarray_type = _cext.shaped_to_ndarray_type(tensor_type)
|
||||
return numpy_ops.CreateArrayFromTensorOp(ndarray_type,
|
||||
tensor_value,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
return NotImplemented
|
||||
|
||||
|
||||
|
|
|
@ -10,16 +10,14 @@ import inspect
|
|||
import textwrap
|
||||
from typing import Optional
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from _npcomp.mlir.dialect import ScfDialectHelper
|
||||
from npcomp.dialect import Numpy
|
||||
|
||||
from mlir import ir as _ir
|
||||
from ..utils import logging
|
||||
from .importer import *
|
||||
from .interfaces import *
|
||||
from .name_resolver_base import *
|
||||
from .value_coder_base import *
|
||||
from .target import *
|
||||
from ..utils.mlir_utils import *
|
||||
|
||||
__all__ = [
|
||||
"ImportFrontend",
|
||||
|
@ -29,34 +27,27 @@ __all__ = [
|
|||
class ImportFrontend:
|
||||
"""Frontend for importing various entities into a Module."""
|
||||
__slots__ = [
|
||||
"_ir_context",
|
||||
"_ir_module",
|
||||
"_ir_h",
|
||||
"_config",
|
||||
"_ic",
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
config: Configuration,
|
||||
ir_context: ir.MLIRContext = None):
|
||||
ir_context: Optional[_ir.Context] = None):
|
||||
super().__init__()
|
||||
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
||||
self._ir_module = self._ir_context.new_module()
|
||||
self._ir_h = AllDialectHelper(self._ir_context,
|
||||
ir.OpBuilder(self._ir_context))
|
||||
ic = self._ic = ImportContext(ir_context)
|
||||
self._ic.module = _ir.Module.create(loc=ic.loc)
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def ir_context(self):
|
||||
return self._ir_context
|
||||
def ir_context(self) -> _ir.Context:
|
||||
return self._ic.context
|
||||
|
||||
@property
|
||||
def ir_module(self):
|
||||
return self._ir_module
|
||||
|
||||
@property
|
||||
def ir_h(self):
|
||||
return self._ir_h
|
||||
def ir_module(self) -> _ir.Module:
|
||||
return self._ic.module
|
||||
|
||||
def import_global_function(self, f):
|
||||
"""Imports a global function.
|
||||
|
@ -70,10 +61,8 @@ class ImportFrontend:
|
|||
Args:
|
||||
f: The python callable.
|
||||
"""
|
||||
h = self.ir_h
|
||||
ir_c = self.ir_context
|
||||
ir_m = self.ir_module
|
||||
target = self._config.target_factory(h)
|
||||
ic = self._ic
|
||||
target = self._config.target_factory(ic)
|
||||
filename = inspect.getsourcefile(f)
|
||||
source_lines, start_lineno = inspect.getsourcelines(f)
|
||||
source = "".join(source_lines)
|
||||
|
@ -81,7 +70,6 @@ class ImportFrontend:
|
|||
ast_root = ast.parse(source, filename=filename)
|
||||
ast.increment_lineno(ast_root, start_lineno - 1)
|
||||
ast_fd = ast_root.body[0]
|
||||
filename_ident = ir_c.identifier(filename)
|
||||
|
||||
# Define the function.
|
||||
# TODO: Much more needs to be done here (arg/result mapping, etc)
|
||||
|
@ -98,27 +86,21 @@ class ImportFrontend:
|
|||
]
|
||||
f_return_type = self._resolve_signature_annotation(
|
||||
target, f_signature.return_annotation)
|
||||
ir_f_type = h.function_type(f_input_types, [f_return_type])
|
||||
ir_f_type = _ir.FunctionType.get(f_input_types, [f_return_type],
|
||||
context=ic.context)
|
||||
|
||||
h.builder.set_file_line_col(filename_ident, ast_fd.lineno,
|
||||
ast_fd.col_offset)
|
||||
h.builder.insert_before_terminator(ir_m.first_block)
|
||||
# TODO: Do not hardcode this IREE attribute.
|
||||
attrs = ir_c.dictionary_attr({"iree.module.export": ir_c.unit_attr})
|
||||
ir_f = h.func_op(ast_fd.name,
|
||||
ir_f_type,
|
||||
create_entry_block=True,
|
||||
attrs=attrs)
|
||||
ic.set_file_line_col(filename, ast_fd.lineno, ast_fd.col_offset)
|
||||
ic.insert_before_terminator(ic.module.body)
|
||||
ir_f, entry_block = ic.FuncOp(ast_fd.name,
|
||||
ir_f_type,
|
||||
create_entry_block=True)
|
||||
ic.insert_end_of_block(entry_block)
|
||||
env = self._create_const_global_env(f,
|
||||
parameter_bindings=zip(
|
||||
f_params.keys(),
|
||||
ir_f.first_block.args),
|
||||
entry_block.arguments),
|
||||
target=target)
|
||||
fctx = FunctionContext(ir_c=ir_c,
|
||||
ir_f=ir_f,
|
||||
ir_h=h,
|
||||
filename_ident=filename_ident,
|
||||
environment=env)
|
||||
fctx = FunctionContext(ic=ic, ir_f=ir_f, filename=filename, environment=env)
|
||||
|
||||
fdimport = FunctionDefImporter(fctx, ast_fd)
|
||||
fdimport.import_body()
|
||||
|
@ -131,7 +113,7 @@ class ImportFrontend:
|
|||
for advanced cases, including mutable global state, closures, etc.
|
||||
Globals from the module are considered immutable.
|
||||
"""
|
||||
ir_h = self._ir_h
|
||||
ic = self._ic
|
||||
try:
|
||||
code = f.__code__
|
||||
globals_dict = f.__globals__
|
||||
|
@ -149,7 +131,7 @@ class ImportFrontend:
|
|||
ConstModuleNameResolver(globals_dict, as_dict=True),
|
||||
ConstModuleNameResolver(builtins_module),
|
||||
)
|
||||
env = Environment(config=self._config, ir_h=ir_h, name_resolvers=resolvers)
|
||||
env = Environment(config=self._config, ic=ic, name_resolvers=resolvers)
|
||||
|
||||
# Bind parameters.
|
||||
for name, value in parameter_bindings:
|
||||
|
@ -158,9 +140,9 @@ class ImportFrontend:
|
|||
return env
|
||||
|
||||
def _resolve_signature_annotation(self, target: Target, annot):
|
||||
ir_h = self._ir_h
|
||||
ic = self._ic
|
||||
if annot is inspect.Signature.empty:
|
||||
return ir_h.basicpy_UnknownType
|
||||
return ic.unknown_type
|
||||
|
||||
# TODO: Do something real here once we need more than the primitive types.
|
||||
if annot is int:
|
||||
|
@ -168,23 +150,8 @@ class ImportFrontend:
|
|||
elif annot is float:
|
||||
return target.impl_float_type
|
||||
elif annot is bool:
|
||||
return ir_h.basicpy_BoolType
|
||||
return ic.bool_type
|
||||
elif annot is str:
|
||||
return ir_h.basicpy_StrType
|
||||
return ic.str_type
|
||||
else:
|
||||
return ir_h.basicpy_UnknownType
|
||||
|
||||
|
||||
################################################################################
|
||||
# Support
|
||||
################################################################################
|
||||
|
||||
|
||||
# TODO: Remove this hack in favor of a helper function that combines
|
||||
# multiple dialect helpers so that we don't need to deal with the sharp
|
||||
# edge of initializing multiple native base classes.
|
||||
class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
Numpy.DialectHelper.__init__(self, *args, **kwargs)
|
||||
ScfDialectHelper.__init__(self, *args, **kwargs)
|
||||
return ic.unknown_type
|
||||
|
|
|
@ -8,7 +8,11 @@ import ast
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from mlir import ir as _ir
|
||||
from mlir.dialects import std as std_ops
|
||||
|
||||
from npcomp import _cext
|
||||
from npcomp.dialects import basicpy as basicpy_ops
|
||||
|
||||
from ..utils import logging
|
||||
from .interfaces import *
|
||||
|
@ -23,24 +27,23 @@ __all__ = [
|
|||
class FunctionContext:
|
||||
"""Accounting information for importing a function."""
|
||||
__slots__ = [
|
||||
"ir_c",
|
||||
"ic",
|
||||
"ir_f",
|
||||
"ir_h",
|
||||
"filename_ident",
|
||||
"filename",
|
||||
"environment",
|
||||
]
|
||||
|
||||
def __init__(self, ir_c, ir_f, ir_h, filename_ident, environment):
|
||||
self.ir_c = ir_c
|
||||
def __init__(self, *, ic: ImportContext, ir_f: _ir.Operation, filename: str,
|
||||
environment: Environment):
|
||||
self.ic = ic
|
||||
self.ir_f = ir_f
|
||||
self.ir_h = ir_h
|
||||
self.filename_ident = filename_ident
|
||||
self.filename = filename
|
||||
self.environment = environment
|
||||
|
||||
def abort(self, message):
|
||||
"""Emits an error diagnostic and raises an exception to abort."""
|
||||
loc = self.current_loc
|
||||
ir.emit_error(loc, message)
|
||||
_cext.emit_error(loc, message)
|
||||
raise EmittedError(loc, message)
|
||||
|
||||
def check_partial_evaluated(self, result: PartialEvalResult):
|
||||
|
@ -53,18 +56,20 @@ class FunctionContext:
|
|||
else:
|
||||
message = ("Error while evaluating value from environment:\n" +
|
||||
"".join(traceback.format_exception(exc_type, exc_value, tb)))
|
||||
ir.emit_error(loc, message)
|
||||
|
||||
# TODO: Add this to the python API.
|
||||
_cext.emit_error(loc, message)
|
||||
raise EmittedError(loc, message)
|
||||
if result.type == PartialEvalType.NOT_EVALUATED:
|
||||
self.abort("Unable to evaluate expression")
|
||||
|
||||
@property
|
||||
def current_loc(self):
|
||||
return self.ir_h.builder.current_loc
|
||||
return self.ic.loc
|
||||
|
||||
def update_loc(self, ast_node):
|
||||
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
|
||||
ast_node.col_offset)
|
||||
self.ic.set_file_line_col(self.filename, ast_node.lineno,
|
||||
ast_node.col_offset)
|
||||
|
||||
def lookup_name(self, name) -> NameReference:
|
||||
"""Lookup a name in the environment, requiring it to have evaluated."""
|
||||
|
@ -74,7 +79,7 @@ class FunctionContext:
|
|||
logging.debug("Map name({}) -> {}", name, ref)
|
||||
return ref
|
||||
|
||||
def emit_const_value(self, py_value) -> ir.Value:
|
||||
def emit_const_value(self, py_value) -> _ir.Value:
|
||||
"""Codes a value as a constant, returning an ir Value."""
|
||||
env = self.environment
|
||||
result = env.code_py_value_as_const(py_value)
|
||||
|
@ -83,7 +88,7 @@ class FunctionContext:
|
|||
return result
|
||||
|
||||
def emit_partial_eval_result(self,
|
||||
partial_result: PartialEvalResult) -> ir.Value:
|
||||
partial_result: PartialEvalResult) -> _ir.Value:
|
||||
"""Emits a partial eval result either as a direct IR value or a constant."""
|
||||
self.check_partial_evaluated(partial_result)
|
||||
if partial_result.type == PartialEvalType.YIELDS_IR_VALUE:
|
||||
|
@ -134,17 +139,20 @@ class FunctionDefImporter(BaseNodeVisitor):
|
|||
self._last_was_return = False
|
||||
|
||||
def import_body(self):
|
||||
ir_h = self.fctx.ir_h
|
||||
ic = self.fctx.ic
|
||||
for ast_stmt in self.ast_fd.body:
|
||||
self._last_was_return = False
|
||||
logging.debug("STMT: {}", ast.dump(ast_stmt, include_attributes=True))
|
||||
self.visit(ast_stmt)
|
||||
if not self._last_was_return:
|
||||
# Add a default terminator.
|
||||
none_value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||
none_cast = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||
none_value).result
|
||||
ir_h.return_op([none_cast])
|
||||
none_value = basicpy_ops.SingletonOp(ic.none_type, loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
none_cast = basicpy_ops.UnknownCastOp(ic.unknown_type,
|
||||
none_value,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
std_ops.ReturnOp([none_cast], loc=ic.loc, ip=ic.ip)
|
||||
|
||||
def visit_Assign(self, ast_node):
|
||||
expr = ExpressionImporter(self.fctx)
|
||||
|
@ -164,27 +172,27 @@ class FunctionDefImporter(BaseNodeVisitor):
|
|||
"Cannot assign to '{}': Store not supported".format(name_ref))
|
||||
|
||||
def visit_Expr(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
_, ip = ir_h.basicpy_exec_op()
|
||||
ic = self.fctx.ic
|
||||
exec_ip = ic.basicpy_ExecOp()
|
||||
|
||||
# Evaluate the expression in the exec body.
|
||||
orig_ip = ir_h.builder.insertion_point
|
||||
ir_h.builder.insertion_point = ip
|
||||
ic.push_ip(exec_ip)
|
||||
expr = ExpressionImporter(self.fctx)
|
||||
expr.visit(ast_node.value)
|
||||
ir_h.basicpy_exec_discard_op([expr.value])
|
||||
ir_h.builder.insertion_point = orig_ip
|
||||
basicpy_ops.ExecDiscardOp([expr.value], loc=ic.loc, ip=ic.ip)
|
||||
ic.pop_ip()
|
||||
|
||||
def visit_Pass(self, ast_node):
|
||||
pass
|
||||
|
||||
def visit_Return(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
expr = ExpressionImporter(self.fctx)
|
||||
expr.visit(ast_node.value)
|
||||
casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||
expr.value).result
|
||||
ir_h.return_op([casted])
|
||||
self._last_was_return = True
|
||||
ic = self.fctx.ic
|
||||
with ic.loc, ic.ip:
|
||||
expr = ExpressionImporter(self.fctx)
|
||||
expr.visit(ast_node.value)
|
||||
casted = basicpy_ops.UnknownCastOp(ic.unknown_type, expr.value).result
|
||||
std_ops.ReturnOp([casted])
|
||||
self._last_was_return = True
|
||||
|
||||
|
||||
class ExpressionImporter(BaseNodeVisitor):
|
||||
|
@ -233,15 +241,20 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
ast.dump(ast_node)))
|
||||
|
||||
def visit_BinOp(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
ic = self.fctx.ic
|
||||
left = self.sub_evaluate(ast_node.left)
|
||||
right = self.sub_evaluate(ast_node.right)
|
||||
self.value = ir_h.basicpy_binary_expr_op(
|
||||
ir_h.basicpy_UnknownType, left, right,
|
||||
ast_node.op.__class__.__name__).result
|
||||
self.value = basicpy_ops.BinaryExprOp(ic.unknown_type,
|
||||
left,
|
||||
right,
|
||||
_ir.StringAttr.get(
|
||||
ast_node.op.__class__.__name__,
|
||||
context=ic.context),
|
||||
ip=ic.ip,
|
||||
loc=ic.loc).result
|
||||
|
||||
def visit_BoolOp(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
ic = self.fctx.ic
|
||||
if isinstance(ast_node.op, ast.And):
|
||||
return_first_true = False
|
||||
elif isinstance(ast_node.op, ast.Or):
|
||||
|
@ -255,25 +268,30 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
next_value = self.sub_evaluate(next_node)
|
||||
if not next_nodes:
|
||||
return next_value
|
||||
condition_value = ir_h.basicpy_as_i1_op(next_value).result
|
||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||
condition_value, True)
|
||||
orig_ip = ir_h.builder.insertion_point
|
||||
condition_value = basicpy_ops.AsI1Op(ic.i1_type, next_value,
|
||||
ip=ic.ip).result
|
||||
if_op, then_ip, else_ip = ic.scf_IfOp([ic.unknown_type], condition_value,
|
||||
True)
|
||||
# Short-circuit return case.
|
||||
ir_h.builder.insertion_point = then_ip if return_first_true else else_ip
|
||||
next_value_casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||
next_value).result
|
||||
ir_h.scf_yield_op([next_value_casted])
|
||||
ic.push_ip(then_ip if return_first_true else else_ip)
|
||||
next_value_casted = basicpy_ops.UnknownCastOp(ic.unknown_type,
|
||||
next_value,
|
||||
ip=ic.ip).result
|
||||
ic.scf_YieldOp([next_value_casted])
|
||||
ic.pop_ip()
|
||||
|
||||
# Nested evaluate next case.
|
||||
ir_h.builder.insertion_point = else_ip if return_first_true else then_ip
|
||||
ic.push_ip(else_ip if return_first_true else then_ip)
|
||||
nested_value = emit_next(next_nodes)
|
||||
nested_value_casted = next_value_casted = ir_h.basicpy_unknown_cast_op(
|
||||
ir_h.basicpy_UnknownType, nested_value).result
|
||||
ir_h.scf_yield_op([nested_value_casted])
|
||||
ir_h.builder.insertion_point = orig_ip
|
||||
nested_value_casted = next_value_casted = basicpy_ops.UnknownCastOp(
|
||||
ic.unknown_type, nested_value, ip=ic.ip).result
|
||||
ic.scf_YieldOp([nested_value_casted])
|
||||
ic.pop_ip()
|
||||
|
||||
return if_op.result
|
||||
|
||||
self.value = emit_next(ast_node.values)
|
||||
with ic.loc:
|
||||
self.value = emit_next(ast_node.values)
|
||||
|
||||
def visit_Call(self, ast_node):
|
||||
# Evaluate positional args.
|
||||
|
@ -311,15 +329,23 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
def visit_Compare(self, ast_node):
|
||||
# Short-circuit comparison (degenerates to binary comparison when just
|
||||
# two operands).
|
||||
ir_h = self.fctx.ir_h
|
||||
false_value = ir_h.basicpy_bool_constant_op(False).result
|
||||
ic = self.fctx.ic
|
||||
false_value = basicpy_ops.BoolConstantOp(ic.bool_type,
|
||||
ic.i1_false,
|
||||
ip=ic.ip,
|
||||
loc=ic.loc).result
|
||||
|
||||
def emit_next(left_value, comparisons):
|
||||
operation, right_node = comparisons[0]
|
||||
comparisons = comparisons[1:]
|
||||
right_value = self.sub_evaluate(right_node)
|
||||
compare_result = ir_h.basicpy_binary_compare_op(
|
||||
left_value, right_value, operation.__class__.__name__).result
|
||||
compare_result = basicpy_ops.BinaryCompareOp(
|
||||
ic.bool_type,
|
||||
left_value,
|
||||
right_value,
|
||||
_ir.StringAttr.get(operation.__class__.__name__),
|
||||
ip=ic.ip,
|
||||
loc=ic.loc).result
|
||||
# Terminate by yielding the final compare result.
|
||||
if not comparisons:
|
||||
return compare_result
|
||||
|
@ -327,47 +353,56 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
# Emit 'if' op and recurse. The if op takes an i1 (core dialect
|
||||
# requirement) and returns a basicpy.BoolType. Since this is an 'and',
|
||||
# all else clauses yield a false value.
|
||||
compare_result_i1 = ir_h.basicpy_bool_cast_op(ir_h.i1_type,
|
||||
compare_result).result
|
||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_BoolType],
|
||||
compare_result_i1, True)
|
||||
orig_ip = ir_h.builder.insertion_point
|
||||
compare_result_i1 = basicpy_ops.BoolCastOp(ic.i1_type,
|
||||
compare_result,
|
||||
ip=ic.ip,
|
||||
loc=ic.loc).result
|
||||
if_op, then_ip, else_ip = ic.scf_IfOp([ic.bool_type], compare_result_i1,
|
||||
True)
|
||||
# Build the else clause.
|
||||
ir_h.builder.insertion_point = else_ip
|
||||
ir_h.scf_yield_op([false_value])
|
||||
ic.push_ip(else_ip)
|
||||
ic.scf_YieldOp([false_value])
|
||||
ic.pop_ip()
|
||||
|
||||
# Build the then clause.
|
||||
ir_h.builder.insertion_point = then_ip
|
||||
ic.push_ip(then_ip)
|
||||
nested_result = emit_next(right_value, comparisons)
|
||||
ir_h.scf_yield_op([nested_result])
|
||||
ir_h.builder.insertion_point = orig_ip
|
||||
ic.scf_YieldOp([nested_result])
|
||||
ic.pop_ip()
|
||||
|
||||
return if_op.result
|
||||
|
||||
self.value = emit_next(self.sub_evaluate(ast_node.left),
|
||||
list(zip(ast_node.ops, ast_node.comparators)))
|
||||
|
||||
def visit_IfExp(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
test_result = ir_h.basicpy_as_i1_op(self.sub_evaluate(ast_node.test)).result
|
||||
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
|
||||
test_result, True)
|
||||
|
||||
orig_ip = ir_h.builder.insertion_point
|
||||
ic = self.fctx.ic
|
||||
test_result = basicpy_ops.AsI1Op(ic.i1_type,
|
||||
self.sub_evaluate(ast_node.test),
|
||||
ip=ic.ip,
|
||||
loc=ic.loc).result
|
||||
if_op, then_ip, else_ip = ic.scf_IfOp([ic.unknown_type], test_result, True)
|
||||
# Build the then clause
|
||||
ir_h.builder.insertion_point = then_ip
|
||||
ic.push_ip(then_ip)
|
||||
then_result = self.sub_evaluate(ast_node.body)
|
||||
ir_h.scf_yield_op([
|
||||
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||
then_result).result
|
||||
ic.scf_YieldOp([
|
||||
basicpy_ops.UnknownCastOp(ic.unknown_type,
|
||||
then_result,
|
||||
ip=ic.ip,
|
||||
loc=ic.loc).result
|
||||
])
|
||||
# Build the then clause.
|
||||
ir_h.builder.insertion_point = else_ip
|
||||
orelse_result = self.sub_evaluate(ast_node.orelse)
|
||||
ir_h.scf_yield_op([
|
||||
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
|
||||
orelse_result).result
|
||||
])
|
||||
ir_h.builder.insertion_point = orig_ip
|
||||
ic.pop_ip()
|
||||
|
||||
# Build the then clause.
|
||||
ic.push_ip(else_ip)
|
||||
orelse_result = self.sub_evaluate(ast_node.orelse)
|
||||
ic.scf_YieldOp([
|
||||
basicpy_ops.UnknownCastOp(ic.unknown_type,
|
||||
orelse_result,
|
||||
ip=ic.ip,
|
||||
loc=ic.loc).result
|
||||
])
|
||||
ic.pop_ip()
|
||||
self.value = if_op.result
|
||||
|
||||
def visit_Name(self, ast_node):
|
||||
|
@ -380,18 +415,20 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
self.value = self.fctx.emit_partial_eval_result(pe_result)
|
||||
|
||||
def visit_UnaryOp(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
op = ast_node.op
|
||||
operand_value = self.sub_evaluate(ast_node.operand)
|
||||
if isinstance(op, ast.Not):
|
||||
# Special handling for logical-not.
|
||||
condition_value = ir_h.basicpy_as_i1_op(operand_value).result
|
||||
true_value = ir_h.basicpy_bool_constant_op(True).result
|
||||
false_value = ir_h.basicpy_bool_constant_op(False).result
|
||||
self.value = ir_h.select_op(condition_value, false_value,
|
||||
true_value).result
|
||||
else:
|
||||
self.fctx.abort("Unknown unary op %r", (ast.dump(op)))
|
||||
ic = self.fctx.ic
|
||||
with ic.ip, ic.loc:
|
||||
op = ast_node.op
|
||||
operand_value = self.sub_evaluate(ast_node.operand)
|
||||
if isinstance(op, ast.Not):
|
||||
# Special handling for logical-not.
|
||||
condition_value = basicpy_ops.AsI1Op(ic.i1_type, operand_value).result
|
||||
true_value = basicpy_ops.BoolConstantOp(ic.bool_type, ic.i1_true).result
|
||||
false_value = basicpy_ops.BoolConstantOp(ic.bool_type,
|
||||
ic.i1_false).result
|
||||
self.value = std_ops.SelectOp(ic.bool_type, condition_value,
|
||||
false_value, true_value).result
|
||||
else:
|
||||
self.fctx.abort("Unknown unary op %r", (ast.dump(op)))
|
||||
|
||||
if sys.version_info < (3, 8, 0):
|
||||
# <3.8 breaks these out into separate AST classes.
|
||||
|
|
|
@ -6,15 +6,18 @@
|
|||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
import sys
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mlir import ir as _ir
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from .target import *
|
||||
from ..utils.mlir_utils import *
|
||||
|
||||
__all__ = [
|
||||
"Configuration",
|
||||
"EmittedError",
|
||||
"Environment",
|
||||
"ImportContext",
|
||||
"NameReference",
|
||||
"NameResolver",
|
||||
"PartialEvalHook",
|
||||
|
@ -82,21 +85,18 @@ class NameReference:
|
|||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def load(self, env: "Environment",
|
||||
ir_h: ir.DialectHelper) -> "PartialEvalResult":
|
||||
def load(self, env: "Environment") -> "PartialEvalResult":
|
||||
"""Loads the IR Value associated with the name.
|
||||
|
||||
The load may either be direct, returning an existing value or
|
||||
side-effecting, causing a read from an external context.
|
||||
|
||||
Args:
|
||||
ir_h: The dialect helper used to emit code.
|
||||
Returns:
|
||||
A partial evaluation result.
|
||||
"""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper):
|
||||
def store(self, env: "Environment", value: _ir.Value):
|
||||
"""Stores a new value into the name.
|
||||
|
||||
A subsequent call to 'load' should yield the same value, subject to
|
||||
|
@ -104,7 +104,6 @@ class NameReference:
|
|||
|
||||
Args:
|
||||
value: The new value to store into the name.
|
||||
ir_h: The dialect helper used to emit code.
|
||||
Raises:
|
||||
NotImplementedError if store is not supported for this name.
|
||||
"""
|
||||
|
@ -143,7 +142,7 @@ class ValueCoder:
|
|||
__slots__ = []
|
||||
|
||||
def code_py_value_as_const(self, env: "Environment",
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
py_value) -> Union[_NotImplementedType, _ir.Value]:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
|
@ -158,7 +157,7 @@ class ValueCoderChain(ValueCoder):
|
|||
return "ValueCoderChain({})".format(self._sub_coders)
|
||||
|
||||
def code_py_value_as_const(self, env: "Environment",
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
py_value) -> Union[_NotImplementedType, _ir.Value]:
|
||||
for sc in self._sub_coders:
|
||||
result = sc.code_py_value_as_const(env, py_value)
|
||||
if result is not NotImplemented:
|
||||
|
@ -207,8 +206,8 @@ class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
|
|||
return PartialEvalResult(PartialEvalType.YIELDS_LIVE_VALUE, live_value)
|
||||
|
||||
@staticmethod
|
||||
def yields_ir_value(ir_value: ir.Value) -> "PartialEvalResult":
|
||||
assert isinstance(ir_value, ir.Value)
|
||||
def yields_ir_value(ir_value: _ir.Value) -> "PartialEvalResult":
|
||||
assert isinstance(ir_value, _ir.Value)
|
||||
return PartialEvalResult(PartialEvalType.YIELDS_IR_VALUE, ir_value)
|
||||
|
||||
@staticmethod
|
||||
|
@ -247,7 +246,7 @@ class LiveValueRef:
|
|||
"""Gets a named attribute from the live value."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def resolve_call(self, env: "Environment", args: Sequence[ir.Value],
|
||||
def resolve_call(self, env: "Environment", args: Sequence[_ir.Value],
|
||||
keywords: Sequence[str]) -> PartialEvalResult:
|
||||
"""Resolves a function call given 'args' and 'keywords'."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
@ -302,7 +301,6 @@ class Environment:
|
|||
"""Instantiated configuration for emitting code in a specific context.
|
||||
|
||||
This brings together:
|
||||
- The code generation context (ir_h)
|
||||
- An instantiated target
|
||||
- Delegating interfaces for other configuration objects.
|
||||
|
||||
|
@ -313,7 +311,7 @@ class Environment:
|
|||
"""
|
||||
__slots__ = [
|
||||
"config",
|
||||
"ir_h",
|
||||
"ic",
|
||||
"_name_resolvers",
|
||||
"target",
|
||||
]
|
||||
|
@ -321,12 +319,12 @@ class Environment:
|
|||
def __init__(self,
|
||||
*,
|
||||
config: Configuration,
|
||||
ir_h: ir.DialectHelper,
|
||||
ic: ImportContext,
|
||||
name_resolvers: Sequence[NameResolver] = ()):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.ir_h = ir_h
|
||||
self.target = config.target_factory(self.ir_h)
|
||||
self.ic = ic
|
||||
self.target = config.target_factory(ic)
|
||||
self._name_resolvers = (tuple(name_resolvers) +
|
||||
self.config.base_name_resolvers)
|
||||
|
||||
|
@ -341,5 +339,5 @@ class Environment:
|
|||
return self.config.partial_eval_hook.partial_evaluate(py_value)
|
||||
|
||||
def code_py_value_as_const(self,
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
py_value) -> Union[_NotImplementedType, _ir.Value]:
|
||||
return self.config.value_coder.code_py_value_as_const(self, py_value)
|
||||
|
|
|
@ -5,8 +5,7 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
from mlir import ir as _ir
|
||||
from .interfaces import *
|
||||
|
||||
__all__ = [
|
||||
|
@ -36,7 +35,7 @@ class LocalNameReference(NameReference):
|
|||
"Attempt to access local '{}' before assignment".format(self.name))
|
||||
return PartialEvalResult.yields_ir_value(self._current_value)
|
||||
|
||||
def store(self, env: Environment, value: ir.Value):
|
||||
def store(self, env: Environment, value: _ir.Value):
|
||||
self._current_value = value
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -52,9 +52,9 @@ class TemplateCallLiveValueRef(LiveValueRef):
|
|||
kw_arg_names.append(kw_name)
|
||||
linear_args.append(kw_value)
|
||||
|
||||
ir_h = env.ir_h
|
||||
result_ir_value = ir_h.basicpy_func_template_call_op(
|
||||
result_type=ir_h.basicpy_UnknownType,
|
||||
ic = env.ic
|
||||
result_ir_value = ic.basicpy_FuncTemplateCallOp(
|
||||
result_type=ic.unknown_type,
|
||||
callee_symbol=self.callee_name,
|
||||
args=linear_args,
|
||||
arg_names=kw_arg_names).result
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import *
|
||||
from _npcomp.mlir import ir
|
||||
from mlir import ir as _ir
|
||||
|
||||
from ..utils.mlir_utils import *
|
||||
|
||||
__all__ = [
|
||||
"GenericTarget32",
|
||||
|
@ -19,32 +21,23 @@ class Target:
|
|||
target.
|
||||
"""
|
||||
__slots__ = [
|
||||
"_mlir_helper",
|
||||
"ic",
|
||||
]
|
||||
|
||||
def __init__(self, mlir_helper: ir.DialectHelper):
|
||||
super().__init__()
|
||||
self._mlir_helper = mlir_helper
|
||||
|
||||
@property
|
||||
def mlir_helper(self):
|
||||
return self._mlir_helper
|
||||
|
||||
@property
|
||||
def mlir_context(self):
|
||||
return self._mlir_helper.context
|
||||
def __init__(self, ic):
|
||||
self.ic = ic
|
||||
|
||||
@property
|
||||
def target_name(self) -> str:
|
||||
return NotImplementedError()
|
||||
|
||||
@property
|
||||
def impl_int_type(self) -> ir.Type:
|
||||
def impl_int_type(self) -> _ir.Type:
|
||||
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def impl_float_type(self) -> ir.Type:
|
||||
def impl_float_type(self) -> _ir.Type:
|
||||
"""Gets the implementation's type for the python 'float' type."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -57,14 +50,14 @@ class GenericTarget64(Target):
|
|||
return "generic64"
|
||||
|
||||
@property
|
||||
def impl_int_type(self) -> ir.Type:
|
||||
def impl_int_type(self) -> _ir.Type:
|
||||
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||
return self.mlir_helper.i64_type
|
||||
return _ir.IntegerType.get_signless(64, context=self.ic.context)
|
||||
|
||||
@property
|
||||
def impl_float_type(self) -> ir.Type:
|
||||
def impl_float_type(self) -> _ir.Type:
|
||||
"""Gets the implementation's type for the python 'float' type."""
|
||||
return self.mlir_helper.f64_type
|
||||
return _ir.F64Type.get(context=self.ic.context)
|
||||
|
||||
|
||||
class GenericTarget32(Target):
|
||||
|
@ -75,15 +68,15 @@ class GenericTarget32(Target):
|
|||
return "generic32"
|
||||
|
||||
@property
|
||||
def impl_int_type(self) -> ir.Type:
|
||||
def impl_int_type(self) -> _ir.Type:
|
||||
"""Gets the default int type for the backend for the Python 'int' type."""
|
||||
return self.mlir_helper.i32_type
|
||||
return _ir.IntegerType.get_signless(32, context=self.ic.context)
|
||||
|
||||
@property
|
||||
def impl_float_type(self) -> ir.Type:
|
||||
def impl_float_type(self) -> _ir.Type:
|
||||
"""Gets the implementation's type for the python 'float' type."""
|
||||
return self.mlir_helper.f32_type
|
||||
return _ir.F32Type.get(context=self.ic.context)
|
||||
|
||||
|
||||
# Factory for producing a target (matches the Target constructor).
|
||||
TargetFactory = Callable[[ir.DialectHelper], Target]
|
||||
TargetFactory = Callable[[ImportContext], Target]
|
||||
|
|
|
@ -25,7 +25,7 @@ def create_import_dump_decorator(*,
|
|||
fe = ImportFrontend(config=config)
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
print(fe.ir_module.operation.get_asm())
|
||||
return f
|
||||
|
||||
def decorator(*args, expect_error=None):
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
|
||||
from typing import Union
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from mlir import ir as _ir
|
||||
from mlir.dialects import std as std_ops
|
||||
from npcomp.dialects import basicpy as basicpy_ops
|
||||
|
||||
from .interfaces import *
|
||||
|
||||
|
@ -21,27 +23,29 @@ class BuiltinsValueCoder(ValueCoder):
|
|||
__slots__ = []
|
||||
|
||||
def code_py_value_as_const(self, env: Environment,
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
ir_h = env.ir_h
|
||||
ir_c = ir_h.context
|
||||
if py_value is True:
|
||||
return ir_h.basicpy_bool_constant_op(True).result
|
||||
elif py_value is False:
|
||||
return ir_h.basicpy_bool_constant_op(False).result
|
||||
elif py_value is None:
|
||||
return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||
elif isinstance(py_value, int):
|
||||
ir_type = env.target.impl_int_type
|
||||
ir_attr = ir_c.integer_attr(ir_type, py_value)
|
||||
return ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, float):
|
||||
ir_type = env.target.impl_float_type
|
||||
ir_attr = ir_c.float_attr(ir_type, py_value)
|
||||
return ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, str):
|
||||
return ir_h.basicpy_str_constant_op(py_value).result
|
||||
elif isinstance(py_value, bytes):
|
||||
return ir_h.basicpy_bytes_constant_op(py_value).result
|
||||
elif isinstance(py_value, type(...)):
|
||||
return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
|
||||
return NotImplemented
|
||||
py_value) -> Union[_NotImplementedType, _ir.Value]:
|
||||
ic = env.ic
|
||||
with ic.loc, ic.ip:
|
||||
if py_value is True:
|
||||
return basicpy_ops.BoolConstantOp(ic.bool_type, ic.i1_true).result
|
||||
elif py_value is False:
|
||||
return basicpy_ops.BoolConstantOp(ic.bool_type, ic.i1_false).result
|
||||
elif py_value is None:
|
||||
return basicpy_ops.SingletonOp(ic.none_type).result
|
||||
elif isinstance(py_value, int):
|
||||
ir_type = env.target.impl_int_type
|
||||
ir_attr = _ir.IntegerAttr.get(ir_type, py_value)
|
||||
return std_ops.ConstantOp(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, float):
|
||||
ir_type = env.target.impl_float_type
|
||||
ir_attr = _ir.FloatAttr.get(ir_type, py_value)
|
||||
return std_ops.ConstantOp(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, str):
|
||||
return basicpy_ops.StrConstantOp(ic.str_type,
|
||||
_ir.StringAttr.get(py_value)).result
|
||||
elif isinstance(py_value, bytes):
|
||||
return basicpy_ops.BytesConstantOp(ic.bytes_type,
|
||||
_ir.StringAttr.get(py_value)).result
|
||||
elif isinstance(py_value, type(...)):
|
||||
return basicpy_ops.SingletonOp(ic.ellipsis_type).result
|
||||
return NotImplemented
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""General utilities for working with MLIR."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from mlir import ir as _ir
|
||||
from npcomp import _cext
|
||||
|
||||
__all__ = [
|
||||
"ImportContext",
|
||||
]
|
||||
|
||||
|
||||
class ImportContext:
|
||||
"""Simple container for things that we update while importing.
|
||||
|
||||
This is also where we stash various helpers to work around awkward/missing
|
||||
MLIR Python API features.
|
||||
"""
|
||||
__slots__ = [
|
||||
"context",
|
||||
"loc",
|
||||
"module",
|
||||
"_ip_stack",
|
||||
|
||||
# Cached types.
|
||||
"unknown_type",
|
||||
"bool_type",
|
||||
"bytes_type",
|
||||
"ellipsis_type",
|
||||
"i1_type",
|
||||
"index_type",
|
||||
"none_type",
|
||||
"str_type",
|
||||
"unknown_array_type",
|
||||
"unknown_tensor_type",
|
||||
|
||||
# Cached attributes.
|
||||
"i1_true",
|
||||
"i1_false",
|
||||
]
|
||||
|
||||
def __init__(self, context: Optional[_ir.Context]):
|
||||
self.context = _ir.Context() if not context else context
|
||||
_cext.register_all_dialects(self.context)
|
||||
|
||||
self.loc = _ir.Location.unknown(context=self.context) # type: _ir.Location
|
||||
self.module = None # type: Optional[_ir.Module]
|
||||
self._ip_stack = []
|
||||
|
||||
# Cache some types and attributes.
|
||||
with self.context:
|
||||
# Types.
|
||||
# TODO: Consolidate numpy.any_dtype and basicpy.UnknownType.
|
||||
self.unknown_type = _ir.Type.parse("!basicpy.UnknownType")
|
||||
self.bool_type = _ir.Type.parse("!basicpy.BoolType")
|
||||
self.bytes_type = _ir.Type.parse("!basicpy.BytesType")
|
||||
self.ellipsis_type = _ir.Type.parse("!basicpy.EllipsisType")
|
||||
self.none_type = _ir.Type.parse("!basicpy.NoneType")
|
||||
self.str_type = _ir.Type.parse("!basicpy.StrType")
|
||||
self.i1_type = _ir.IntegerType.get_signless(1)
|
||||
self.index_type = _ir.IndexType.get()
|
||||
self.unknown_tensor_type = _ir.UnrankedTensorType.get(self.unknown_type,
|
||||
loc=self.loc)
|
||||
self.unknown_array_type = _cext.shaped_to_ndarray_type(
|
||||
self.unknown_tensor_type)
|
||||
|
||||
# Attributes.
|
||||
self.i1_true = _ir.IntegerAttr.get(self.i1_type, 1)
|
||||
self.i1_false = _ir.IntegerAttr.get(self.i1_type, 0)
|
||||
|
||||
def set_file_line_col(self, file: str, line: int, col: int):
|
||||
self.loc = _ir.Location.file(file, line, col, context=self.context)
|
||||
|
||||
def push_ip(self, new_ip: _ir.InsertionPoint):
|
||||
self._ip_stack.append(new_ip)
|
||||
|
||||
def pop_ip(self):
|
||||
assert self._ip_stack, "Mismatched push_ip/pop_ip: stack is empty on pop"
|
||||
del self._ip_stack[-1]
|
||||
|
||||
@property
|
||||
def ip(self):
|
||||
assert self._ip_stack, "InsertionPoint requested but stack is empty"
|
||||
return self._ip_stack[-1]
|
||||
|
||||
def insert_before_terminator(self, block: _ir.Block):
|
||||
self.push_ip(_ir.InsertionPoint.at_block_terminator(block))
|
||||
|
||||
def insert_end_of_block(self, block: _ir.Block):
|
||||
self.push_ip(_ir.InsertionPoint(block))
|
||||
|
||||
def FuncOp(self, name: str, func_type: _ir.Type,
|
||||
create_entry_block: bool) -> Tuple[_ir.Operation, _ir.Block]:
|
||||
"""Creates a |func| op.
|
||||
|
||||
TODO: This should really be in the MLIR API.
|
||||
Returns:
|
||||
(operation, entry_block)
|
||||
"""
|
||||
with self.context, self.loc:
|
||||
attrs = {
|
||||
"type": _ir.TypeAttr.get(func_type),
|
||||
"sym_name": _ir.StringAttr.get(name),
|
||||
}
|
||||
op = _ir.Operation.create("func", regions=1, attributes=attrs, ip=self.ip)
|
||||
body_region = op.regions[0]
|
||||
entry_block = body_region.blocks.append(*func_type.inputs)
|
||||
return op, entry_block
|
||||
|
||||
def basicpy_ExecOp(self):
|
||||
"""Creates a basicpy.exec op.
|
||||
|
||||
Returns:
|
||||
Insertion point to the body.
|
||||
"""
|
||||
op = _ir.Operation.create("basicpy.exec",
|
||||
regions=1,
|
||||
ip=self.ip,
|
||||
loc=self.loc)
|
||||
b = op.regions[0].blocks.append()
|
||||
return _ir.InsertionPoint(b)
|
||||
|
||||
def basicpy_FuncTemplateCallOp(self, result_type, callee_symbol, args,
|
||||
arg_names):
|
||||
with self.loc, self.ip:
|
||||
attributes = {
|
||||
"callee":
|
||||
_ir.FlatSymbolRefAttr.get(callee_symbol),
|
||||
"arg_names":
|
||||
_ir.ArrayAttr.get([_ir.StringAttr.get(n) for n in arg_names]),
|
||||
}
|
||||
op = _ir.Operation.create("basicpy.func_template_call",
|
||||
results=[result_type],
|
||||
operands=args,
|
||||
attributes=attributes,
|
||||
ip=self.ip)
|
||||
return op
|
||||
|
||||
def scf_IfOp(self, results, condition: _ir.Value, with_else_region: bool):
|
||||
"""Creates an SCF if op.
|
||||
|
||||
Returns:
|
||||
(if_op, then_ip, else_ip) if with_else_region, otherwise (if_op, then_ip)
|
||||
"""
|
||||
op = _ir.Operation.create("scf.if",
|
||||
results=results,
|
||||
operands=[condition],
|
||||
regions=2 if with_else_region else 1,
|
||||
loc=self.loc,
|
||||
ip=self.ip)
|
||||
then_region = op.regions[0]
|
||||
then_block = then_region.blocks.append()
|
||||
if with_else_region:
|
||||
else_region = op.regions[1]
|
||||
else_block = else_region.blocks.append()
|
||||
return op, _ir.InsertionPoint(then_block), _ir.InsertionPoint(else_block)
|
||||
else:
|
||||
return op, _ir.InsertionPoint(then_block)
|
||||
|
||||
def scf_YieldOp(self, operands):
|
||||
return _ir.Operation.create("scf.yield",
|
||||
operands=operands,
|
||||
loc=self.loc,
|
||||
ip=self.ip)
|
|
@ -1,112 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from _npcomp.dialect import BasicpyDialectHelper as _BaseDialectHelper
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
__all__ = [
|
||||
"DialectHelper",
|
||||
]
|
||||
|
||||
|
||||
class DialectHelper(_BaseDialectHelper):
|
||||
r"""Dialect helper for the Basicpy dialect.
|
||||
|
||||
>>> c = ir.MLIRContext()
|
||||
>>> h = DialectHelper(c, ir.OpBuilder(c))
|
||||
|
||||
Dialect Types:
|
||||
>>> h.basicpy_NoneType
|
||||
!basicpy.NoneType
|
||||
>>> h.basicpy_EllipsisType
|
||||
!basicpy.EllipsisType
|
||||
>>> h.basicpy_SlotObject_type(
|
||||
... "foobar", h.basicpy_NoneType, h.basicpy_NoneType)
|
||||
!basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||
|
||||
singleton op:
|
||||
>>> m = c.new_module()
|
||||
>>> h.builder.insert_block_start(m.first_block)
|
||||
>>> _ = h.basicpy_singleton_op(h.basicpy_NoneType)
|
||||
>>> m.to_asm().strip()
|
||||
'module {\n %0 = basicpy.singleton : !basicpy.NoneType\n}'
|
||||
|
||||
slot_object ops:
|
||||
>>> m = c.new_module()
|
||||
>>> h.builder.insert_block_start(m.first_block)
|
||||
>>> v0 = h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||
>>> slot_object = h.basicpy_slot_object_make_op("foobar", v0, v0).result
|
||||
>>> _ = h.basicpy_slot_object_get_op(slot_object, 0)
|
||||
>>> print(m.to_asm().strip())
|
||||
module {
|
||||
%0 = basicpy.singleton : !basicpy.NoneType
|
||||
%1 = basicpy.slot_object_make(%0, %0) -> !basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||
%2 = basicpy.slot_object_get %1[0] : !basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
def basicpy_binary_expr_op(self, result_type, lhs, rhs, operation_name):
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({"operation": c.string_attr(operation_name)})
|
||||
return self.op("basicpy.binary_expr", [result_type], [lhs, rhs], attrs)
|
||||
|
||||
def basicpy_bool_cast_op(self, result_type, value):
|
||||
return self.op("basicpy.bool_cast", [result_type], [value])
|
||||
|
||||
def basicpy_bool_constant_op(self, value):
|
||||
c = self.context
|
||||
ival = 1 if value else 0
|
||||
attrs = c.dictionary_attr({"value": c.integer_attr(self.i1_type, ival)})
|
||||
return self.op("basicpy.bool_constant", [self.basicpy_BoolType], [], attrs)
|
||||
|
||||
def basicpy_bytes_constant_op(self, value):
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({"value": c.string_attr(value)})
|
||||
return self.op("basicpy.bytes_constant", [self.basicpy_BytesType], [],
|
||||
attrs)
|
||||
|
||||
def basicpy_binary_compare_op(self, lhs, rhs, operation_name):
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({"operation": c.string_attr(operation_name)})
|
||||
return self.op("basicpy.binary_compare", [self.basicpy_BoolType],
|
||||
[lhs, rhs], attrs)
|
||||
|
||||
def basicpy_singleton_op(self, singleton_type):
|
||||
return self.op("basicpy.singleton", [singleton_type], [])
|
||||
|
||||
def basicpy_slot_object_make_op(self, class_name, *slot_values):
|
||||
c = self.context
|
||||
class_name_attr = c.string_attr(class_name)
|
||||
object_type = self.basicpy_SlotObject_type(class_name,
|
||||
*[v.type for v in slot_values])
|
||||
attrs = c.dictionary_attr({"className": class_name_attr})
|
||||
return self.op("basicpy.slot_object_make", [object_type], slot_values,
|
||||
attrs)
|
||||
|
||||
def basicpy_str_constant_op(self, value):
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))})
|
||||
return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs)
|
||||
|
||||
def basicpy_as_i1_op(self, value):
|
||||
return self.op("basicpy.as_i1", [self.i1_type], [value])
|
||||
|
||||
def basicpy_unknown_cast_op(self, result_type, operand):
|
||||
return self.op("basicpy.unknown_cast", [result_type], [operand])
|
||||
|
||||
def basicpy_func_template_call_op(self, result_type, callee_symbol, args,
|
||||
arg_names):
|
||||
"""Creates a basicpy.func_template_call op."""
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({
|
||||
"callee": c.flat_symbol_ref_attr(callee_symbol),
|
||||
"arg_names": c.array_attr([c.string_attr(n) for n in arg_names]),
|
||||
})
|
||||
return self.op("basicpy.func_template_call", [result_type], args, attrs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -1,80 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
from npcomp.dialect import Basicpy
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
__all__ = [
|
||||
"load_builtin_module",
|
||||
"DialectHelper",
|
||||
]
|
||||
|
||||
|
||||
class DialectHelper(Basicpy.DialectHelper):
|
||||
r"""Dialect helper.
|
||||
|
||||
>>> c = ir.MLIRContext()
|
||||
>>> h = DialectHelper(c, ir.OpBuilder(c))
|
||||
|
||||
DenseElementsAttrs:
|
||||
>>> c.dense_elements_attr(np.asarray([1, 2, 3, 4], dtype=np.int32))
|
||||
dense<[1, 2, 3, 4]> : tensor<4xsi32>
|
||||
>>> c.dense_elements_attr(np.asarray([[1, 2], [3, 4]], dtype=np.int32))
|
||||
dense<[[1, 2], [3, 4]]> : tensor<2x2xsi32>
|
||||
>>> c.dense_elements_attr(np.asarray([[1., 2.], [3., 4.]]))
|
||||
dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf64>
|
||||
>>> c.dense_elements_attr(np.asarray([[1., 2.], [3., 4.]], dtype=np.float32))
|
||||
dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>
|
||||
|
||||
Types:
|
||||
>>> c = ir.MLIRContext()
|
||||
>>> t = DialectHelper(c, ir.OpBuilder(c))
|
||||
>>> t.numpy_any_dtype
|
||||
!basicpy.UnknownType
|
||||
>>> t.tensor_type(t.numpy_any_dtype, [1, 2, 3])
|
||||
tensor<1x2x3x!basicpy.UnknownType>
|
||||
>>> t.tensor_type(t.numpy_any_dtype)
|
||||
tensor<*x!basicpy.UnknownType>
|
||||
>>> t.tensor_type(t.numpy_any_dtype, [-1, 2])
|
||||
tensor<?x2x!basicpy.UnknownType>
|
||||
>>> t.tensor_type(t.f32_type)
|
||||
tensor<*xf32>
|
||||
>>> t.function_type([t.i32_type], [t.f32_type])
|
||||
(i32) -> f32
|
||||
>>> t.numpy_unknown_tensor_type
|
||||
tensor<*x!basicpy.UnknownType>
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def numpy_any_dtype(self):
|
||||
return self.basicpy_UnknownType
|
||||
|
||||
@property
|
||||
def numpy_unknown_tensor_type(self):
|
||||
return self.tensor_type(self.basicpy_UnknownType)
|
||||
|
||||
@property
|
||||
def unknown_array_type(self):
|
||||
return self.numpy_NdArrayType(self.basicpy_UnknownType)
|
||||
|
||||
def numpy_builtin_ufunc_call_op(self, *args, qualified_name, result_type):
|
||||
"""Creates a numpy.builtin_ufunc_call op."""
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({"qualified_name": c.string_attr(qualified_name)})
|
||||
return self.op("numpy.builtin_ufunc_call", [result_type], args, attrs)
|
||||
|
||||
def numpy_narrow_op(self, result_type, operand):
|
||||
"""Creates a numpy.narrow op."""
|
||||
return self.op("numpy.narrow", [result_type], [operand])
|
||||
|
||||
def numpy_get_slice_op(self, result_type, array, *slice_elements):
|
||||
return self.op("numpy.get_slice", [result_type],
|
||||
[array] + list(slice_elements))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -0,0 +1,15 @@
|
|||
//===-- ATenBind.td - Aten dialect bind --------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_DIALECTS_ATEN_BIND
|
||||
#define NPCOMP_PYTHON_DIALECTS_ATEN_BIND
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "npcomp/Dialect/ATen/IR/ATenOps.td"
|
||||
|
||||
#endif
|
|
@ -0,0 +1,15 @@
|
|||
//===-- BasicpyBind.td - Basicpy dialect bind --------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_DIALECTS_BASICPY_BIND
|
||||
#define NPCOMP_PYTHON_DIALECTS_BASICPY_BIND
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "npcomp/Dialect/Basicpy/IR/BasicpyOps.td"
|
||||
|
||||
#endif
|
|
@ -0,0 +1,12 @@
|
|||
function(_add_dialect target td_file bind_name)
|
||||
set(LLVM_TARGET_DEFINITIONS ${td_file})
|
||||
mlir_tablegen("${bind_name}.py" -gen-python-op-bindings -bind-dialect=${bind_name})
|
||||
add_public_tablegen_target(${target})
|
||||
add_dependencies(NPCOMPNativePyExt ${target})
|
||||
endfunction()
|
||||
|
||||
_add_dialect(NPCOMPPyDialectATen ATenBind.td "aten")
|
||||
_add_dialect(NPCOMPPyDialectBasicpy BasicpyBind.td "basicpy")
|
||||
_add_dialect(NPCOMPPyDialectNumpy NumpyBind.td "numpy")
|
||||
_add_dialect(NPCOMPPyDialectTCF TCFBind.td "tcf")
|
||||
_add_dialect(NPCOMPPyDialectTorch TorchBind.td "torch")
|
|
@ -0,0 +1,15 @@
|
|||
//===-- NumpyOps.td - Numpy dialect bind -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_DIALECTS_NUMPY_BIND
|
||||
#define NPCOMP_PYTHON_DIALECTS_NUMPY_BIND
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "npcomp/Dialect/Numpy/IR/NumpyOps.td"
|
||||
|
||||
#endif
|
|
@ -0,0 +1,15 @@
|
|||
//===-- TCFBind.td - TCF dialect bind ----------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_DIALECTS_TCF_BIND
|
||||
#define NPCOMP_PYTHON_DIALECTS_TCF_BIND
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "npcomp/Dialect/TCF/IR/TCFOps.td"
|
||||
|
||||
#endif
|
|
@ -0,0 +1,15 @@
|
|||
//===-- TorchBind.td - Torch dialect bind ------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_DIALECTS_TORCH_BIND
|
||||
#define NPCOMP_PYTHON_DIALECTS_TORCH_BIND
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "npcomp/Dialect/Torch/IR/TorchOps.td"
|
||||
|
||||
#endif
|
|
@ -0,0 +1,7 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Generated tablegen dialects expect to be able to find some symbols from
|
||||
# the mlir.dialects package.
|
||||
from mlir.dialects import _cext, _segmented_accessor, _equally_sized_accessor, _get_default_loc_context
|
|
@ -7,6 +7,10 @@ import numpy as np
|
|||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
|
||||
from mlir import ir as _ir
|
||||
|
||||
from npcomp.dialects import numpy as numpy_ops
|
||||
|
||||
|
||||
class Protocol(Enum):
|
||||
UFUNC = 1
|
||||
|
@ -39,8 +43,7 @@ TraceInvocation.__new__.__defaults__ = (Protocol.ARRAY_FUNC, "__call__")
|
|||
|
||||
|
||||
class EmissionRequest(
|
||||
namedtuple("EmissionRequest",
|
||||
["input_ssa_values", "dialect_helper", "extra"])):
|
||||
namedtuple("EmissionRequest", ["input_ssa_values", "ic", "extra"])):
|
||||
"""Represents the result of processing inputs from an invocation.
|
||||
|
||||
The `input_ssa_values` are mlir.ir.Value instances corresponding to
|
||||
|
@ -173,11 +176,14 @@ class GenericCallUfuncEmitter(FuncEmitter):
|
|||
return py_results[0]
|
||||
|
||||
def emit(self, request: EmissionRequest):
|
||||
h = request.dialect_helper
|
||||
op_result_type = h.tensor_type(h.numpy_any_dtype)
|
||||
call_op = h.numpy_builtin_ufunc_call_op(*request.input_ssa_values,
|
||||
qualified_name=self._ufunc_name,
|
||||
result_type=op_result_type)
|
||||
ic = request.ic
|
||||
name_attr = _ir.StringAttr.get(self._ufunc_name)
|
||||
result_type = ic.unknown_tensor_type
|
||||
call_op = numpy_ops.BuiltinUfuncCallOp(result_type,
|
||||
qualified_name=name_attr,
|
||||
inputs=request.input_ssa_values,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip)
|
||||
return call_op.results
|
||||
|
||||
|
||||
|
@ -219,9 +225,13 @@ class GenericArrayFuncEmitter(FuncEmitter):
|
|||
return tuple(py_results)
|
||||
|
||||
def emit(self, request: EmissionRequest):
|
||||
h = request.dialect_helper
|
||||
op_result_types = [h.tensor_type(h.numpy_any_dtype)] * self._nresults
|
||||
op = h.op(self._op_name, op_result_types, request.input_ssa_values)
|
||||
ic = request.ic
|
||||
op_result_types = [ic.unknown_tensor_type] * self._nresults
|
||||
op = _ir.Operation.create(self._op_name,
|
||||
results=op_result_types,
|
||||
operands=request.input_ssa_values,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip)
|
||||
return op.results
|
||||
|
||||
|
||||
|
|
|
@ -3,28 +3,44 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import re
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Optional
|
||||
import numpy as np
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from mlir import ir as _ir
|
||||
from mlir.dialects import std as std_ops
|
||||
|
||||
from npcomp.dialect import Numpy
|
||||
from npcomp.exporter import *
|
||||
from npcomp.types import *
|
||||
from npcomp.tracing.context import *
|
||||
from npcomp.tracing.emitters import *
|
||||
from npcomp import _cext
|
||||
from npcomp.dialects import basicpy as basicpy_ops
|
||||
from npcomp.dialects import numpy as numpy_ops
|
||||
|
||||
from ..exporter import *
|
||||
from ..types import *
|
||||
from ..compiler.utils.mlir_utils import *
|
||||
|
||||
from .context import *
|
||||
from .emitters import *
|
||||
|
||||
|
||||
class ModuleBuilder:
|
||||
"""Builds an MLIR module by tracing functions."""
|
||||
|
||||
def __init__(self, mlir_context=None, emitter_registry=None):
|
||||
self.context = mlir_context if mlir_context else ir.MLIRContext()
|
||||
self.module = self.context.new_module()
|
||||
self.helper = Numpy.DialectHelper(self.context, ir.OpBuilder(self.context))
|
||||
__slots__ = [
|
||||
"emitters",
|
||||
"ic",
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
mlir_context: Optional[_ir.Context] = None,
|
||||
emitter_registry=None):
|
||||
ic = self.ic = ImportContext(mlir_context)
|
||||
ic.module = _ir.Module.create(loc=ic.loc)
|
||||
self.emitters = (emitter_registry
|
||||
if emitter_registry else EmitterRegistry.create_default())
|
||||
|
||||
@property
|
||||
def module(self):
|
||||
return self.ic.module
|
||||
|
||||
def trace(self, *export_py_funcs: ExportPyFunction):
|
||||
"""Traces exported py functions."""
|
||||
for export_py_func in export_py_funcs:
|
||||
|
@ -43,9 +59,7 @@ class FunctionTracer(TraceContext):
|
|||
"_args_array_params",
|
||||
"_f",
|
||||
"_f_types",
|
||||
"_helper",
|
||||
"_mlir_m",
|
||||
"_mlir_c",
|
||||
"_ic",
|
||||
"_python_args",
|
||||
"_result_array_params",
|
||||
"_traced_arrays",
|
||||
|
@ -61,35 +75,39 @@ class FunctionTracer(TraceContext):
|
|||
self._validate()
|
||||
|
||||
# Alias some parent members for convenience.
|
||||
self._mlir_m = module_builder.module
|
||||
self._mlir_c = module_builder.context
|
||||
self._helper = module_builder.helper
|
||||
self._ic = module_builder.ic
|
||||
with self._ic.context:
|
||||
# Extract ArrayParams for all args and results.
|
||||
self._args_array_params = [
|
||||
ArrayParams.from_constraints(arg.constraints)
|
||||
for arg in self.epf.sig.args
|
||||
]
|
||||
self._python_args = [None] * len(self._args_array_params)
|
||||
self._result_array_params = ArrayParams.from_constraints(
|
||||
self.epf.sig.result.constraints)
|
||||
|
||||
# Extract ArrayParams for all args and results.
|
||||
self._args_array_params = [
|
||||
ArrayParams.from_constraints(arg.constraints)
|
||||
for arg in self.epf.sig.args
|
||||
]
|
||||
self._python_args = [None] * len(self._args_array_params)
|
||||
self._result_array_params = ArrayParams.from_constraints(
|
||||
self.epf.sig.result.constraints)
|
||||
# Create the MLIR function.
|
||||
self._f, self._f_types = self._create_mlir_function()
|
||||
self._create_trace_roots()
|
||||
|
||||
# Create the MLIR function.
|
||||
self._f, self._f_types = self._create_mlir_function()
|
||||
self._create_trace_roots()
|
||||
@property
|
||||
def entry_block(self) -> _ir.Block:
|
||||
return self._f.regions[0].blocks[0]
|
||||
|
||||
def trace(self):
|
||||
# Invoke the python function with placeholders.
|
||||
# TODO: More sophisticated signature merging
|
||||
# TODO: Multiple results
|
||||
# TODO: Error reporting
|
||||
h = self._helper
|
||||
py_results = (self.epf.pyfunc(*self._python_args),)
|
||||
if len(py_results) != len(self._f_types):
|
||||
raise TracingError("Traced function returned != %d results: %r" % (
|
||||
len(self._f_types),
|
||||
py_results,
|
||||
))
|
||||
ic = self._ic
|
||||
ic.insert_end_of_block(self.entry_block)
|
||||
with ic.context:
|
||||
py_results = (self.epf.pyfunc(*self._python_args),)
|
||||
if len(py_results) != len(self._f_types):
|
||||
raise TracingError("Traced function returned != %d results: %r" % (
|
||||
len(self._f_types),
|
||||
py_results,
|
||||
))
|
||||
|
||||
# Narrow all results to the declared return types.
|
||||
return_operands = []
|
||||
|
@ -97,8 +115,12 @@ class FunctionTracer(TraceContext):
|
|||
mlir_result = self.get_traced_array_value(py_result)
|
||||
# narrow to declared result type.
|
||||
return_operands.extend(
|
||||
h.numpy_narrow_op(mlir_result_type, mlir_result).results)
|
||||
h.return_op(return_operands)
|
||||
numpy_ops.NarrowOp(mlir_result_type,
|
||||
mlir_result,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).results)
|
||||
std_ops.ReturnOp(return_operands, loc=ic.loc, ip=ic.ip)
|
||||
ic.pop_ip()
|
||||
|
||||
def set_traced_array(self, traced_array, value):
|
||||
"""Sets the current SSA value for a traced_array."""
|
||||
|
@ -117,15 +139,18 @@ class FunctionTracer(TraceContext):
|
|||
return traced_value
|
||||
|
||||
def _get_external_array_value(self, external_array):
|
||||
h = self._helper
|
||||
ic = self._ic
|
||||
if not isinstance(external_array, np.ndarray):
|
||||
raise TracingError("Expected ndarray but got: %r" % (external_array,))
|
||||
found_it = self._external_arrays.get(id(external_array))
|
||||
if found_it:
|
||||
return found_it[1]
|
||||
# Import it.
|
||||
dense_attr = h.context.dense_elements_attr(external_array)
|
||||
const_value = h.constant_op(dense_attr.type, dense_attr).result
|
||||
dense_attr = _ir.DenseElementsAttr.get(external_array, context=ic.context)
|
||||
const_value = std_ops.ConstantOp(dense_attr.type,
|
||||
dense_attr,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
self._external_arrays[id(external_array)] = (external_array, const_value)
|
||||
return const_value
|
||||
|
||||
|
@ -138,28 +163,24 @@ class FunctionTracer(TraceContext):
|
|||
(self.epf.sig.result,))
|
||||
|
||||
def _create_mlir_function(self):
|
||||
mlir_c = self._mlir_c
|
||||
mlir_m = self._mlir_m
|
||||
h = self._helper
|
||||
ic = self._ic
|
||||
epf = self.epf
|
||||
f_args = [
|
||||
mlir_c.parse_type(ap.mlir_tensor_type_asm)
|
||||
_ir.Type.parse(ap.mlir_tensor_type_asm)
|
||||
for ap in self._args_array_params
|
||||
]
|
||||
f_types = [
|
||||
mlir_c.parse_type(self._result_array_params.mlir_tensor_type_asm)
|
||||
]
|
||||
h.builder.insert_before_terminator(mlir_m.first_block)
|
||||
f_type = h.function_type(f_args, f_types)
|
||||
f = h.func_op(epf.__name__, f_type, create_entry_block=True)
|
||||
f_types = [_ir.Type.parse(self._result_array_params.mlir_tensor_type_asm)]
|
||||
ic.insert_before_terminator(ic.module.body)
|
||||
f_type = _ir.FunctionType.get(f_args, f_types)
|
||||
f, _ = ic.FuncOp(epf.__name__, f_type, create_entry_block=True)
|
||||
return f, f_types
|
||||
|
||||
def _create_trace_roots(self):
|
||||
entry_block = self._f.first_block
|
||||
entry_block = self.entry_block
|
||||
for index, ap in enumerate(self._args_array_params):
|
||||
if ap is not None:
|
||||
ta = TracedArray(self)
|
||||
self.set_traced_array(ta, entry_block.args[index])
|
||||
self.set_traced_array(ta, entry_block.arguments[index])
|
||||
self._python_args[index] = ta
|
||||
|
||||
def _resolve_input_ssa_values(self, trace_values: Iterable[TraceValue]):
|
||||
|
@ -190,9 +211,7 @@ class FunctionTracer(TraceContext):
|
|||
def _emit_invocation(self, emitter: FuncEmitter, invocation: TraceInvocation):
|
||||
tv_map = emitter.map_invocation(invocation)
|
||||
input_ssa_values = self._resolve_input_ssa_values(tv_map.input_trace_values)
|
||||
request = EmissionRequest(input_ssa_values,
|
||||
dialect_helper=self._helper,
|
||||
extra=tv_map.extra)
|
||||
request = EmissionRequest(input_ssa_values, ic=self._ic, extra=tv_map.extra)
|
||||
result_ssa_values = emitter.emit(request)
|
||||
py_values = self._resolve_result_py_values(tv_map.result_trace_value_types,
|
||||
result_ssa_values)
|
||||
|
@ -213,14 +232,18 @@ class FunctionTracer(TraceContext):
|
|||
return self._emit_invocation(emitter, invocation)
|
||||
|
||||
def _emit_slice_value(self, slice_element):
|
||||
h = self._helper
|
||||
ic = self._ic
|
||||
if slice_element == None:
|
||||
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||
return basicpy_ops.SingletonOp(ic.none_type, loc=ic.loc, ip=ic.ip).result
|
||||
elif slice_element == Ellipsis:
|
||||
return h.basicpy_singleton_op(h.basicpy_EllipsisType).result
|
||||
return basicpy_ops.SingletonOp(ic.ellipsis_type, loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
elif isinstance(slice_element, int):
|
||||
return h.constant_op(h.index_type,
|
||||
h.context.index_attr(slice_element)).result
|
||||
return std_ops.ConstantOp(ic.index_type,
|
||||
_ir.IntegerAttr.get(ic.index_type,
|
||||
slice_element),
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
elif isinstance(slice_element, slice):
|
||||
return self._emit_slice_object(slice_element)
|
||||
else:
|
||||
|
@ -229,29 +252,40 @@ class FunctionTracer(TraceContext):
|
|||
"TODO: Slicing with generic arrays not yet implemented")
|
||||
|
||||
def _emit_slice_object(self, slice_object: slice):
|
||||
h = self._helper
|
||||
ic = self._ic
|
||||
|
||||
def emit_index(index):
|
||||
if index is None:
|
||||
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||
return basicpy_ops.SingletonOp(ic.none_type, loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
else:
|
||||
return h.constant_op(h.index_type,
|
||||
h.context.index_attr(int(index))).result
|
||||
return std_ops.ConstantOp(ic.index_type,
|
||||
_ir.IntegerAttr.get(ic.index_type,
|
||||
int(index)),
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
|
||||
start = emit_index(slice_object.start)
|
||||
stop = emit_index(slice_object.stop)
|
||||
step = emit_index(slice_object.step)
|
||||
return h.basicpy_slot_object_make_op("slice", start, stop, step).result
|
||||
result_type = _cext.slot_object_type(ic.context, "slice",
|
||||
[start.type, stop.type, step.type])
|
||||
return basicpy_ops.SlotObjectMakeOp(result_type, [start, stop, step],
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
|
||||
def _handle_array_getitem(self, array, key):
|
||||
h = self._helper
|
||||
ic = self._ic
|
||||
array_value = self.get_traced_array_value(array)
|
||||
# Array slicing is always based on a tuple.
|
||||
slice_tuple = key if isinstance(key, tuple) else (key,)
|
||||
# Resolve and emit each slice element.
|
||||
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
|
||||
result_value = h.numpy_get_slice_op(h.unknown_array_type, array_value,
|
||||
*slice_values).result
|
||||
result_value = numpy_ops.GetSliceOp(ic.unknown_array_type,
|
||||
array_value,
|
||||
slice_values,
|
||||
loc=ic.loc,
|
||||
ip=ic.ip).result
|
||||
result_array = TracedArray(self)
|
||||
self.set_traced_array(result_array, result_value)
|
||||
return result_array
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
|
||||
bias = np.random.uniform(size=(4,)).astype(np.float32)
|
||||
|
||||
|
||||
def constants(a: np.ndarray) -> np.ndarray:
|
||||
return np.dot(a, weights) + bias
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.constants = constants
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.constants)
|
||||
print(mb.module.to_asm())
|
|
@ -1,35 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
return a * b + a + b
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.simple_mul = simple_mul
|
||||
exp.simple_mul.sig.args["a"] += Shape(1, 4)
|
||||
exp.simple_mul.sig.args["a"] += DynamicDim(0)
|
||||
exp.simple_mul.sig.args["a"] += DType(np.float32)
|
||||
exp.simple_mul.sig.args["b"] += Shape(1)
|
||||
exp.simple_mul.sig.args["b"] += DType(np.float32)
|
||||
exp.simple_mul.sig.result += Shape(1, 4)
|
||||
exp.simple_mul.sig.result += DynamicDim(0)
|
||||
exp.simple_mul.sig.result += DType(np.float32)
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.simple_mul)
|
||||
# CHECK: func @simple_mul(%arg0: tensor<?x4xf32>, %arg1: tensor<1xf32>) -> tensor<?x4xf32> {
|
||||
# CHECK: %0 = numpy.ufunc_call @numpy.multiply(%arg0, %arg1) : (tensor<?x4xf32>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
# CHECK: %1 = numpy.ufunc_call @numpy.add(%0, %arg0) : (tensor<*x!numpy.any_dtype>, tensor<?x4xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
# CHECK: %2 = numpy.ufunc_call @numpy.add(%1, %arg1) : (tensor<*x!numpy.any_dtype>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
# CHECK: %3 = numpy.narrow %2 : (tensor<*x!numpy.any_dtype>) -> tensor<?x4xf32>
|
||||
# CHECK: return %3 : tensor<?x4xf32>
|
||||
# CHECK: }
|
||||
print(mb.module.to_asm())
|
|
@ -1,20 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def slice_array1(a: np.ndarray) -> np.ndarray:
|
||||
return a[1, 2:10:2, 3:4, ..., :, 0]
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.slice_array1 = slice_array1
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.slice_array1)
|
||||
print(mb.module.to_asm())
|
|
@ -1,25 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def transpose_attribute(a: np.ndarray) -> np.ndarray:
|
||||
return a.T
|
||||
|
||||
|
||||
def transpose(a: np.ndarray) -> np.ndarray:
|
||||
return np.transpose(a)
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.transpose_attribute = transpose_attribute
|
||||
exp.transpose = transpose
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.transpose_attribute, exp.transpose)
|
||||
print(mb.module.to_asm())
|
|
@ -8,9 +8,6 @@ from npcomp.compiler.numpy import test_config
|
|||
from npcomp.compiler.numpy.target import *
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
# TODO: This should all exist in a high level API somewhere.
|
||||
from _npcomp import mlir
|
||||
|
||||
logging.enable()
|
||||
|
||||
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
||||
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Test for the MLIR IR Python bindings.
|
||||
|
||||
TODO: These tests were just for bootstrapping and are not authoritative at this
|
||||
point.
|
||||
"""
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
c = ir.MLIRContext()
|
||||
|
||||
# CHECK-LABEL: module @parseSuccess
|
||||
m = c.parse_asm(r"""
|
||||
module @parseSuccess {
|
||||
func @f() {
|
||||
return
|
||||
}
|
||||
}
|
||||
""")
|
||||
# CHECK: func @f
|
||||
print(m.to_asm())
|
||||
# CHECK: OP NAME: module
|
||||
print("OP NAME:", m.name)
|
||||
# CHECK: NUM_REGIONS: 1
|
||||
print("NUM_REGIONS:", m.num_regions)
|
||||
region = m.region(0)
|
||||
# CHECK: CONTAINED OP: func
|
||||
# CHECK: CONTAINED OP: module_terminator
|
||||
for block in region.blocks:
|
||||
for op in block.operations:
|
||||
print("CONTAINED OP:", op.name)
|
||||
|
||||
# CHECK-LABEL: PARSE_FAILURE
|
||||
print("PARSE_FAILURE")
|
||||
try:
|
||||
m = c.parse_asm("{{ILLEGAL SYNTAX}}")
|
||||
except ValueError as e:
|
||||
# CHECK: [ERROR]: expected operation name in quotes
|
||||
print(e)
|
|
@ -1,40 +0,0 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
||||
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Test for the MLIR Pass Python bindings"""
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from _npcomp.mlir import passes
|
||||
|
||||
c = ir.MLIRContext()
|
||||
|
||||
pm = passes.PassManager(c)
|
||||
|
||||
# CHECK-LABEL: module @parseSuccess
|
||||
m = c.parse_asm(r"""
|
||||
module @parseSuccess {
|
||||
func @notUsed() attributes { sym_visibility = "private" }
|
||||
func @f() {
|
||||
return
|
||||
}
|
||||
}
|
||||
""")
|
||||
# CHECK: func private @notUsed
|
||||
# CHECK: func @f
|
||||
print(m.to_asm())
|
||||
|
||||
# CHECK: PASS COUNT: 0
|
||||
print("PASS COUNT:", len(pm))
|
||||
|
||||
pm.addPassPipelines("canonicalize", "symbol-dce")
|
||||
# Note: not checking the actual count since these may expand to more than
|
||||
# two passes.
|
||||
# CHECK: PASS COUNT:
|
||||
print("PASS COUNT:", len(pm))
|
||||
# CHECK: PASSES: canonicalize, symbol-dce
|
||||
print("PASSES:", str(pm))
|
||||
pm.run(m)
|
||||
print(m.to_asm())
|
||||
# CHECK-NOT: func @notUsed
|
|
@ -0,0 +1,34 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
||||
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
|
||||
bias = np.random.uniform(size=(4,)).astype(np.float32)
|
||||
|
||||
|
||||
def constants(a: np.ndarray) -> np.ndarray:
|
||||
return np.dot(a, weights) + bias
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.constants = constants
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.constants)
|
||||
# CHECK-LABEL: func @constants(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||||
# CHECK: %[[VAL_1:.*]] = constant dense<{{.*}}> : tensor<16x4xf32>
|
||||
# CHECK: %[[VAL_2:.*]] = numpy.dot %[[VAL_0]], %[[VAL_1]] : (tensor<*x!numpy.any_dtype>, tensor<16x4xf32>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: %[[VAL_3:.*]] = constant dense<{{.*}}> : tensor<4xf32>
|
||||
# CHECK: %[[VAL_4:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[VAL_2]], %[[VAL_3]]) : (tensor<*x!basicpy.UnknownType>, tensor<4xf32>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: %[[VAL_5:.*]] = numpy.narrow %[[VAL_4]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
|
||||
# CHECK: return %[[VAL_5]] : tensor<*x!numpy.any_dtype>
|
||||
# CHECK: }
|
||||
print(mb.module)
|
|
@ -1,3 +1,5 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
||||
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
@ -25,4 +27,12 @@ exp.dot2d.sig.result += DType(np.float32)
|
|||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.dot2d)
|
||||
print(mb.module.to_asm())
|
||||
|
||||
# CHECK-LABEL: func @dot2d(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<?x16xf32>,
|
||||
# CHECK-SAME: %[[VAL_1:.*]]: tensor<16x32xf32>) -> tensor<?x32xf32> {
|
||||
# CHECK: %[[VAL_2:.*]] = numpy.dot %[[VAL_0]], %[[VAL_1]] : (tensor<?x16xf32>, tensor<16x32xf32>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: %[[VAL_3:.*]] = numpy.narrow %[[VAL_2]] : (tensor<*x!basicpy.UnknownType>) -> tensor<?x32xf32>
|
||||
# CHECK: return %[[VAL_3]] : tensor<?x32xf32>
|
||||
# CHECK: }
|
||||
print(mb.module)
|
|
@ -30,6 +30,7 @@ exp.simple_mul.sig.result += DType(np.float32)
|
|||
|
||||
mb = ModuleBuilder()
|
||||
mb.trace(exp.simple_mul)
|
||||
# This test exercises the full tracing path and incidentally checks the ops.
|
||||
# CHECK: func @simple_mul(%arg0: tensor<?x4xf32>, %arg1: tensor<1xf32>) -> tensor<?x4xf32> {
|
||||
# CHECK: %0 = numpy.builtin_ufunc_call<"numpy.multiply"> (%arg0, %arg1) : (tensor<?x4xf32>, tensor<1xf32>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: %1 = numpy.builtin_ufunc_call<"numpy.add"> (%0, %arg0) : (tensor<*x!basicpy.UnknownType>, tensor<?x4xf32>) -> tensor<*x!basicpy.UnknownType>
|
||||
|
@ -37,4 +38,4 @@ mb.trace(exp.simple_mul)
|
|||
# CHECK: %3 = numpy.narrow %2 : (tensor<*x!basicpy.UnknownType>) -> tensor<?x4xf32>
|
||||
# CHECK: return %3 : tensor<?x4xf32>
|
||||
# CHECK: }
|
||||
print(mb.module.to_asm())
|
||||
print(str(mb.module))
|
|
@ -0,0 +1,48 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
||||
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def slice_array1(a: np.ndarray) -> np.ndarray:
|
||||
return a[1, 2:10:2, 3:4, ..., :, 0]
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.slice_array1 = slice_array1
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.slice_array1)
|
||||
|
||||
# TODO: The numpy.get_slice op emission should be analyzed: it probably
|
||||
# needs to both accept and produce either arrays or tensors and the following
|
||||
# narrow should do likewise.
|
||||
# CHECK-LABEL: func @slice_array1(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||||
# CHECK: %[[VAL_1:.*]] = constant 1 : index
|
||||
# CHECK: %[[VAL_2:.*]] = constant 2 : index
|
||||
# CHECK: %[[VAL_3:.*]] = constant 10 : index
|
||||
# CHECK: %[[VAL_4:.*]] = constant 2 : index
|
||||
# CHECK: %[[VAL_5:.*]] = basicpy.slot_object_make(%[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) -> !basicpy.SlotObject<slice, index, index, index>
|
||||
# CHECK: %[[VAL_6:.*]] = constant 3 : index
|
||||
# CHECK: %[[VAL_7:.*]] = constant 4 : index
|
||||
# CHECK: %[[VAL_8:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[VAL_9:.*]] = basicpy.slot_object_make(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) -> !basicpy.SlotObject<slice, index, index, !basicpy.NoneType>
|
||||
# CHECK: %[[VAL_10:.*]] = basicpy.singleton : !basicpy.EllipsisType
|
||||
# CHECK: %[[VAL_11:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[VAL_12:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[VAL_13:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[VAL_14:.*]] = basicpy.slot_object_make(%[[VAL_11]], %[[VAL_12]], %[[VAL_13]]) -> !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>
|
||||
# CHECK: %[[VAL_15:.*]] = constant 0 : index
|
||||
# CHECK: %[[VAL_16:.*]] = numpy.get_slice %[[VAL_0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_9]], %[[VAL_10]], %[[VAL_14]], %[[VAL_15]] : (tensor<*x!numpy.any_dtype>, index, !basicpy.SlotObject<slice, index, index, index>, !basicpy.SlotObject<slice, index, index, !basicpy.NoneType>, !basicpy.EllipsisType, !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>, index) -> !numpy.ndarray<*:?>
|
||||
# CHECK: %[[VAL_17:.*]] = numpy.narrow %[[VAL_16]] : (!numpy.ndarray<*:?>) -> tensor<*x!numpy.any_dtype>
|
||||
# CHECK: return %[[VAL_17]] : tensor<*x!numpy.any_dtype>
|
||||
# CHECK: }
|
||||
|
||||
print(mb.module)
|
|
@ -0,0 +1,42 @@
|
|||
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
|
||||
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
|
||||
def transpose_attribute(a: np.ndarray) -> np.ndarray:
|
||||
return a.T
|
||||
|
||||
|
||||
def transpose(a: np.ndarray) -> np.ndarray:
|
||||
return np.transpose(a)
|
||||
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.transpose_attribute = transpose_attribute
|
||||
exp.transpose = transpose
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.transpose_attribute, exp.transpose)
|
||||
|
||||
# TODO: Consolidate any_dtype -> UnknownType.
|
||||
# CHECK-LABEL: func @transpose_attribute(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||||
# CHECK: %[[VAL_1:.*]] = numpy.transpose %[[VAL_0]] : (tensor<*x!numpy.any_dtype>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: %[[VAL_2:.*]] = numpy.narrow %[[VAL_1]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
|
||||
# CHECK: return %[[VAL_2]] : tensor<*x!numpy.any_dtype>
|
||||
# CHECK: }
|
||||
|
||||
# CHECK-LABEL: func @transpose(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
|
||||
# CHECK: %[[VAL_1:.*]] = numpy.transpose %[[VAL_0]] : (tensor<*x!numpy.any_dtype>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: %[[VAL_2:.*]] = numpy.narrow %[[VAL_1]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
|
||||
# CHECK: return %[[VAL_2]] : tensor<*x!numpy.any_dtype>
|
||||
# CHECK: }
|
||||
print(mb.module)
|
|
@ -24,8 +24,6 @@ def run_doctest(mod):
|
|||
|
||||
TEST_MODULES = (
|
||||
"npcomp.compiler.numpy.py_value_utils",
|
||||
"npcomp.dialect.Basicpy",
|
||||
"npcomp.dialect.Numpy",
|
||||
"npcomp.tracing.context",
|
||||
"npcomp.tracing.emitters",
|
||||
"npcomp.tracing.mlir_trace",
|
||||
|
|
Loading…
Reference in New Issue