NFC: Delete npcomp python API and switch to upstream.

* Most updates are mechanical except:
  * python/npcomp/__init__.py and python/NpcompModule.cpp: New init/registration bits to replace some automatic things being done in the old bindings. Also an annoying linkage hack that I'll need to triage next.
  * NpcompModule.cpp: New python helpers for custom types and other hard to reach items (for the new bindings).
  * PybindUtils.h: Extended type casting so that the local extension can directly exchange Mlir* C types.
  * python/npcomp/dialects/*: Build support and ODS bindings for local dialects.
  * mlir_utils.py: Defines an ImportContext to replace the old/bad "Helper" class that tracked locations, and insertion points. This has a number of methods on it that would be good candidates to think about better ways to do them upstream.
* Also hoisted a few stand-alone samples to dedicated unit tests as they covered important things.
* More cleanup can be done, but keeping this patch as mechanical as possible to stay in NFC land (this is big enough).
pull/144/head
Stella Laurenzo 2020-12-29 13:22:18 -08:00
parent 951d7ff42c
commit 3f706473fd
61 changed files with 1021 additions and 2546 deletions

View File

@ -21,6 +21,9 @@ extern "C" {
*/
void npcompRegisterAllDialects(MlirContext context);
/** Registers all NPComp passes for symbolic access with the global registry. */
void npcompRegisterAllPasses();
#ifdef __cplusplus
}
#endif

View File

@ -75,6 +75,9 @@ MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
/// Helper that gets an equivalent NdArrayType from a ShapedType.
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
/// Helper that converts an NdArrayType to a TensorType.
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType);
/*============================================================================*/
/* None type. */
/*============================================================================*/
@ -85,6 +88,14 @@ int npcompTypeIsANone(MlirType t);
/** Gets the type of the singleton 'None'. */
MlirType npcompNoneTypeGet(MlirContext context);
/*============================================================================*/
/* SlotObject type. */
/*============================================================================*/
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
intptr_t slotTypeCount,
const MlirType *slotTypes);
/*============================================================================*/
/* Tuple type. */
/*============================================================================*/

View File

@ -1,4 +1,4 @@
//===- BasicPyOps.td - Basic Python ops --------------------*- tablegen -*-===//
//===- BasicpyOps.td - Basic Python ops --------------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -9,7 +9,7 @@
#ifndef NPCOMP_DIALECT_BASICPY_IR_BASICPY_OPS
#define NPCOMP_DIALECT_BASICPY_IR_BASICPY_OPS
include "BasicpyDialect.td"
include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -510,7 +510,6 @@ def Basicpy_SlotObjectMakeOp : Basicpy_Op<"slot_object_make", [
!basicpy.SlotObject<slice, !basicpy.NoneType>
}];
let arguments = (ins
StrAttr:$className,
// TODO: Tighter constraints on allowable types.
Variadic<AnyType>:$slots
);

View File

@ -9,7 +9,7 @@
#ifndef NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_OPS
include "NumpyDialect.td"
include "npcomp/Dialect/Numpy/IR/NumpyDialect.td"
include "npcomp/Typing/Analysis/CPA/Interfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
@ -167,7 +167,7 @@ def Numpy_TransposeOp : Numpy_Op<"transpose", []> {
// See: https://docs.scipy.org/doc/numpy/user/basics.indexing.html
//----------------------------------------------------------------------------//
def Numpy_GetSlice : Numpy_Op<"get_slice", []> {
def Numpy_GetSliceOp : Numpy_Op<"get_slice", []> {
let summary = "Gets a slice of an array";
let description = [{
This op encapsulates all forms of indexing into an array by taking a

View File

@ -1,38 +0,0 @@
//===- MlirInit.h - MLIR config and init ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_MLIRINIT_H
#define NPCOMP_PYTHON_MLIRINIT_H
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
// Note that the CPP module is compiled without RTTI or exceptions, unlike
// the rest of the pybind code. Therefore, we also stash some trampolines
// here for parts of the code that are not RTTI-compatible.
namespace mlir {
class MLIRContext;
namespace npcomp {
namespace python {
// One time initialization.
bool npcompMlirInitialize();
// Loads globally registered dialects into the MLIRContext.
// This is temporary until there is an upstream story for handling dialect
// registration in python-based systems.
void loadGlobalDialectsIntoContext(MLIRContext *context);
} // namespace python
} // namespace npcomp
} // namespace mlir
#endif // NPCOMP_PYTHON_MLIRINIT_H

View File

@ -1,202 +0,0 @@
//===- MlirIr.h - MLIR IR Bindings ----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_MLIR_IR_H
#define NPCOMP_PYTHON_MLIR_IR_H
#include "PybindUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir {
struct PyContext;
//===----------------------------------------------------------------------===//
// Utility types
//===----------------------------------------------------------------------===//
template <typename ListTy, typename ItemWrapperTy> class PyIpListWrapper {
public:
using ThisTy = PyIpListWrapper<ListTy, ItemWrapperTy>;
static void bind(py::module m, const char *className);
PyIpListWrapper(ListTy &list) : list(list) {}
private:
ListTy &list;
};
//===----------------------------------------------------------------------===//
// Wrapper types
//===----------------------------------------------------------------------===//
/// Wrapper around an Operation*.
struct PyBaseOperation {
virtual ~PyBaseOperation();
static void bind(py::module m);
virtual Operation *getOperation() = 0;
};
/// Wrapper around Module, capturing a PyContext reference.
struct PyModuleOp : PyBaseOperation {
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
: context(context), moduleOp(moduleOp) {
assert(moduleOp);
}
~PyModuleOp();
static void bind(py::module m);
Operation *getOperation() override;
std::string toAsm(bool enableDebugInfo, bool prettyForm,
int64_t largeElementLimit);
std::shared_ptr<PyContext> context;
ModuleOp moduleOp;
};
/// Wrapper around an Operation*.
struct PyOperationRef : PyBaseOperation {
PyOperationRef(Operation *operation) : operation(operation) {
assert(operation);
}
PyOperationRef(Operation &operation) : operation(&operation) {}
~PyOperationRef();
static void bind(py::module m);
Operation *getOperation() override;
Operation *operation;
};
/// Wrapper around SymbolTable.
struct PySymbolTable {
PySymbolTable(SymbolTable &symbolTable) : symbolTable(symbolTable) {}
static void bind(py::module m);
SymbolTable &symbolTable;
};
/// Wrapper around Value.
struct PyValue {
PyValue(Value value) : value(value) { assert(value); }
static void bind(py::module m);
operator Value() { return value; }
Value value;
};
/// Wrapper around Identifier.
struct PyIdentifier {
PyIdentifier(Identifier identifier) : identifier(identifier) {}
static void bind(py::module m);
Identifier identifier;
};
/// Wrapper around Attribute.
struct PyAttribute {
PyAttribute(Attribute attr) : attr(attr) { assert(attr); }
static void bind(py::module m);
Attribute attr;
};
/// Wrapper around MLIRContext.
struct PyContext : std::enable_shared_from_this<PyContext> {
PyContext();
static void bind(py::module m);
PyModuleOp parseAsm(const std::string &asm_text);
MLIRContext context;
};
/// Wrapper around a Block&.
struct PyBlockRef {
PyBlockRef(Block &block) : block(block) {}
static void bind(py::module m);
Block &block;
};
/// Wrapper around a Region&.
struct PyRegionRef {
PyRegionRef(Region &region) : region(region) {}
static void bind(py::module m);
Region &region;
};
struct PyType {
PyType() = default;
PyType(Type type) : type(type) {}
static void bind(py::module m);
operator Type() { return type; }
Type type;
};
/// Wrapper around an OpBuilder reference.
/// This class is inherently dangerous because it does not track ownership
/// of IR objects that it may be operating on and incorrect usage can cause
/// memory access errors, just as it can in C++. It is intended for use by
/// higher level constructs that are specifically coded to satisfy object
/// lifetime needs.
class PyBaseOpBuilder {
public:
virtual ~PyBaseOpBuilder();
static void bind(py::module m);
virtual OpBuilder &getBuilder(bool requirePosition = false) = 0;
MLIRContext *getContext() { return getBuilder(false).getContext(); }
// For convenience, we track the current location at the builder level
// to avoid lots of parameter passing.
void setCurrentLoc(Location loc) { currentLoc = loc; }
Location getCurrentLoc() {
if (currentLoc) {
return Location(currentLoc);
} else {
return UnknownLoc::get(getBuilder(false).getContext());
}
}
private:
LocationAttr currentLoc;
};
/// Wrapper around an instance of an OpBuilder.
class PyOpBuilder : public PyBaseOpBuilder {
public:
PyOpBuilder(PyContext &context) : builder(&context.context) {}
~PyOpBuilder() override;
static void bind(py::module m);
OpBuilder &getBuilder(bool requirePosition = false) override;
private:
OpBuilder builder;
};
//===----------------------------------------------------------------------===//
// Custom types
//===----------------------------------------------------------------------===//
/// Helper for creating (possibly dialect specific) IR objects. This class
/// is intended to be subclassed on the Python side (possibly with multiple
/// inheritance) to provide Python level APIs for custom dialects. The base
/// class contains helpers for std types and ops.
class PyDialectHelper {
public:
PyDialectHelper(PyContext &context, PyOpBuilder &builder)
: context(context), pyOpBuilder(builder) {}
static void bind(py::module m);
MLIRContext *getContext() { return pyOpBuilder.getContext(); }
protected:
PyContext &context;
PyOpBuilder &pyOpBuilder;
};
} // namespace mlir
#endif // NPCOMP_PYTHON_MLIR_IR_H

View File

@ -1,34 +0,0 @@
//===- MlirPass.h - MLIR Pass Bindings ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_MLIR_PASS_H
#define NPCOMP_PYTHON_MLIR_PASS_H
#include "MlirIr.h"
#include "PybindUtils.h"
#include "mlir/Pass/PassManager.h"
namespace mlir {
struct PyPassManager {
PyPassManager(std::shared_ptr<PyContext> context, bool verifyModules)
: passManager(&context->context, OpPassManager::Nesting::Implicit),
context(std::move(context)) {
passManager.enableVerifier(verifyModules);
}
static void bind(py::module m);
PassManager passManager;
private:
std::shared_ptr<PyContext> context;
};
} // namespace mlir
#endif // NPCOMP_PYTHON_MLIR_PASS_H

View File

@ -1,28 +0,0 @@
//===- NpcompModule.h - Module registrations ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_NPCOMP_MODULE_H
#define NPCOMP_PYTHON_NPCOMP_MODULE_H
#include "PybindUtils.h"
namespace mlir {
void defineMlirIrModule(py::module m);
void defineMlirPassModule(py::module m);
void defineMlirCoreDialects(py::module m);
namespace npcomp {
namespace python {
void defineNpcompDialect(py::module m);
} // namespace python
} // namespace npcomp
} // namespace mlir
#endif // NPCOMP_PYTHON_NPCOMP_MODULE_H

View File

@ -28,11 +28,51 @@ namespace detail {
template <typename T>
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
/// Helper to convert a presumed MLIR API object to a capsule, accepting either
/// an explicit Capsule (which can happen when two C APIs are communicating
/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
/// attribute (through which supported MLIR Python API objects export their
/// contained API pointer as a capsule). This is intended to be used from
/// type casters, which are invoked with a raw handle (unowned). The returned
/// object's lifetime may not extend beyond the apiObject handle without
/// explicitly having its refcount increased (i.e. on return).
static py::object mlirApiObjectToCapsule(py::handle apiObject) {
if (PyCapsule_CheckExact(apiObject.ptr()))
return py::reinterpret_borrow<py::object>(apiObject);
return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
}
// Note: Currently all of the following support cast from py::object to the
// Mlir* C-API type, but only a few light-weight, context-bound ones
// implicitly cast the other way because the use case has not yet emerged and
// ownership is unclear.
/// Casts object -> MlirAttribute.
template <> struct type_caster<MlirAttribute> {
PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
bool load(handle src, bool) {
auto capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToAttribute(capsule.ptr());
if (mlirAttributeIsNull(value)) {
return false;
}
return true;
}
static handle cast(MlirAttribute v, return_value_policy, handle) {
auto capsule =
py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v));
return py::module::import("mlir.ir")
.attr("Attribute")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
}
};
/// Casts object -> MlirContext.
template <> struct type_caster<MlirContext> {
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
bool load(handle src, bool) {
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
auto capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToContext(capsule.ptr());
if (mlirContextIsNull(value)) {
return false;
@ -41,11 +81,32 @@ template <> struct type_caster<MlirContext> {
}
};
/// Casts object -> MlirLocation.
template <> struct type_caster<MlirLocation> {
PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
bool load(handle src, bool) {
auto capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToLocation(capsule.ptr());
if (mlirLocationIsNull(value)) {
return false;
}
return true;
}
static handle cast(MlirLocation v, return_value_policy, handle) {
auto capsule =
py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v));
return py::module::import("mlir.ir")
.attr("Location")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
}
};
/// Casts object -> MlirModule.
template <> struct type_caster<MlirModule> {
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
bool load(handle src, bool) {
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
auto capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToModule(capsule.ptr());
if (mlirModuleIsNull(value)) {
return false;
@ -54,11 +115,24 @@ template <> struct type_caster<MlirModule> {
}
};
/// Casts object -> MlirOperation.
template <> struct type_caster<MlirOperation> {
PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
bool load(handle src, bool) {
auto capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToOperation(capsule.ptr());
if (mlirOperationIsNull(value)) {
return false;
}
return true;
}
};
/// Casts object -> MlirPassManager.
template <> struct type_caster<MlirPassManager> {
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
bool load(handle src, bool) {
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
auto capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(value)) {
return false;
@ -67,6 +141,27 @@ template <> struct type_caster<MlirPassManager> {
}
};
/// Casts object -> MlirType.
template <> struct type_caster<MlirType> {
PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
bool load(handle src, bool) {
auto capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToType(capsule.ptr());
if (mlirTypeIsNull(value)) {
return false;
}
return true;
}
static handle cast(MlirType t, return_value_policy, handle) {
auto capsule =
py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t));
return py::module::import("mlir.ir")
.attr("Type")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
}
};
} // namespace detail
} // namespace pybind11

View File

@ -16,7 +16,6 @@ target_link_libraries(NPCOMPBackendRefJITPythonModule
MLIRExecutionEngine
MLIRTargetLLVMIR
NPCOMPPythonCommon
NPCOMPRefBackendJITHelpers
)

View File

@ -12,8 +12,6 @@
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "npcomp/Python/MlirIr.h"
#include "npcomp/Python/MlirPass.h"
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
using llvm::SmallVector;
@ -21,8 +19,6 @@ using llvm::StringRef;
using llvm::Twine;
// Make namespaces consistent.
using mlir::PyModuleOp;
using mlir::PyPassManager;
using refback::JITModule;
using refbackrt::Ref;
using refbackrt::Tensor;

View File

@ -9,6 +9,8 @@
#include "npcomp-c/Registration.h"
#include "mlir/CAPI/IR.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/InitAll.h"
void npcompRegisterAllDialects(MlirContext context) {
@ -16,3 +18,11 @@ void npcompRegisterAllDialects(MlirContext context) {
// TODO: Don't eagerly load once D88162 is in and clients can do this.
unwrap(context)->getDialectRegistry().loadAll(unwrap(context));
}
void npcompRegisterAllPasses() {
::mlir::NPCOMP::registerAllPasses();
// Upstream passes we depend on.
::mlir::registerCanonicalizerPass();
::mlir::registerSCFToStandardPass();
}

View File

@ -9,6 +9,7 @@
#include "npcomp-c/Types.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
@ -85,6 +86,10 @@ MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
unwrap(shapedType).cast<ShapedType>()));
}
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType) {
return wrap(unwrap(ndarrayType).cast<Numpy::NdArrayType>().toTensorType());
}
/*============================================================================*/
/* None type. */
/*============================================================================*/
@ -97,6 +102,23 @@ MlirType npcompNoneTypeGet(MlirContext context) {
return wrap(Basicpy::NoneType::get(unwrap(context)));
}
/*============================================================================*/
/* SlotObject type. */
/*============================================================================*/
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
intptr_t slotTypeCount,
const MlirType *slotTypes) {
MLIRContext *cppContext = unwrap(context);
auto classNameAttr = StringAttr::get(unwrap(className), cppContext);
SmallVector<Type> slotTypesCpp;
slotTypesCpp.resize(slotTypeCount);
for (intptr_t i = 0; i < slotTypeCount; ++i) {
slotTypesCpp[i] = unwrap(slotTypes[i]);
}
return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp));
}
/*============================================================================*/
/* Tuple type. */
/*============================================================================*/

View File

@ -2,7 +2,6 @@ add_subdirectory(CAPI)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(RefBackend)
add_subdirectory(Python)
add_subdirectory(Typing)
if(NPCOMP_ENABLE_REFJIT)

View File

@ -1,54 +0,0 @@
################################################################################
# NPCOMPPythonCommon
################################################################################
include(AddLLVM)
include(NpcompPython)
# TODO: This should not be wired in at such a low/unconditional level.
# It is done here to be kept with the other LLVM initialization until a better
# place can be found for it.
# set(ExtraInit_LIBADD)
# if(NPCOMP_ENABLE_REFJIT)
# llvm_map_components_to_libnames(refjit_llvm_libs
# nativecodegen
# )
# message(STATUS "Including LLVM libs for RefJit: ${refjit_llvm_libs}")
# list(APPEND ExtraInit_LIBADD
# ${refjit_llvm_libs})
# endif()
# include_directories(
# ${PYTHON_INCLUDE_DIRS}
# )
set(PYBIND_SOURCES
MlirInit.cpp
MlirIr.cpp
MlirPass.cpp
NpcompDialect.cpp
CoreDialects.cpp
)
add_library(NPCOMPPythonCommon
${PYBIND_SOURCES}
)
target_link_libraries(NPCOMPPythonCommon
pybind11::module
NPCOMPInitAll
NPCOMPCAPI
# Core dialects
MLIRSCF
# Upstream depends
MLIRDialect
MLIREDSC
MLIRIR
MLIRLLVMIR
MLIRPass
MLIRTransforms
)
npcomp_python_target_compile_options(NPCOMPPythonCommon)

View File

@ -1,64 +0,0 @@
//===- NpcompDialect.cpp - Custom dialect classes -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Python/MlirIr.h"
#include "npcomp/Python/NpcompModule.h"
#include "mlir/Dialect/SCF/SCF.h"
namespace mlir {
namespace {
class ScfDialectHelper : public PyDialectHelper {
public:
using PyDialectHelper::PyDialectHelper;
static void bind(py::module m) {
py::class_<ScfDialectHelper, PyDialectHelper>(m, "ScfDialectHelper")
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
py::keep_alive<1, 3>())
.def("scf_yield_op",
[](ScfDialectHelper &self,
std::vector<PyValue> pyYields) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
llvm::SmallVector<Value, 4> yields(pyYields.begin(),
pyYields.end());
auto op = opBuilder.create<scf::YieldOp>(loc, yields);
return op.getOperation();
})
.def(
"scf_if_op",
[](ScfDialectHelper &self, std::vector<PyType> pyResultTypes,
PyValue cond, bool withElseRegion) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
llvm::SmallVector<Type, 4> resultTypes(pyResultTypes.begin(),
pyResultTypes.end());
auto op = opBuilder.create<scf::IfOp>(loc, resultTypes, cond,
withElseRegion);
if (withElseRegion) {
return py::make_tuple(
PyOperationRef(op),
op.getThenBodyBuilder().saveInsertionPoint(),
op.getElseBodyBuilder().saveInsertionPoint());
} else {
return py::make_tuple(
PyOperationRef(op),
op.getThenBodyBuilder().saveInsertionPoint());
}
},
py::arg("result_types"), py::arg("cond"),
py::arg("with_else_region") = false);
}
};
} // namespace
} // namespace mlir
void mlir::defineMlirCoreDialects(py::module m) { ScfDialectHelper::bind(m); }

View File

@ -1,72 +0,0 @@
//===- MlirInit.cpp -------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Python/MlirInit.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "npcomp-c/InitLLVM.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/InitAll.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/TargetSelect.h"
bool mlir::npcomp::python::npcompMlirInitialize() {
// Enable LLVM's signal handler to get nice stack traces.
llvm::sys::SetOneShotPipeSignalFunction(
llvm::sys::DefaultOneShotPipeSignalHandler);
llvm::sys::PrintStackTraceOnErrorSignal("npcomp");
// Register any pass manager command line options.
mlir::registerPassManagerCLOptions();
mlir::registerMLIRContextCLOptions();
std::string program_name = "npcomp";
std::vector<const char *> default_options = {program_name.c_str(), nullptr};
llvm::cl::ParseCommandLineOptions(1, default_options.data());
// Pass registration.
::mlir::registerAllPasses();
::mlir::NPCOMP::registerAllPasses();
// Initialize code generation.
npcompInitializeLLVMCodegen();
return true;
}
void mlir::npcomp::python::loadGlobalDialectsIntoContext(MLIRContext *context) {
static mlir::DialectRegistry registry = ([]() {
mlir::DialectRegistry registry;
::mlir::registerAllDialects(registry);
::mlir::NPCOMP::registerAllDialects(registry);
return registry;
})();
registry.loadAll(context);
}
namespace mlir {
namespace npcomp {
namespace python {
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm,
raw_ostream &errorStream = llvm::errs()) {
return ::mlir::parsePassPipeline(pipeline, pm, errorStream);
}
} // namespace python
} // namespace npcomp
} // namespace mlir

View File

@ -1,987 +0,0 @@
//===- MlirIr.cpp - MLIR IR Bindings --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Python/MlirIr.h"
#include "npcomp/Python/MlirInit.h"
#include "npcomp/Python/NpcompModule.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/Parser.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// Forward declarations
//===----------------------------------------------------------------------===//
struct PyContext;
/// Parses an MLIR module from a string.
/// For maximum efficiency, the `contents` should be zero terminated.
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
MLIRContext *context);
//===----------------------------------------------------------------------===//
// Direct type bindings
//===----------------------------------------------------------------------===//
static void bindInsertPoint(py::module m) {
py::class_<OpBuilder::InsertPoint>(m, "InsertPoint");
}
//===----------------------------------------------------------------------===//
// Internal only template definitions
// Since it is only legal to use explicit instantiations of templates in
// mlir_ir.h, implementations are kept in this module to keep things scoped
// well for the compiler.
//===----------------------------------------------------------------------===//
template <typename ListTy, typename ItemWrapperTy>
void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
const char *className) {
struct PyItemIterator : public llvm::iterator_adaptor_base<
PyItemIterator, typename ListTy::iterator,
typename std::iterator_traits<
typename ListTy::iterator>::iterator_category,
typename ListTy::value_type> {
PyItemIterator() = default;
PyItemIterator(typename ListTy::iterator &&other)
: PyItemIterator::iterator_adaptor_base(std::move(other)) {}
ItemWrapperTy operator*() const { return ItemWrapperTy(*this->I); }
};
py::class_<ThisTy>(m, className)
.def_property_readonly(
"front",
[](ThisTy &self) { return ItemWrapperTy(self.list.front()); })
.def("__len__", [](ThisTy &self) { return self.list.size(); })
.def(
"__iter__",
[](ThisTy &self) {
PyItemIterator begin(self.list.begin());
PyItemIterator end(self.list.end());
return py::make_iterator(begin, end);
},
py::keep_alive<0, 1>());
}
//===----------------------------------------------------------------------===//
// Explicit template instantiations
//===----------------------------------------------------------------------===//
template class PyIpListWrapper<Region::BlockListType, PyBlockRef>;
using PyBlockList = PyIpListWrapper<Region::BlockListType, PyBlockRef>;
template class PyIpListWrapper<Block::OpListType, PyOperationRef>;
using PyOperationList = PyIpListWrapper<Block::OpListType, PyOperationRef>;
//===----------------------------------------------------------------------===//
// Conversions
//===----------------------------------------------------------------------===//
Type mapBufferFormatToType(MLIRContext *context, const std::string &format,
py::ssize_t itemSize) {
// Floating point formats.
if (format == "f")
return FloatType::getF32(context);
if (format == "d")
return FloatType::getF64(context);
if (format == "D")
return ComplexType::get(FloatType::getF64(context));
// Signed integer formats.
if (format == "b" || format == "h" || format == "i" || format == "l" ||
format == "L") {
unsigned width = itemSize * 8;
return IntegerType::get(context, width,
IntegerType::SignednessSemantics::Signed);
}
// Unsigned integer format.
if (format == "B" || format == "H" || format == "I" || format == "k" ||
format == "K") {
unsigned width = itemSize * 8;
return IntegerType::get(context, width,
IntegerType::SignednessSemantics::Unsigned);
}
return Type();
}
/// Creates a DenseElementsAttr from a python buffer which must have been
/// requested to be C-Contiguous.
Attribute createDenseElementsAttrFromBuffer(MLIRContext *context,
py::buffer_info &array) {
Type elementType =
mapBufferFormatToType(context, array.format, array.itemsize);
if (!elementType) {
throw py::raiseValueError(
"Unsupported buffer/array type for conversion to DenseElementsAttr");
}
SmallVector<int64_t, 4> shape(array.shape.begin(),
array.shape.begin() + array.ndim);
RankedTensorType type = RankedTensorType::get(shape, elementType);
const char *rawBufferPtr = reinterpret_cast<const char *>(array.ptr);
ArrayRef<char> rawBuffer(rawBufferPtr, array.size * array.itemsize);
return DenseElementsAttr::getFromRawBuffer(type, rawBuffer, false);
}
//===----------------------------------------------------------------------===//
// Diagnostics
//===----------------------------------------------------------------------===//
/// RAII class to capture diagnostics for later reporting back to the python
/// layer.
class DiagnosticCapture {
public:
DiagnosticCapture(mlir::MLIRContext *mlir_context)
: mlir_context(mlir_context) {
handler_id = mlir_context->getDiagEngine().registerHandler(
[&](Diagnostic &d) -> LogicalResult {
diagnostics.push_back(std::move(d));
return success();
});
}
~DiagnosticCapture() {
if (mlir_context) {
mlir_context->getDiagEngine().eraseHandler(handler_id);
}
}
DiagnosticCapture(DiagnosticCapture &&other) {
mlir_context = other.mlir_context;
diagnostics.swap(other.diagnostics);
handler_id = other.handler_id;
other.mlir_context = nullptr;
}
std::vector<mlir::Diagnostic> &getDiagnostics() { return diagnostics; }
// Consumes/clears diagnostics.
std::string consumeDiagnosticsAsString(const char *error_message);
void clearDiagnostics() { diagnostics.clear(); }
private:
MLIRContext *mlir_context;
std::vector<mlir::Diagnostic> diagnostics;
mlir::DiagnosticEngine::HandlerID handler_id;
};
//===----------------------------------------------------------------------===//
// PyDialectHelper
//===----------------------------------------------------------------------===//
void PyDialectHelper::bind(py::module m) {
py::class_<PyDialectHelper>(m, "DialectHelper")
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
py::keep_alive<1, 3>())
.def_property_readonly("builder",
[](PyDialectHelper &self) -> PyBaseOpBuilder & {
return self.pyOpBuilder;
})
.def_property_readonly(
"context",
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
return self.context.shared_from_this();
})
.def(
"op",
[](PyDialectHelper &self, const std::string &opNameStr,
std::vector<PyType> pyResultTypes, std::vector<PyValue> pyOperands,
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
Location loc = self.pyOpBuilder.getCurrentLoc();
OperationName opName(opNameStr, opBuilder.getContext());
SmallVector<Type, 4> types(pyResultTypes.begin(),
pyResultTypes.end());
SmallVector<Value, 4> operands(pyOperands.begin(),
pyOperands.end());
DictionaryAttr attrList;
if (attrs) {
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
if (!dictAttrs) {
throw py::raiseValueError(
"Expected `attrs` to be a DictionaryAttr");
}
attrList = dictAttrs;
} else {
attrList = DictionaryAttr::get({}, self.getContext());
}
Operation *op =
Operation::create(loc, opName, types, operands, attrList);
opBuilder.insert(op);
return op;
},
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
py::arg("attrs") = llvm::Optional<PyAttribute>())
.def(
"func_op",
[](PyDialectHelper &self, const std::string &name, PyType type,
bool createEntryBlock, llvm::Optional<PyAttribute> attrs) {
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
if (!functionType) {
throw py::raiseValueError("Illegal function type");
}
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
// TODO: Dedup attr creation from op().
DictionaryAttr attrList;
if (attrs) {
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
if (!dictAttrs) {
throw py::raiseValueError(
"Expected `attrs` to be a DictionaryAttr");
}
attrList = dictAttrs;
} else {
attrList = DictionaryAttr::get({}, self.getContext());
}
FuncOp op =
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
/*attrs=*/attrList.getValue());
if (createEntryBlock) {
Block *entryBlock = new Block();
entryBlock->addArguments(functionType.getInputs());
op.getBody().push_back(entryBlock);
opBuilder.setInsertionPointToStart(entryBlock);
}
return PyOperationRef(op);
},
py::arg("name"), py::arg("type"),
py::arg("create_entry_block") = false,
py::arg("attrs") = llvm::Optional<PyAttribute>(),
R"(Creates a new `func` op, optionally creating an entry block.
If an entry block is created, the builder will be positioned
to its start.)")
.def(
"select_op",
[](PyDialectHelper &self, PyValue conditionValue, PyValue trueValue,
PyValue falseValue) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
return PyOperationRef(opBuilder.create<SelectOp>(
loc, conditionValue, trueValue, falseValue));
},
py::arg("condition"), py::arg("true_value"), py::arg("false_value"))
.def("return_op",
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
SmallVector<Value, 4> operands(pyOperands.begin(),
pyOperands.end());
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
})
.def("constant_op",
[](PyDialectHelper &self, PyType type, PyAttribute value) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
return PyOperationRef(
opBuilder.create<ConstantOp>(loc, type.type, value.attr));
})
// Types.
.def_property_readonly("index_type",
[](PyDialectHelper &self) -> PyType {
return IndexType::get(self.getContext());
})
.def(
"integer_type",
[](PyDialectHelper &self, unsigned width) -> PyType {
return IntegerType::get(self.getContext(), width);
},
py::arg("width") = 32)
.def_property_readonly("i1_type",
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(self.getContext(), 1);
})
.def_property_readonly("i16_type",
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(self.getContext(), 32);
})
.def_property_readonly("i32_type",
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(self.getContext(), 32);
})
.def_property_readonly("i64_type",
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(self.getContext(), 64);
})
.def_property_readonly("f32_type",
[](PyDialectHelper &self) -> PyType {
return FloatType::getF32(self.getContext());
})
.def_property_readonly("f64_type",
[](PyDialectHelper &self) -> PyType {
return FloatType::getF64(self.getContext());
})
.def(
"tensor_type",
[](PyDialectHelper &self, PyType elementType,
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
if (!elementType.type) {
throw py::raiseValueError("Null element type");
}
if (shape) {
return RankedTensorType::get(*shape, elementType.type);
} else {
return UnrankedTensorType::get(elementType.type);
}
},
py::arg("element_type"),
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
.def("function_type",
[](PyDialectHelper &self, std::vector<PyType> inputs,
std::vector<PyType> results) -> PyType {
llvm::SmallVector<Type, 4> inputTypes;
llvm::SmallVector<Type, 1> resultTypes;
for (auto input : inputs) {
inputTypes.push_back(input.type);
}
for (auto result : results) {
resultTypes.push_back(result.type);
}
return FunctionType::get(self.getContext(), inputTypes,
resultTypes);
});
}
//===----------------------------------------------------------------------===//
// Module initialization
//===----------------------------------------------------------------------===//
static void emitDiagnostic(DiagnosticSeverity severity, PyAttribute loc,
std::string &message) {
auto locAttr = loc.attr.dyn_cast_or_null<LocationAttr>();
if (!locAttr) {
throw py::raiseValueError("Expected a LocationAttr");
}
auto &diagEngine = locAttr.getContext()->getDiagEngine();
diagEngine.emit(Location(locAttr), severity) << message;
}
void defineMlirIrModule(py::module m) {
m.doc() = "Python bindings for constructs in the mlir/IR library";
// Globals.
m.def(
"emit_error",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Error, loc, message);
},
py::arg("loc"), py::arg("message"));
m.def(
"emit_warning",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Warning, loc, message);
},
py::arg("loc"), py::arg("message"));
m.def(
"emit_remark",
[](PyAttribute loc, std::string message) {
emitDiagnostic(DiagnosticSeverity::Remark, loc, message);
},
py::arg("loc"), py::arg("message"));
// Python only types.
PyDialectHelper::bind(m);
// Utility types.
PyBlockList::bind(m, "BlockList");
PyOperationList::bind(m, "OperationList");
// Wrapper types.
PyAttribute::bind(m);
PyBaseOperation::bind(m);
PyBaseOpBuilder::bind(m);
PyBlockRef::bind(m);
PyContext::bind(m);
PyIdentifier::bind(m);
PyModuleOp::bind(m);
PyOperationRef::bind(m);
PyOpBuilder::bind(m);
PyRegionRef::bind(m);
PySymbolTable::bind(m);
PyType::bind(m);
PyValue::bind(m);
// Direct wrappings.
bindInsertPoint(m);
}
//===----------------------------------------------------------------------===//
// PyContext
//===----------------------------------------------------------------------===//
PyContext::PyContext() {
mlir::npcomp::python::loadGlobalDialectsIntoContext(&context);
}
void PyContext::bind(py::module m) {
py::class_<PyContext, std::shared_ptr<PyContext>>(m, "MLIRContext")
.def(py::init<>([]() {
// Need explicit make_shared to avoid UB with enable_shared_from_this.
return std::make_shared<PyContext>();
}))
.def("new_module",
[&](PyContext &self) -> PyModuleOp {
Location loc = UnknownLoc::get(&self.context);
auto m = ModuleOp::create(loc);
return PyModuleOp(self.shared_from_this(), m);
})
.def("parse_asm", &PyContext::parseAsm)
.def(
"new_builder",
[](PyContext &self) {
// Note: we collapse the Builder and OpBuilder into one because
// there is little reason to expose the inheritance hierarchy to
// Python.
return PyOpBuilder(self);
},
py::keep_alive<0, 1>())
.def("identifier",
[](PyContext &self, std::string s) -> PyIdentifier {
return Identifier::get(s, &self.context);
})
.def(
"file_line_col_loc_attr",
[](PyContext &self, PyIdentifier filename, unsigned line,
unsigned column) -> PyAttribute {
return static_cast<LocationAttr>(FileLineColLoc::get(
filename.identifier, line, column, &self.context));
},
py::arg("filename"), py::arg("line"), py::arg("column"))
// Salient functions from Builder.
.def("parse_type",
[](PyContext &self, const std::string &asmText) {
Type t = parseType(asmText, &self.context);
if (!t) {
std::string message = "Unable to parse MLIR type: ";
message.append(asmText);
throw py::raiseValueError(message);
}
return PyType(t);
})
.def(
"integer_attr",
[](PyContext &self, PyType type, int64_t value) -> PyAttribute {
if (!type.type.isa<IntegerType>()) {
throw py::raiseValueError("Expected IntegerType");
}
return IntegerAttr::get(type.type, value);
},
py::arg("type"), py::arg("value"))
.def("float_attr",
[](PyContext &self, PyType type, double value) -> PyAttribute {
if (!type.type.isa<FloatType>()) {
throw py::raiseValueError("Expected FloatType");
}
return FloatAttr::get(type.type, value);
})
.def("index_attr",
[](PyContext &self, int64_t indexValue) -> PyAttribute {
return IntegerAttr::get(IndexType::get(&self.context), indexValue);
})
.def("string_attr",
[](PyContext &self, const std::string &s) -> PyAttribute {
return StringAttr::get(s, &self.context);
})
.def("bytes_attr",
[](PyContext &self, py::bytes bytes) -> PyAttribute {
char *buffer;
ssize_t length;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bytes.ptr(), &buffer,
&length)) {
throw py::raiseValueError("Cannot extract bytes");
}
return StringAttr::get(StringRef(buffer, length), &self.context);
})
.def("flat_symbol_ref_attr",
[](PyContext &self, const std::string &s) -> PyAttribute {
return FlatSymbolRefAttr::get(s, &self.context);
})
.def("dictionary_attr",
[](PyContext &self, py::dict d) -> PyAttribute {
SmallVector<NamedAttribute, 4> attrs;
for (auto &it : d) {
auto key = it.first.cast<std::string>();
auto value = it.second.cast<PyAttribute>();
auto keyIdent = Identifier::get(key, &self.context);
attrs.emplace_back(keyIdent, value.attr);
}
return DictionaryAttr::get(attrs, &self.context);
})
.def("array_attr",
[](PyContext &self, py::list l) -> PyAttribute {
SmallVector<Attribute, 4> attrs;
for (auto &it : l) {
attrs.push_back(it.cast<PyAttribute>().attr);
}
return ArrayAttr::get(attrs, &self.context);
})
.def(
"dense_elements_attr",
[](PyContext &self, py::buffer array) -> PyAttribute {
// Request a contiguous view.
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
Py_buffer *view = new Py_buffer();
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
delete view;
throw py::error_already_set();
}
py::buffer_info array_info(view);
return createDenseElementsAttrFromBuffer(&self.context, array_info);
},
py::arg("array"))
.def_property_readonly("unit_attr", [](PyContext &self) -> PyAttribute {
return UnitAttr::get(&self.context);
});
}
PyModuleOp PyContext::parseAsm(const std::string &asm_text) {
// Arrange to get a view that includes a terminating null to avoid
// additional copy.
// TODO: Consider using the buffer protocol to access and avoid more copies.
const char *asm_chars = asm_text.c_str();
StringRef asm_sr(asm_chars, asm_text.size() + 1);
// TODO: Output non failure diagnostics (somewhere)
DiagnosticCapture diag_capture(&context);
auto module_ref = parseMLIRModuleFromString(asm_sr, &context);
if (!module_ref) {
throw py::raiseValueError(
diag_capture.consumeDiagnosticsAsString("Error parsing ASM"));
}
return PyModuleOp{shared_from_this(), module_ref.release()};
}
//===----------------------------------------------------------------------===//
// PyBaseOperation
//===----------------------------------------------------------------------===//
PyBaseOperation::~PyBaseOperation() = default;
void PyBaseOperation::bind(py::module m) {
py::class_<PyBaseOperation>(m, "BaseOperation")
.def_property_readonly(
"name",
[](PyBaseOperation &self) {
return std::string(self.getOperation()->getName().getStringRef());
})
.def_property_readonly("is_registered",
[](PyBaseOperation &self) {
return self.getOperation()->isRegistered();
})
.def_property_readonly("num_regions",
[](PyBaseOperation &self) {
return self.getOperation()->getNumRegions();
})
.def_property_readonly("results",
[](PyBaseOperation &self) {
auto *op = self.getOperation();
std::vector<PyValue> results(op->result_begin(),
op->result_end());
return results;
})
.def_property_readonly("result",
[](PyBaseOperation &self) -> PyValue {
auto *op = self.getOperation();
if (op->getNumResults() != 1) {
throw py::raiseValueError(
"Operation does not have 1 result");
}
return op->getOpResult(0);
})
.def("region",
[](PyBaseOperation &self, int index) {
auto *op = self.getOperation();
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
throw py::raisePyError(PyExc_IndexError,
"Region index out of bounds");
}
return PyRegionRef(op->getRegion(index));
})
.def_property_readonly("first_block", [](PyBaseOperation &self) {
Operation *op = self.getOperation();
assert(op);
if (op->getNumRegions() == 0) {
throw py::raiseValueError("Op has no regions");
}
auto &region = op->getRegion(0);
if (region.empty()) {
throw py::raiseValueError("Op has no blocks");
}
return PyBlockRef(region.front());
});
}
//===----------------------------------------------------------------------===//
// PyOperationRef
//===----------------------------------------------------------------------===//
PyOperationRef::~PyOperationRef() = default;
void PyOperationRef::bind(py::module m) {
py::class_<PyOperationRef, PyBaseOperation>(m, "OperationRef");
}
Operation *PyOperationRef::getOperation() { return operation; }
//===----------------------------------------------------------------------===//
// PyModuleOp
//===----------------------------------------------------------------------===//
PyModuleOp::~PyModuleOp() = default;
void PyModuleOp::bind(py::module m) {
py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp")
.def_property_readonly("context",
[](PyModuleOp &self) { return self.context; })
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
py::arg("pretty") = false, py::arg("large_element_limit") = -1);
}
Operation *PyModuleOp::getOperation() { return moduleOp; }
std::string PyModuleOp::toAsm(bool enableDebugInfo, bool prettyForm,
int64_t largeElementLimit) {
// Print to asm.
std::string asmOutput;
llvm::raw_string_ostream sout(asmOutput);
OpPrintingFlags printFlags;
if (enableDebugInfo) {
printFlags.enableDebugInfo(prettyForm);
}
if (largeElementLimit >= 0) {
printFlags.elideLargeElementsAttrs(largeElementLimit);
}
moduleOp.print(sout, printFlags);
return sout.str();
}
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
MLIRContext *context) {
std::unique_ptr<llvm::MemoryBuffer> contents_buffer;
if (contents.back() == 0) {
// If it has a nul terminator, just use as-is.
contents_buffer = llvm::MemoryBuffer::getMemBuffer(contents.drop_back());
} else {
// Otherwise, make a copy.
contents_buffer = llvm::MemoryBuffer::getMemBufferCopy(contents, "EMBED");
}
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc());
OwningModuleRef mlir_module = parseSourceFile(source_mgr, context);
return mlir_module;
}
// Custom location printer that prints prettier, multi-line file output
// suitable for human readable error messages. The standard printer just prints
// a long nested expression not particularly human friendly). Note that there
// is a location pretty printer in the MLIR AsmPrinter. It is private and
// doesn't do any path shortening, which seems to make long Python stack traces
// a bit easier to scan.
// TODO: Upstream this.
void printLocation(Location loc, raw_ostream &out) {
TypeSwitch<Location>(loc)
.Case<OpaqueLoc>(
[&](OpaqueLoc loc) { printLocation(loc.getFallbackLocation(), out); })
.Case<UnknownLoc>([&](Location) { out << " [unknown location]\n"; })
.Case<FileLineColLoc>([&](FileLineColLoc line_col_loc) {
StringRef this_filename = line_col_loc.getFilename();
auto slash_pos = this_filename.find_last_of("/\\");
// We print both the basename and extended names with a structure like
// `foo.py:35:4`. Even though technically the line/col
// information is redundant to include in both names, having it on both
// makes it easier to paste the paths into an editor and jump to the
// exact location.
std::string line_col_suffix =
":" + std::to_string(line_col_loc.getLine()) + ":" +
std::to_string(line_col_loc.getColumn());
bool has_basename = false;
StringRef basename = this_filename;
if (slash_pos != StringRef::npos) {
has_basename = true;
basename = this_filename.substr(slash_pos + 1);
}
out << " at: " << basename << line_col_suffix;
if (has_basename) {
StringRef extended_name = this_filename;
// Print out two tabs, as basenames usually vary in length by more
// than one tab width.
out << "\t\t( " << extended_name << line_col_suffix << " )";
}
out << "\n";
})
.Case<NameLoc>([&](NameLoc nameLoc) {
out << " @'" << nameLoc.getName() << "':\n";
auto childLoc = nameLoc.getChildLoc();
if (!childLoc.isa<UnknownLoc>()) {
out << "(...\n";
printLocation(childLoc, out);
out << ")\n";
}
})
.Case<CallSiteLoc>([&](CallSiteLoc callSite) {
printLocation(callSite.getCaller(), out);
printLocation(callSite.getCallee(), out);
});
}
//===----------------------------------------------------------------------===//
// PySymbolTable
//===----------------------------------------------------------------------===//
void PySymbolTable::bind(py::module m) {
py::class_<PySymbolTable>(m, "SymbolTable")
.def_property_readonly_static("symbol_attr_name",
[](const py::object &) {
auto sr =
SymbolTable::getSymbolAttrName();
return py::str(sr.data(), sr.size());
})
.def_property_readonly_static(
"visibility_attr_name", [](const py::object &) {
auto sr = SymbolTable::getVisibilityAttrName();
return py::str(sr.data(), sr.size());
});
}
//===----------------------------------------------------------------------===//
// DiagnosticCapture
//===----------------------------------------------------------------------===//
std::string
DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
std::string s;
llvm::raw_string_ostream sout(s);
bool first = true;
if (error_message) {
sout << error_message;
first = false;
}
for (auto &d : diagnostics) {
if (!first) {
sout << "\n\n";
} else {
first = false;
}
switch (d.getSeverity()) {
case DiagnosticSeverity::Note:
sout << "[NOTE]";
break;
case DiagnosticSeverity::Warning:
sout << "[WARNING]";
break;
case DiagnosticSeverity::Error:
sout << "[ERROR]";
break;
case DiagnosticSeverity::Remark:
sout << "[REMARK]";
break;
}
// Message.
sout << ": " << d << "\n";
printLocation(d.getLocation(), sout);
}
diagnostics.clear();
return sout.str();
}
//===----------------------------------------------------------------------===//
// PyBlockRef
//===----------------------------------------------------------------------===//
void PyBlockRef::bind(py::module m) {
py::class_<PyBlockRef>(m, "BlockRef")
.def_property_readonly("operations",
[](PyBlockRef &self) {
return PyOperationList(
self.block.getOperations());
})
.def_property_readonly("args", [](PyBlockRef &self) {
return std::vector<PyValue>(self.block.args_begin(),
self.block.args_end());
});
}
//===----------------------------------------------------------------------===//
// PyRegionRef
//===----------------------------------------------------------------------===//
void PyRegionRef::bind(py::module m) {
py::class_<PyRegionRef>(m, "RegionRef")
.def_property_readonly("blocks", [](PyRegionRef &self) {
return PyBlockList(self.region.getBlocks());
});
}
//===----------------------------------------------------------------------===//
// PyType
//===----------------------------------------------------------------------===//
void PyType::bind(py::module m) {
py::class_<PyType>(m, "Type").def("__repr__",
[](PyType &self) -> std::string {
if (!self.type)
return "<undefined type>";
std::string res;
llvm::raw_string_ostream os(res);
self.type.print(os);
return res;
});
}
//===----------------------------------------------------------------------===//
// PyIdentifier
//===----------------------------------------------------------------------===//
void PyIdentifier::bind(py::module m) {
py::class_<PyIdentifier>(m, "Identifier")
.def("__str__", [](PyIdentifier &self) { return self.identifier.str(); })
.def("__repr__", [](PyIdentifier &self) {
std::string s("<Identifier \"");
s.append(self.identifier.str());
s.append("\">");
return s;
});
}
//===----------------------------------------------------------------------===//
// PyValue
//===----------------------------------------------------------------------===//
void PyValue::bind(py::module m) {
py::class_<PyValue>(m, "Value")
.def_property_readonly(
"type", [](PyValue &self) -> PyType { return self.value.getType(); })
.def("__repr__", [](PyValue &self) {
std::string res;
llvm::raw_string_ostream os(res);
os << self.value;
return res;
});
}
//===----------------------------------------------------------------------===//
// PyAttribute
//===----------------------------------------------------------------------===//
void PyAttribute::bind(py::module m) {
py::class_<PyAttribute>(m, "Attribute")
.def_property_readonly(
"type",
[](PyAttribute &self) -> PyType { return self.attr.getType(); })
.def("__repr__", [](PyAttribute &self) {
std::string res;
llvm::raw_string_ostream os(res);
os << self.attr;
return res;
});
}
//===----------------------------------------------------------------------===//
// OpBuilder implementations
//===----------------------------------------------------------------------===//
PyBaseOpBuilder::~PyBaseOpBuilder() = default;
PyOpBuilder::~PyOpBuilder() = default;
OpBuilder &PyOpBuilder::getBuilder(bool requirePosition) {
if (requirePosition && !builder.getBlock()) {
throw py::raisePyError(PyExc_IndexError, "Insertion point not set");
}
return builder;
}
void PyBaseOpBuilder::bind(py::module m) {
py::class_<PyBaseOpBuilder>(m, "BaseOpBuilder");
}
void PyOpBuilder::bind(py::module m) {
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
.def(py::init<PyContext &>(), py::keep_alive<1, 2>())
.def_property(
"current_loc",
[](PyOpBuilder &self) -> PyAttribute {
return static_cast<Attribute>(self.getCurrentLoc());
},
[](PyOpBuilder &self, PyAttribute attr) {
auto loc_attr = attr.attr.dyn_cast_or_null<LocationAttr>();
if (!loc_attr) {
throw py::raiseValueError("Expected a LocationAttr");
}
self.setCurrentLoc(Location(loc_attr));
})
.def_property(
"insertion_point",
[](PyOpBuilder &self) {
return self.getBuilder(true).saveInsertionPoint();
},
[](PyOpBuilder &self, OpBuilder::InsertPoint ip) {
self.getBuilder(false).restoreInsertionPoint(ip);
})
.def(
"set_file_line_col",
[](PyOpBuilder &self, PyIdentifier filename, unsigned line,
unsigned column) {
Location loc = FileLineColLoc::get(filename.identifier, line,
column, self.getContext());
self.setCurrentLoc(loc);
},
py::arg("filename"), py::arg("line"), py::arg("column"),
"Shortcut to set a FileLineCol current location")
.def("clear_insertion_point",
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
.def(
"insert_op_before",
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation();
self.builder.setInsertionPoint(op);
},
"Sets the insertion point to just before the specified op.")
.def(
"insert_op_after",
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation();
self.builder.setInsertionPointAfter(op);
},
"Sets the insertion point to just after the specified op.")
.def(
"insert_block_start",
[](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToStart(&block.block);
},
"Sets the insertion point to the start of the block.")
.def(
"insert_block_end",
[](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToEnd(&block.block);
},
"Sets the insertion point to the end of the block.")
.def(
"insert_before_terminator",
[](PyOpBuilder &self, PyBlockRef block) {
auto *terminator = block.block.getTerminator();
if (!terminator) {
throw py::raiseValueError("Block has no terminator");
}
self.builder.setInsertionPoint(terminator);
},
"Sets the insertion point to just before the block terminator.");
}
} // namespace mlir

View File

@ -1,82 +0,0 @@
//===- MlirIr.cpp - MLIR IR Bindings --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Python/MlirPass.h"
#include "npcomp/Python/MlirInit.h"
#include "npcomp/Python/NpcompModule.h"
#include "mlir/Pass/PassRegistry.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// Module initialization
//===----------------------------------------------------------------------===//
void defineMlirPassModule(py::module m) {
m.doc() = "Python bindings for mlir pass infra";
PyPassManager::bind(m);
}
//===----------------------------------------------------------------------===//
// PassManager
//===----------------------------------------------------------------------===//
void PyPassManager::bind(py::module m) {
py::class_<PyPassManager>(m, "PassManager")
.def(py::init<std::shared_ptr<PyContext>, bool>(), py::arg("context"),
py::arg("verifyModules") = true)
.def(
"enableCrashReproducerGeneration",
[](PyPassManager &self, std::string outputFile,
bool genLocalReproducer) {
self.passManager.enableCrashReproducerGeneration(
outputFile, genLocalReproducer);
},
py::arg("outputFile"), py::arg("genLocalReproducer") = false)
.def("__len__",
[](PyPassManager &self) { return self.passManager.size(); })
.def("__str__",
[](PyPassManager &self) {
std::string spec;
llvm::raw_string_ostream stream(spec);
self.passManager.printAsTextualPipeline(stream);
return spec;
})
.def("run",
[](PyPassManager &self, PyModuleOp &module) {
if (module.context.get() != self.context.get()) {
throw py::raiseValueError(
"Expected a module with the same context "
"as the PassManager");
}
if (failed(self.passManager.run(module.moduleOp))) {
// TODO: Wrap propagate context diagnostics
throw py::raisePyError(PyExc_RuntimeError,
"Could not run passes");
}
})
.def("addPassPipelines", [](PyPassManager &self, py::args passPipelines) {
std::string error;
llvm::raw_string_ostream error_stream(error);
for (auto pyPassPipeline : passPipelines) {
auto passPipeline = pyPassPipeline.cast<std::string>();
if (failed(mlir::parsePassPipeline(passPipeline, self.passManager,
error_stream))) {
std::string message = "failed to parse pass pipeline '";
message.append(passPipeline);
message.append("': ");
message.append(error);
throw py::raiseValueError(message);
}
}
});
}
} // namespace mlir

View File

@ -1,172 +0,0 @@
//===- NpcompDialect.cpp - Custom dialect classes -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Python/MlirIr.h"
#include "npcomp/Python/NpcompModule.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
namespace mlir {
namespace NPCOMP {
class BasicpyDialectHelper : public PyDialectHelper {
public:
using PyDialectHelper::PyDialectHelper;
static void bind(py::module m) {
py::class_<BasicpyDialectHelper, PyDialectHelper>(m, "BasicpyDialectHelper")
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
py::keep_alive<1, 3>())
// ---------------------------------------------------------------------
// Basicpy dialect
// ---------------------------------------------------------------------
.def_property_readonly("basicpy_BoolType",
[](BasicpyDialectHelper &self) -> PyType {
return Basicpy::BoolType::get(
self.getContext());
})
.def_property_readonly("basicpy_BytesType",
[](BasicpyDialectHelper &self) -> PyType {
return Basicpy::BytesType::get(
self.getContext());
})
.def_property_readonly("basicpy_EllipsisType",
[](BasicpyDialectHelper &self) -> PyType {
return Basicpy::EllipsisType::get(
self.getContext());
})
.def_property_readonly("basicpy_NoneType",
[](BasicpyDialectHelper &self) -> PyType {
return Basicpy::NoneType::get(
self.getContext());
})
.def(
"basicpy_SlotObject_type",
[](BasicpyDialectHelper &self, std::string className,
py::args pySlotTypes) -> PyType {
SmallVector<Type, 4> slotTypes;
for (auto pySlotType : pySlotTypes) {
slotTypes.push_back(pySlotType.cast<PyType>());
}
auto classNameAttr =
StringAttr::get(className, self.getContext());
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
},
py::arg("className"))
.def_property_readonly("basicpy_StrType",
[](BasicpyDialectHelper &self) -> PyType {
return Basicpy::StrType::get(
self.getContext());
})
.def_property_readonly("basicpy_UnknownType",
[](BasicpyDialectHelper &self) -> PyType {
return Basicpy::UnknownType::get(
self.getContext());
})
.def("basicpy_exec_op",
[](BasicpyDialectHelper &self) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
auto op = opBuilder.create<Basicpy::ExecOp>(loc);
return py::make_tuple(PyOperationRef(op),
op.getBodyBuilder().saveInsertionPoint());
})
.def("basicpy_exec_discard_op",
[](BasicpyDialectHelper &self, std::vector<PyValue> pyOperands) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
llvm::SmallVector<Value, 4> operands(pyOperands.begin(),
pyOperands.end());
auto op =
opBuilder.create<Basicpy::ExecDiscardOp>(loc, operands);
return PyOperationRef(op);
})
.def("basicpy_slot_object_get_op",
[](BasicpyDialectHelper &self, PyValue slotObject,
unsigned index) -> PyOperationRef {
auto slotObjectType = slotObject.value.getType()
.dyn_cast<Basicpy::SlotObjectType>();
if (!slotObjectType) {
throw py::raiseValueError("Operand must be a SlotObject");
}
if (index >= slotObjectType.getSlotCount()) {
throw py::raiseValueError("Out of range slot index");
}
auto resultType = slotObjectType.getSlotTypes()[index];
auto indexAttr =
IntegerAttr::get(IndexType::get(self.getContext()), index);
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
auto op = opBuilder.create<Basicpy::SlotObjectGetOp>(
loc, resultType, slotObject, indexAttr);
return op.getOperation();
})
// ---------------------------------------------------------------------
// Numpy dialect
// ---------------------------------------------------------------------
.def("numpy_copy_to_tensor_op",
[](BasicpyDialectHelper &self, PyValue source) -> PyOperationRef {
auto sourceType =
source.value.getType().dyn_cast<Numpy::NdArrayType>();
if (!sourceType) {
source.value.dump();
throw py::raiseValueError("expected ndarray type for "
"numpy_copy_to_tensor_op");
}
auto dtype = sourceType.getDtype();
auto optionalShape = sourceType.getOptionalShape();
TensorType tensorType;
if (optionalShape) {
tensorType = RankedTensorType::get(*optionalShape, dtype);
} else {
tensorType = UnrankedTensorType::get(dtype);
}
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
auto op = opBuilder.create<Numpy::CopyToTensorOp>(
loc, tensorType, source.value);
return op.getOperation();
})
.def("numpy_create_array_from_tensor_op",
[](BasicpyDialectHelper &self, PyValue source) -> PyOperationRef {
auto sourceType = source.value.getType().dyn_cast<TensorType>();
if (!sourceType) {
throw py::raiseValueError("expected tensor type for "
"numpy_create_array_from_tensor_op");
}
auto dtype = sourceType.getElementType();
llvm::Optional<ArrayRef<int64_t>> optionalShape;
if (auto rankedTensorType =
sourceType.dyn_cast<RankedTensorType>()) {
optionalShape = rankedTensorType.getShape();
}
auto ndarrayType = Numpy::NdArrayType::get(dtype, optionalShape);
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = self.pyOpBuilder.getCurrentLoc();
auto op = opBuilder.create<Numpy::CreateArrayFromTensorOp>(
loc, ndarrayType, source.value);
return op.getOperation();
})
.def("numpy_NdArrayType",
[](BasicpyDialectHelper &self, PyType dtype) -> PyType {
return Numpy::NdArrayType::get(dtype.type);
});
}
};
} // namespace NPCOMP
} // namespace mlir
using namespace ::mlir::NPCOMP;
void mlir::npcomp::python::defineNpcompDialect(py::module m) {
BasicpyDialectHelper::bind(m);
}

View File

@ -75,9 +75,11 @@ target_link_libraries(NPCOMPNativePyExt
NPCOMP
NPCOMPInitAll
NPCOMPPythonCommon
${NPCOMP_PYEXT_LIBADD}
)
npcomp_python_target_compile_options(NPCOMPNativePyExt)
mlir_check_all_link_libraries(NPCOMPNativePyExt)
# Order dependent: Built artifacts add dependencies to the above targets.
add_subdirectory(npcomp/dialects)

View File

@ -1,4 +1,4 @@
//===- native.cpp - MLIR Python bindings ----------------------------------===//
//===- NpcompModule.cpp - MLIR Python bindings ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -9,85 +9,58 @@
#include <cstddef>
#include <unordered_map>
#include "npcomp/Python/MlirInit.h"
#include "npcomp/Python/NpcompModule.h"
#include "npcomp/Python/PybindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/InitLLVM.h"
#include "npcomp-c/Registration.h"
#include "llvm/Support/CommandLine.h"
#include "npcomp-c/Types.h"
#include "npcomp/Python/PybindUtils.h"
#ifdef NPCOMP_ENABLE_REFJIT
#include "npcomp/Backend/RefJIT/PythonModule.h"
#endif
namespace mlir {
namespace npcomp {
namespace python {
namespace {
static void defineLLVMModule(pybind11::module m) {
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
m.def("add_option",
[](std::string name, llvm::Optional<std::string> value) {
auto options_map = llvm::cl::getRegisteredOptions();
auto found_it = options_map.find(name);
if (found_it == options_map.end()) {
std::string message = "Unknown LLVM option: ";
message.append(name);
throw py::raiseValueError(message.c_str());
}
std::string value_sr = value ? *value : "";
found_it->getValue()->addOccurrence(1, name, value_sr);
},
py::arg("name"), py::arg("value") = llvm::Optional<std::string>());
m.def("reset_option",
[](std::string name) {
auto options_map = llvm::cl::getRegisteredOptions();
auto found_it = options_map.find(name);
if (found_it == options_map.end()) {
std::string message = "Unknown LLVM option: ";
message.append(name);
throw py::raiseValueError(message.c_str());
}
found_it->getValue()->setDefault();
},
py::arg("name"));
MlirType shapedToNdArrayArrayType(MlirType shaped_type) {
if (!mlirTypeIsAShaped(shaped_type)) {
throw py::raiseValueError("type is not a shaped type");
}
return npcompNdArrayTypeGetFromShaped(shaped_type);
}
static void defineGlobals(py::module &m) {
m.def("register_dialects", [](MlirContext context) {
npcompRegisterAllDialects(context);
});
MlirType ndarrayToTensorType(MlirType ndarray_type) {
if (!npcompTypeIsANdArray(ndarray_type)) {
throw py::raiseValueError("type is not an ndarray type");
}
return npcompNdArrayTypeToTensor(ndarray_type);
}
MlirType slotObjectType(MlirContext context, const std::string &className,
const std::vector<MlirType> &slotTypes) {
MlirStringRef classNameSr{className.data(), className.size()};
return ::npcompSlotObjectTypeGet(context, classNameSr, slotTypes.size(),
slotTypes.data());
}
// TODO: Move this upstream.
void emitError(MlirLocation loc, std::string message) {
::mlirEmitError(loc, message.c_str());
}
} // namespace
PYBIND11_MODULE(_npcomp, m) {
// Guard the once init to happen once per process (vs module, which in
// mondo builds can happen multiple times).
static bool llvm_init_baton = ([]() { return npcompMlirInitialize(); })();
(void)(llvm_init_baton);
m.doc() = "Npcomp native python bindings";
// TODO: Retire the llvm, mlir, passes, and dialect modules in favor of
// upstream Python bindings.
auto llvm_m = m.def_submodule("llvm", "LLVM interop");
defineLLVMModule(llvm_m);
// "mlir" module.
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
auto mlir_ir_m = mlir_m.def_submodule("ir");
defineMlirIrModule(mlir_ir_m);
// Note: not "pass" because it is a reserved word
auto mlir_pass_m = mlir_m.def_submodule("passes");
defineMlirPassModule(mlir_pass_m);
auto mlir_dialect_m = mlir_m.def_submodule("dialect");
defineMlirCoreDialects(mlir_dialect_m);
// Outer "_npcomp" module
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
defineNpcompDialect(npcomp_dialect);
// Globals.
defineGlobals(m);
m.def("register_all_dialects", ::npcompRegisterAllDialects);
m.def("_register_all_passes", ::npcompRegisterAllPasses);
m.def("_initialize_llvm_codegen", ::npcompInitializeLLVMCodegen);
m.def("shaped_to_ndarray_type", shapedToNdArrayArrayType);
m.def("ndarray_to_tensor_type", ndarrayToTensorType);
m.def("slot_object_type", slotObjectType);
m.def("emit_error", emitError);
// Optional backend modules.
auto backend_m = m.def_submodule("backend", "Backend support");
@ -99,7 +72,3 @@ PYBIND11_MODULE(_npcomp, m) {
::npcomp::python::defineBackendRefJitModule(refjit_m);
#endif
}
} // namespace python
} // namespace npcomp
} // namespace mlir

View File

@ -1,3 +1,25 @@
def _load_extension():
# TODO: Remote the RTLD_GLOBAL hack once local, cross module imports
# resolve symbols properly. Something is keeping the dynamic loader on
# Linux from treating the following vague symbols as the same across
# _mlir and _npcomp:
# mlir::detail::TypeIDExported::get<mlir::FuncOp>()::instance
import sys
import ctypes
flags = sys.getdlopenflags()
sys.setdlopenflags(flags | ctypes.RTLD_GLOBAL)
import _npcomp
sys.setdlopenflags(flags)
import mlir
mlir._cext.globals.append_dialect_search_prefix("npcomp.dialects")
return _npcomp
_cext = _load_extension()
_cext._register_all_passes()
_cext._initialize_llvm_codegen()
# Top-level symbols.
from .exporter import *
from .types import *

View File

@ -7,8 +7,6 @@ import subprocess
from mlir.ir import *
from mlir.passmanager import *
from _npcomp import register_dialects
from _npcomp import mlir as legacy_mlir
from npcomp.compiler.generic.backend import iree as iree_backend
from npcomp.compiler.utils import logging
@ -56,7 +54,7 @@ class CompilerBackend:
self._ireert = _get_iree()
self._debug = logging.debug_enabled()
def compile(self, legacy_imported_ir_module: legacy_mlir.ir.ModuleOp):
def compile(self, imported_module: Module):
"""Compiles an imported module.
Args:
@ -67,10 +65,12 @@ class CompilerBackend:
for IREE, it is a serialized VM flatbuffer) but the contract is that
it is operated on by methods on this class.
"""
# TODO: Once transitioned to new Python API, don't reparse the module.
with Context() as context:
register_dialects(context)
imported_module = Module.parse(legacy_imported_ir_module.to_asm())
with imported_module.context:
# Frontend.
if self._debug:
logging.debug("Input IR:\n{}", imported_module)
assert (
imported_module.operation.verify()), "Imported module does not verify"
# Frontend.
pm = PassManager.parse(",".join(FRONTEND_PASSES))
pm.run(imported_module)

View File

@ -6,8 +6,6 @@ import os
from mlir.ir import *
from mlir.passmanager import *
from _npcomp import register_dialects
from _npcomp import mlir as legacy_mlir
from npcomp.compiler.generic.backend import refjit as refjit_backend
from npcomp.compiler.utils import logging
@ -37,7 +35,7 @@ class CompilerBackend:
self._refjit = refjit_backend.get_refjit()
self._debug = logging.debug_enabled()
def compile(self, legacy_imported_ir_module: legacy_mlir.ir.ModuleOp):
def compile(self, imported_module: Module):
"""Compiles an imported module.
Args:
@ -49,15 +47,16 @@ class CompilerBackend:
for IREE, it is a serialized VM flatbuffer) but the contract is that
it is operated on by methods on this class.
"""
# TODO: Once transitioned to new Python API, don't reparse the module.
with Context() as context:
register_dialects(context)
imported_module = Module.parse(legacy_imported_ir_module.to_asm())
with imported_module.context as context:
# Frontend.
if self._debug:
logging.debug("Input IR:\n{}", imported_module)
assert (
imported_module.operation.verify()), "Imported module does not verify"
pm = PassManager.parse(",".join(FRONTEND_PASSES))
pm.run(imported_module)
if self._debug:
logging.debug("Frontend IR:{}", imported_module)
logging.debug("Frontend IR:\n{}", imported_module)
# Backend.
# Note that this is a separate pass manager purely to aid in debugging.
@ -65,7 +64,7 @@ class CompilerBackend:
self._refjit.build_backend_compilation_pipeline(pm)
pm.run(imported_module)
if self._debug:
logging.debug("Backend IR:{}", imported_module)
logging.debug("Backend IR:\n{}", imported_module)
jit_module = self._refjit.JITModule.from_compiled_module(
imported_module, refjit_backend.get_runtime_libs())

View File

@ -8,7 +8,10 @@ from typing import Callable, Iterator, Sequence, Tuple
import functools
import numpy as np
from _npcomp.mlir import ir
from mlir import ir as _ir
from npcomp import _cext
from npcomp.dialects import numpy as numpy_ops
from ....utils import logging
from ...interfaces import *
@ -71,7 +74,7 @@ class BuiltinUfuncLiveValueRef(LiveValueRef):
self._qualified_name = qualified_name
self._ufunc = ufunc
def resolve_call(self, env: Environment, args: Sequence[ir.Value],
def resolve_call(self, env: Environment, args: Sequence[_ir.Value],
keywords: Sequence[str]) -> PartialEvalResult:
if keywords:
return PartialEvalResult.error_message(
@ -80,14 +83,27 @@ class BuiltinUfuncLiveValueRef(LiveValueRef):
return PartialEvalResult.error_message(
"ufunc {} expected {} inputs but got {}".format(
self._qualified_name, self._ufunc.nin, len(args)))
ir_h = env.ir_h
ic = env.ic
# Because a ufunc call is defined in terms of tensors and, at this stage,
# all "public" values are ndarray, do appropriate conversions.
tensor_args = [ir_h.numpy_copy_to_tensor_op(arg).result for arg in args]
result_type = ir_h.numpy_unknown_tensor_type
tensor_result = ir_h.numpy_builtin_ufunc_call_op(
*tensor_args,
qualified_name=self._qualified_name,
result_type=result_type).result
array_result = ir_h.numpy_create_array_from_tensor_op(tensor_result).result
def copy_to_tensor(value):
tensor_type = _cext.ndarray_to_tensor_type(value.type)
return numpy_ops.CopyToTensorOp(tensor_type, value, loc=ic.loc,
ip=ic.ip).result
tensor_args = [copy_to_tensor(arg) for arg in args]
result_type = ic.unknown_tensor_type
tensor_result = numpy_ops.BuiltinUfuncCallOp(result_type,
_ir.StringAttr.get(
self._qualified_name,
context=ic.context),
tensor_args,
loc=ic.loc,
ip=ic.ip).result
array_result = numpy_ops.CreateArrayFromTensorOp(
_cext.shaped_to_ndarray_type(tensor_result.type),
tensor_result,
loc=ic.loc,
ip=ic.ip).result
return PartialEvalResult.yields_ir_value(array_result)

View File

@ -6,7 +6,11 @@
import numpy as np
from typing import Union
from _npcomp.mlir import ir
from mlir import ir as _ir
from mlir.dialects import std as std_ops
from npcomp import _cext
from npcomp.dialects import numpy as numpy_ops
from ....utils import logging
from ...interfaces import *
@ -23,15 +27,22 @@ class NdArrayValueCoder(ValueCoder):
__slots__ = []
def code_py_value_as_const(self, env: Environment,
py_value) -> Union[_NotImplementedType, ir.Value]:
py_value) -> Union[_NotImplementedType, _ir.Value]:
# TODO: Query for ndarray compat (for duck typed and such)
# TODO: Have a higher level name resolution signal which indicates const
ir_h = env.ir_h
ic = env.ic
if isinstance(py_value, np.ndarray):
dense_attr = ir_h.context.dense_elements_attr(py_value)
dense_attr = _ir.DenseElementsAttr.get(py_value, context=ic.context)
tensor_type = dense_attr.type
tensor_value = ir_h.constant_op(tensor_type, dense_attr).result
return ir_h.numpy_create_array_from_tensor_op(tensor_value).result
tensor_value = std_ops.ConstantOp(tensor_type,
dense_attr,
loc=ic.loc,
ip=ic.ip).result
ndarray_type = _cext.shaped_to_ndarray_type(tensor_type)
return numpy_ops.CreateArrayFromTensorOp(ndarray_type,
tensor_value,
loc=ic.loc,
ip=ic.ip).result
return NotImplemented

View File

@ -10,16 +10,14 @@ import inspect
import textwrap
from typing import Optional
from _npcomp.mlir import ir
from _npcomp.mlir.dialect import ScfDialectHelper
from npcomp.dialect import Numpy
from mlir import ir as _ir
from ..utils import logging
from .importer import *
from .interfaces import *
from .name_resolver_base import *
from .value_coder_base import *
from .target import *
from ..utils.mlir_utils import *
__all__ = [
"ImportFrontend",
@ -29,34 +27,27 @@ __all__ = [
class ImportFrontend:
"""Frontend for importing various entities into a Module."""
__slots__ = [
"_ir_context",
"_ir_module",
"_ir_h",
"_config",
"_ic",
]
def __init__(self,
*,
config: Configuration,
ir_context: ir.MLIRContext = None):
ir_context: Optional[_ir.Context] = None):
super().__init__()
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
self._ir_module = self._ir_context.new_module()
self._ir_h = AllDialectHelper(self._ir_context,
ir.OpBuilder(self._ir_context))
ic = self._ic = ImportContext(ir_context)
self._ic.module = _ir.Module.create(loc=ic.loc)
self._config = config
@property
def ir_context(self):
return self._ir_context
def ir_context(self) -> _ir.Context:
return self._ic.context
@property
def ir_module(self):
return self._ir_module
@property
def ir_h(self):
return self._ir_h
def ir_module(self) -> _ir.Module:
return self._ic.module
def import_global_function(self, f):
"""Imports a global function.
@ -70,10 +61,8 @@ class ImportFrontend:
Args:
f: The python callable.
"""
h = self.ir_h
ir_c = self.ir_context
ir_m = self.ir_module
target = self._config.target_factory(h)
ic = self._ic
target = self._config.target_factory(ic)
filename = inspect.getsourcefile(f)
source_lines, start_lineno = inspect.getsourcelines(f)
source = "".join(source_lines)
@ -81,7 +70,6 @@ class ImportFrontend:
ast_root = ast.parse(source, filename=filename)
ast.increment_lineno(ast_root, start_lineno - 1)
ast_fd = ast_root.body[0]
filename_ident = ir_c.identifier(filename)
# Define the function.
# TODO: Much more needs to be done here (arg/result mapping, etc)
@ -98,27 +86,21 @@ class ImportFrontend:
]
f_return_type = self._resolve_signature_annotation(
target, f_signature.return_annotation)
ir_f_type = h.function_type(f_input_types, [f_return_type])
ir_f_type = _ir.FunctionType.get(f_input_types, [f_return_type],
context=ic.context)
h.builder.set_file_line_col(filename_ident, ast_fd.lineno,
ast_fd.col_offset)
h.builder.insert_before_terminator(ir_m.first_block)
# TODO: Do not hardcode this IREE attribute.
attrs = ir_c.dictionary_attr({"iree.module.export": ir_c.unit_attr})
ir_f = h.func_op(ast_fd.name,
ir_f_type,
create_entry_block=True,
attrs=attrs)
ic.set_file_line_col(filename, ast_fd.lineno, ast_fd.col_offset)
ic.insert_before_terminator(ic.module.body)
ir_f, entry_block = ic.FuncOp(ast_fd.name,
ir_f_type,
create_entry_block=True)
ic.insert_end_of_block(entry_block)
env = self._create_const_global_env(f,
parameter_bindings=zip(
f_params.keys(),
ir_f.first_block.args),
entry_block.arguments),
target=target)
fctx = FunctionContext(ir_c=ir_c,
ir_f=ir_f,
ir_h=h,
filename_ident=filename_ident,
environment=env)
fctx = FunctionContext(ic=ic, ir_f=ir_f, filename=filename, environment=env)
fdimport = FunctionDefImporter(fctx, ast_fd)
fdimport.import_body()
@ -131,7 +113,7 @@ class ImportFrontend:
for advanced cases, including mutable global state, closures, etc.
Globals from the module are considered immutable.
"""
ir_h = self._ir_h
ic = self._ic
try:
code = f.__code__
globals_dict = f.__globals__
@ -149,7 +131,7 @@ class ImportFrontend:
ConstModuleNameResolver(globals_dict, as_dict=True),
ConstModuleNameResolver(builtins_module),
)
env = Environment(config=self._config, ir_h=ir_h, name_resolvers=resolvers)
env = Environment(config=self._config, ic=ic, name_resolvers=resolvers)
# Bind parameters.
for name, value in parameter_bindings:
@ -158,9 +140,9 @@ class ImportFrontend:
return env
def _resolve_signature_annotation(self, target: Target, annot):
ir_h = self._ir_h
ic = self._ic
if annot is inspect.Signature.empty:
return ir_h.basicpy_UnknownType
return ic.unknown_type
# TODO: Do something real here once we need more than the primitive types.
if annot is int:
@ -168,23 +150,8 @@ class ImportFrontend:
elif annot is float:
return target.impl_float_type
elif annot is bool:
return ir_h.basicpy_BoolType
return ic.bool_type
elif annot is str:
return ir_h.basicpy_StrType
return ic.str_type
else:
return ir_h.basicpy_UnknownType
################################################################################
# Support
################################################################################
# TODO: Remove this hack in favor of a helper function that combines
# multiple dialect helpers so that we don't need to deal with the sharp
# edge of initializing multiple native base classes.
class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
def __init__(self, *args, **kwargs):
Numpy.DialectHelper.__init__(self, *args, **kwargs)
ScfDialectHelper.__init__(self, *args, **kwargs)
return ic.unknown_type

View File

@ -8,7 +8,11 @@ import ast
import sys
import traceback
from _npcomp.mlir import ir
from mlir import ir as _ir
from mlir.dialects import std as std_ops
from npcomp import _cext
from npcomp.dialects import basicpy as basicpy_ops
from ..utils import logging
from .interfaces import *
@ -23,24 +27,23 @@ __all__ = [
class FunctionContext:
"""Accounting information for importing a function."""
__slots__ = [
"ir_c",
"ic",
"ir_f",
"ir_h",
"filename_ident",
"filename",
"environment",
]
def __init__(self, ir_c, ir_f, ir_h, filename_ident, environment):
self.ir_c = ir_c
def __init__(self, *, ic: ImportContext, ir_f: _ir.Operation, filename: str,
environment: Environment):
self.ic = ic
self.ir_f = ir_f
self.ir_h = ir_h
self.filename_ident = filename_ident
self.filename = filename
self.environment = environment
def abort(self, message):
"""Emits an error diagnostic and raises an exception to abort."""
loc = self.current_loc
ir.emit_error(loc, message)
_cext.emit_error(loc, message)
raise EmittedError(loc, message)
def check_partial_evaluated(self, result: PartialEvalResult):
@ -53,18 +56,20 @@ class FunctionContext:
else:
message = ("Error while evaluating value from environment:\n" +
"".join(traceback.format_exception(exc_type, exc_value, tb)))
ir.emit_error(loc, message)
# TODO: Add this to the python API.
_cext.emit_error(loc, message)
raise EmittedError(loc, message)
if result.type == PartialEvalType.NOT_EVALUATED:
self.abort("Unable to evaluate expression")
@property
def current_loc(self):
return self.ir_h.builder.current_loc
return self.ic.loc
def update_loc(self, ast_node):
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
ast_node.col_offset)
self.ic.set_file_line_col(self.filename, ast_node.lineno,
ast_node.col_offset)
def lookup_name(self, name) -> NameReference:
"""Lookup a name in the environment, requiring it to have evaluated."""
@ -74,7 +79,7 @@ class FunctionContext:
logging.debug("Map name({}) -> {}", name, ref)
return ref
def emit_const_value(self, py_value) -> ir.Value:
def emit_const_value(self, py_value) -> _ir.Value:
"""Codes a value as a constant, returning an ir Value."""
env = self.environment
result = env.code_py_value_as_const(py_value)
@ -83,7 +88,7 @@ class FunctionContext:
return result
def emit_partial_eval_result(self,
partial_result: PartialEvalResult) -> ir.Value:
partial_result: PartialEvalResult) -> _ir.Value:
"""Emits a partial eval result either as a direct IR value or a constant."""
self.check_partial_evaluated(partial_result)
if partial_result.type == PartialEvalType.YIELDS_IR_VALUE:
@ -134,17 +139,20 @@ class FunctionDefImporter(BaseNodeVisitor):
self._last_was_return = False
def import_body(self):
ir_h = self.fctx.ir_h
ic = self.fctx.ic
for ast_stmt in self.ast_fd.body:
self._last_was_return = False
logging.debug("STMT: {}", ast.dump(ast_stmt, include_attributes=True))
self.visit(ast_stmt)
if not self._last_was_return:
# Add a default terminator.
none_value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
none_cast = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
none_value).result
ir_h.return_op([none_cast])
none_value = basicpy_ops.SingletonOp(ic.none_type, loc=ic.loc,
ip=ic.ip).result
none_cast = basicpy_ops.UnknownCastOp(ic.unknown_type,
none_value,
loc=ic.loc,
ip=ic.ip).result
std_ops.ReturnOp([none_cast], loc=ic.loc, ip=ic.ip)
def visit_Assign(self, ast_node):
expr = ExpressionImporter(self.fctx)
@ -164,27 +172,27 @@ class FunctionDefImporter(BaseNodeVisitor):
"Cannot assign to '{}': Store not supported".format(name_ref))
def visit_Expr(self, ast_node):
ir_h = self.fctx.ir_h
_, ip = ir_h.basicpy_exec_op()
ic = self.fctx.ic
exec_ip = ic.basicpy_ExecOp()
# Evaluate the expression in the exec body.
orig_ip = ir_h.builder.insertion_point
ir_h.builder.insertion_point = ip
ic.push_ip(exec_ip)
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
ir_h.basicpy_exec_discard_op([expr.value])
ir_h.builder.insertion_point = orig_ip
basicpy_ops.ExecDiscardOp([expr.value], loc=ic.loc, ip=ic.ip)
ic.pop_ip()
def visit_Pass(self, ast_node):
pass
def visit_Return(self, ast_node):
ir_h = self.fctx.ir_h
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
expr.value).result
ir_h.return_op([casted])
self._last_was_return = True
ic = self.fctx.ic
with ic.loc, ic.ip:
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
casted = basicpy_ops.UnknownCastOp(ic.unknown_type, expr.value).result
std_ops.ReturnOp([casted])
self._last_was_return = True
class ExpressionImporter(BaseNodeVisitor):
@ -233,15 +241,20 @@ class ExpressionImporter(BaseNodeVisitor):
ast.dump(ast_node)))
def visit_BinOp(self, ast_node):
ir_h = self.fctx.ir_h
ic = self.fctx.ic
left = self.sub_evaluate(ast_node.left)
right = self.sub_evaluate(ast_node.right)
self.value = ir_h.basicpy_binary_expr_op(
ir_h.basicpy_UnknownType, left, right,
ast_node.op.__class__.__name__).result
self.value = basicpy_ops.BinaryExprOp(ic.unknown_type,
left,
right,
_ir.StringAttr.get(
ast_node.op.__class__.__name__,
context=ic.context),
ip=ic.ip,
loc=ic.loc).result
def visit_BoolOp(self, ast_node):
ir_h = self.fctx.ir_h
ic = self.fctx.ic
if isinstance(ast_node.op, ast.And):
return_first_true = False
elif isinstance(ast_node.op, ast.Or):
@ -255,25 +268,30 @@ class ExpressionImporter(BaseNodeVisitor):
next_value = self.sub_evaluate(next_node)
if not next_nodes:
return next_value
condition_value = ir_h.basicpy_as_i1_op(next_value).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
condition_value, True)
orig_ip = ir_h.builder.insertion_point
condition_value = basicpy_ops.AsI1Op(ic.i1_type, next_value,
ip=ic.ip).result
if_op, then_ip, else_ip = ic.scf_IfOp([ic.unknown_type], condition_value,
True)
# Short-circuit return case.
ir_h.builder.insertion_point = then_ip if return_first_true else else_ip
next_value_casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
next_value).result
ir_h.scf_yield_op([next_value_casted])
ic.push_ip(then_ip if return_first_true else else_ip)
next_value_casted = basicpy_ops.UnknownCastOp(ic.unknown_type,
next_value,
ip=ic.ip).result
ic.scf_YieldOp([next_value_casted])
ic.pop_ip()
# Nested evaluate next case.
ir_h.builder.insertion_point = else_ip if return_first_true else then_ip
ic.push_ip(else_ip if return_first_true else then_ip)
nested_value = emit_next(next_nodes)
nested_value_casted = next_value_casted = ir_h.basicpy_unknown_cast_op(
ir_h.basicpy_UnknownType, nested_value).result
ir_h.scf_yield_op([nested_value_casted])
ir_h.builder.insertion_point = orig_ip
nested_value_casted = next_value_casted = basicpy_ops.UnknownCastOp(
ic.unknown_type, nested_value, ip=ic.ip).result
ic.scf_YieldOp([nested_value_casted])
ic.pop_ip()
return if_op.result
self.value = emit_next(ast_node.values)
with ic.loc:
self.value = emit_next(ast_node.values)
def visit_Call(self, ast_node):
# Evaluate positional args.
@ -311,15 +329,23 @@ class ExpressionImporter(BaseNodeVisitor):
def visit_Compare(self, ast_node):
# Short-circuit comparison (degenerates to binary comparison when just
# two operands).
ir_h = self.fctx.ir_h
false_value = ir_h.basicpy_bool_constant_op(False).result
ic = self.fctx.ic
false_value = basicpy_ops.BoolConstantOp(ic.bool_type,
ic.i1_false,
ip=ic.ip,
loc=ic.loc).result
def emit_next(left_value, comparisons):
operation, right_node = comparisons[0]
comparisons = comparisons[1:]
right_value = self.sub_evaluate(right_node)
compare_result = ir_h.basicpy_binary_compare_op(
left_value, right_value, operation.__class__.__name__).result
compare_result = basicpy_ops.BinaryCompareOp(
ic.bool_type,
left_value,
right_value,
_ir.StringAttr.get(operation.__class__.__name__),
ip=ic.ip,
loc=ic.loc).result
# Terminate by yielding the final compare result.
if not comparisons:
return compare_result
@ -327,47 +353,56 @@ class ExpressionImporter(BaseNodeVisitor):
# Emit 'if' op and recurse. The if op takes an i1 (core dialect
# requirement) and returns a basicpy.BoolType. Since this is an 'and',
# all else clauses yield a false value.
compare_result_i1 = ir_h.basicpy_bool_cast_op(ir_h.i1_type,
compare_result).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_BoolType],
compare_result_i1, True)
orig_ip = ir_h.builder.insertion_point
compare_result_i1 = basicpy_ops.BoolCastOp(ic.i1_type,
compare_result,
ip=ic.ip,
loc=ic.loc).result
if_op, then_ip, else_ip = ic.scf_IfOp([ic.bool_type], compare_result_i1,
True)
# Build the else clause.
ir_h.builder.insertion_point = else_ip
ir_h.scf_yield_op([false_value])
ic.push_ip(else_ip)
ic.scf_YieldOp([false_value])
ic.pop_ip()
# Build the then clause.
ir_h.builder.insertion_point = then_ip
ic.push_ip(then_ip)
nested_result = emit_next(right_value, comparisons)
ir_h.scf_yield_op([nested_result])
ir_h.builder.insertion_point = orig_ip
ic.scf_YieldOp([nested_result])
ic.pop_ip()
return if_op.result
self.value = emit_next(self.sub_evaluate(ast_node.left),
list(zip(ast_node.ops, ast_node.comparators)))
def visit_IfExp(self, ast_node):
ir_h = self.fctx.ir_h
test_result = ir_h.basicpy_as_i1_op(self.sub_evaluate(ast_node.test)).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
test_result, True)
orig_ip = ir_h.builder.insertion_point
ic = self.fctx.ic
test_result = basicpy_ops.AsI1Op(ic.i1_type,
self.sub_evaluate(ast_node.test),
ip=ic.ip,
loc=ic.loc).result
if_op, then_ip, else_ip = ic.scf_IfOp([ic.unknown_type], test_result, True)
# Build the then clause
ir_h.builder.insertion_point = then_ip
ic.push_ip(then_ip)
then_result = self.sub_evaluate(ast_node.body)
ir_h.scf_yield_op([
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
then_result).result
ic.scf_YieldOp([
basicpy_ops.UnknownCastOp(ic.unknown_type,
then_result,
ip=ic.ip,
loc=ic.loc).result
])
# Build the then clause.
ir_h.builder.insertion_point = else_ip
orelse_result = self.sub_evaluate(ast_node.orelse)
ir_h.scf_yield_op([
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
orelse_result).result
])
ir_h.builder.insertion_point = orig_ip
ic.pop_ip()
# Build the then clause.
ic.push_ip(else_ip)
orelse_result = self.sub_evaluate(ast_node.orelse)
ic.scf_YieldOp([
basicpy_ops.UnknownCastOp(ic.unknown_type,
orelse_result,
ip=ic.ip,
loc=ic.loc).result
])
ic.pop_ip()
self.value = if_op.result
def visit_Name(self, ast_node):
@ -380,18 +415,20 @@ class ExpressionImporter(BaseNodeVisitor):
self.value = self.fctx.emit_partial_eval_result(pe_result)
def visit_UnaryOp(self, ast_node):
ir_h = self.fctx.ir_h
op = ast_node.op
operand_value = self.sub_evaluate(ast_node.operand)
if isinstance(op, ast.Not):
# Special handling for logical-not.
condition_value = ir_h.basicpy_as_i1_op(operand_value).result
true_value = ir_h.basicpy_bool_constant_op(True).result
false_value = ir_h.basicpy_bool_constant_op(False).result
self.value = ir_h.select_op(condition_value, false_value,
true_value).result
else:
self.fctx.abort("Unknown unary op %r", (ast.dump(op)))
ic = self.fctx.ic
with ic.ip, ic.loc:
op = ast_node.op
operand_value = self.sub_evaluate(ast_node.operand)
if isinstance(op, ast.Not):
# Special handling for logical-not.
condition_value = basicpy_ops.AsI1Op(ic.i1_type, operand_value).result
true_value = basicpy_ops.BoolConstantOp(ic.bool_type, ic.i1_true).result
false_value = basicpy_ops.BoolConstantOp(ic.bool_type,
ic.i1_false).result
self.value = std_ops.SelectOp(ic.bool_type, condition_value,
false_value, true_value).result
else:
self.fctx.abort("Unknown unary op %r", (ast.dump(op)))
if sys.version_info < (3, 8, 0):
# <3.8 breaks these out into separate AST classes.

View File

@ -6,15 +6,18 @@
from collections import namedtuple
from enum import Enum
import sys
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Tuple, Union
from mlir import ir as _ir
from _npcomp.mlir import ir
from .target import *
from ..utils.mlir_utils import *
__all__ = [
"Configuration",
"EmittedError",
"Environment",
"ImportContext",
"NameReference",
"NameResolver",
"PartialEvalHook",
@ -82,21 +85,18 @@ class NameReference:
super().__init__()
self.name = name
def load(self, env: "Environment",
ir_h: ir.DialectHelper) -> "PartialEvalResult":
def load(self, env: "Environment") -> "PartialEvalResult":
"""Loads the IR Value associated with the name.
The load may either be direct, returning an existing value or
side-effecting, causing a read from an external context.
Args:
ir_h: The dialect helper used to emit code.
Returns:
A partial evaluation result.
"""
return PartialEvalResult.not_evaluated()
def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper):
def store(self, env: "Environment", value: _ir.Value):
"""Stores a new value into the name.
A subsequent call to 'load' should yield the same value, subject to
@ -104,7 +104,6 @@ class NameReference:
Args:
value: The new value to store into the name.
ir_h: The dialect helper used to emit code.
Raises:
NotImplementedError if store is not supported for this name.
"""
@ -143,7 +142,7 @@ class ValueCoder:
__slots__ = []
def code_py_value_as_const(self, env: "Environment",
py_value) -> Union[_NotImplementedType, ir.Value]:
py_value) -> Union[_NotImplementedType, _ir.Value]:
return NotImplemented
@ -158,7 +157,7 @@ class ValueCoderChain(ValueCoder):
return "ValueCoderChain({})".format(self._sub_coders)
def code_py_value_as_const(self, env: "Environment",
py_value) -> Union[_NotImplementedType, ir.Value]:
py_value) -> Union[_NotImplementedType, _ir.Value]:
for sc in self._sub_coders:
result = sc.code_py_value_as_const(env, py_value)
if result is not NotImplemented:
@ -207,8 +206,8 @@ class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
return PartialEvalResult(PartialEvalType.YIELDS_LIVE_VALUE, live_value)
@staticmethod
def yields_ir_value(ir_value: ir.Value) -> "PartialEvalResult":
assert isinstance(ir_value, ir.Value)
def yields_ir_value(ir_value: _ir.Value) -> "PartialEvalResult":
assert isinstance(ir_value, _ir.Value)
return PartialEvalResult(PartialEvalType.YIELDS_IR_VALUE, ir_value)
@staticmethod
@ -247,7 +246,7 @@ class LiveValueRef:
"""Gets a named attribute from the live value."""
return PartialEvalResult.not_evaluated()
def resolve_call(self, env: "Environment", args: Sequence[ir.Value],
def resolve_call(self, env: "Environment", args: Sequence[_ir.Value],
keywords: Sequence[str]) -> PartialEvalResult:
"""Resolves a function call given 'args' and 'keywords'."""
return PartialEvalResult.not_evaluated()
@ -302,7 +301,6 @@ class Environment:
"""Instantiated configuration for emitting code in a specific context.
This brings together:
- The code generation context (ir_h)
- An instantiated target
- Delegating interfaces for other configuration objects.
@ -313,7 +311,7 @@ class Environment:
"""
__slots__ = [
"config",
"ir_h",
"ic",
"_name_resolvers",
"target",
]
@ -321,12 +319,12 @@ class Environment:
def __init__(self,
*,
config: Configuration,
ir_h: ir.DialectHelper,
ic: ImportContext,
name_resolvers: Sequence[NameResolver] = ()):
super().__init__()
self.config = config
self.ir_h = ir_h
self.target = config.target_factory(self.ir_h)
self.ic = ic
self.target = config.target_factory(ic)
self._name_resolvers = (tuple(name_resolvers) +
self.config.base_name_resolvers)
@ -341,5 +339,5 @@ class Environment:
return self.config.partial_eval_hook.partial_evaluate(py_value)
def code_py_value_as_const(self,
py_value) -> Union[_NotImplementedType, ir.Value]:
py_value) -> Union[_NotImplementedType, _ir.Value]:
return self.config.value_coder.code_py_value_as_const(self, py_value)

View File

@ -5,8 +5,7 @@
from typing import Optional
from _npcomp.mlir import ir
from mlir import ir as _ir
from .interfaces import *
__all__ = [
@ -36,7 +35,7 @@ class LocalNameReference(NameReference):
"Attempt to access local '{}' before assignment".format(self.name))
return PartialEvalResult.yields_ir_value(self._current_value)
def store(self, env: Environment, value: ir.Value):
def store(self, env: Environment, value: _ir.Value):
self._current_value = value
def __repr__(self):

View File

@ -52,9 +52,9 @@ class TemplateCallLiveValueRef(LiveValueRef):
kw_arg_names.append(kw_name)
linear_args.append(kw_value)
ir_h = env.ir_h
result_ir_value = ir_h.basicpy_func_template_call_op(
result_type=ir_h.basicpy_UnknownType,
ic = env.ic
result_ir_value = ic.basicpy_FuncTemplateCallOp(
result_type=ic.unknown_type,
callee_symbol=self.callee_name,
args=linear_args,
arg_names=kw_arg_names).result

View File

@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import *
from _npcomp.mlir import ir
from mlir import ir as _ir
from ..utils.mlir_utils import *
__all__ = [
"GenericTarget32",
@ -19,32 +21,23 @@ class Target:
target.
"""
__slots__ = [
"_mlir_helper",
"ic",
]
def __init__(self, mlir_helper: ir.DialectHelper):
super().__init__()
self._mlir_helper = mlir_helper
@property
def mlir_helper(self):
return self._mlir_helper
@property
def mlir_context(self):
return self._mlir_helper.context
def __init__(self, ic):
self.ic = ic
@property
def target_name(self) -> str:
return NotImplementedError()
@property
def impl_int_type(self) -> ir.Type:
def impl_int_type(self) -> _ir.Type:
"""Gets the default int type for the backend for the Python 'int' type."""
raise NotImplementedError()
@property
def impl_float_type(self) -> ir.Type:
def impl_float_type(self) -> _ir.Type:
"""Gets the implementation's type for the python 'float' type."""
raise NotImplementedError()
@ -57,14 +50,14 @@ class GenericTarget64(Target):
return "generic64"
@property
def impl_int_type(self) -> ir.Type:
def impl_int_type(self) -> _ir.Type:
"""Gets the default int type for the backend for the Python 'int' type."""
return self.mlir_helper.i64_type
return _ir.IntegerType.get_signless(64, context=self.ic.context)
@property
def impl_float_type(self) -> ir.Type:
def impl_float_type(self) -> _ir.Type:
"""Gets the implementation's type for the python 'float' type."""
return self.mlir_helper.f64_type
return _ir.F64Type.get(context=self.ic.context)
class GenericTarget32(Target):
@ -75,15 +68,15 @@ class GenericTarget32(Target):
return "generic32"
@property
def impl_int_type(self) -> ir.Type:
def impl_int_type(self) -> _ir.Type:
"""Gets the default int type for the backend for the Python 'int' type."""
return self.mlir_helper.i32_type
return _ir.IntegerType.get_signless(32, context=self.ic.context)
@property
def impl_float_type(self) -> ir.Type:
def impl_float_type(self) -> _ir.Type:
"""Gets the implementation's type for the python 'float' type."""
return self.mlir_helper.f32_type
return _ir.F32Type.get(context=self.ic.context)
# Factory for producing a target (matches the Target constructor).
TargetFactory = Callable[[ir.DialectHelper], Target]
TargetFactory = Callable[[ImportContext], Target]

View File

@ -25,7 +25,7 @@ def create_import_dump_decorator(*,
fe = ImportFrontend(config=config)
fe.import_global_function(f)
print("// -----")
print(fe.ir_module.to_asm())
print(fe.ir_module.operation.get_asm())
return f
def decorator(*args, expect_error=None):

View File

@ -5,7 +5,9 @@
from typing import Union
from _npcomp.mlir import ir
from mlir import ir as _ir
from mlir.dialects import std as std_ops
from npcomp.dialects import basicpy as basicpy_ops
from .interfaces import *
@ -21,27 +23,29 @@ class BuiltinsValueCoder(ValueCoder):
__slots__ = []
def code_py_value_as_const(self, env: Environment,
py_value) -> Union[_NotImplementedType, ir.Value]:
ir_h = env.ir_h
ir_c = ir_h.context
if py_value is True:
return ir_h.basicpy_bool_constant_op(True).result
elif py_value is False:
return ir_h.basicpy_bool_constant_op(False).result
elif py_value is None:
return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
elif isinstance(py_value, int):
ir_type = env.target.impl_int_type
ir_attr = ir_c.integer_attr(ir_type, py_value)
return ir_h.constant_op(ir_type, ir_attr).result
elif isinstance(py_value, float):
ir_type = env.target.impl_float_type
ir_attr = ir_c.float_attr(ir_type, py_value)
return ir_h.constant_op(ir_type, ir_attr).result
elif isinstance(py_value, str):
return ir_h.basicpy_str_constant_op(py_value).result
elif isinstance(py_value, bytes):
return ir_h.basicpy_bytes_constant_op(py_value).result
elif isinstance(py_value, type(...)):
return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
return NotImplemented
py_value) -> Union[_NotImplementedType, _ir.Value]:
ic = env.ic
with ic.loc, ic.ip:
if py_value is True:
return basicpy_ops.BoolConstantOp(ic.bool_type, ic.i1_true).result
elif py_value is False:
return basicpy_ops.BoolConstantOp(ic.bool_type, ic.i1_false).result
elif py_value is None:
return basicpy_ops.SingletonOp(ic.none_type).result
elif isinstance(py_value, int):
ir_type = env.target.impl_int_type
ir_attr = _ir.IntegerAttr.get(ir_type, py_value)
return std_ops.ConstantOp(ir_type, ir_attr).result
elif isinstance(py_value, float):
ir_type = env.target.impl_float_type
ir_attr = _ir.FloatAttr.get(ir_type, py_value)
return std_ops.ConstantOp(ir_type, ir_attr).result
elif isinstance(py_value, str):
return basicpy_ops.StrConstantOp(ic.str_type,
_ir.StringAttr.get(py_value)).result
elif isinstance(py_value, bytes):
return basicpy_ops.BytesConstantOp(ic.bytes_type,
_ir.StringAttr.get(py_value)).result
elif isinstance(py_value, type(...)):
return basicpy_ops.SingletonOp(ic.ellipsis_type).result
return NotImplemented

View File

@ -0,0 +1,167 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""General utilities for working with MLIR."""
from typing import Optional, Tuple
from mlir import ir as _ir
from npcomp import _cext
__all__ = [
"ImportContext",
]
class ImportContext:
"""Simple container for things that we update while importing.
This is also where we stash various helpers to work around awkward/missing
MLIR Python API features.
"""
__slots__ = [
"context",
"loc",
"module",
"_ip_stack",
# Cached types.
"unknown_type",
"bool_type",
"bytes_type",
"ellipsis_type",
"i1_type",
"index_type",
"none_type",
"str_type",
"unknown_array_type",
"unknown_tensor_type",
# Cached attributes.
"i1_true",
"i1_false",
]
def __init__(self, context: Optional[_ir.Context]):
self.context = _ir.Context() if not context else context
_cext.register_all_dialects(self.context)
self.loc = _ir.Location.unknown(context=self.context) # type: _ir.Location
self.module = None # type: Optional[_ir.Module]
self._ip_stack = []
# Cache some types and attributes.
with self.context:
# Types.
# TODO: Consolidate numpy.any_dtype and basicpy.UnknownType.
self.unknown_type = _ir.Type.parse("!basicpy.UnknownType")
self.bool_type = _ir.Type.parse("!basicpy.BoolType")
self.bytes_type = _ir.Type.parse("!basicpy.BytesType")
self.ellipsis_type = _ir.Type.parse("!basicpy.EllipsisType")
self.none_type = _ir.Type.parse("!basicpy.NoneType")
self.str_type = _ir.Type.parse("!basicpy.StrType")
self.i1_type = _ir.IntegerType.get_signless(1)
self.index_type = _ir.IndexType.get()
self.unknown_tensor_type = _ir.UnrankedTensorType.get(self.unknown_type,
loc=self.loc)
self.unknown_array_type = _cext.shaped_to_ndarray_type(
self.unknown_tensor_type)
# Attributes.
self.i1_true = _ir.IntegerAttr.get(self.i1_type, 1)
self.i1_false = _ir.IntegerAttr.get(self.i1_type, 0)
def set_file_line_col(self, file: str, line: int, col: int):
self.loc = _ir.Location.file(file, line, col, context=self.context)
def push_ip(self, new_ip: _ir.InsertionPoint):
self._ip_stack.append(new_ip)
def pop_ip(self):
assert self._ip_stack, "Mismatched push_ip/pop_ip: stack is empty on pop"
del self._ip_stack[-1]
@property
def ip(self):
assert self._ip_stack, "InsertionPoint requested but stack is empty"
return self._ip_stack[-1]
def insert_before_terminator(self, block: _ir.Block):
self.push_ip(_ir.InsertionPoint.at_block_terminator(block))
def insert_end_of_block(self, block: _ir.Block):
self.push_ip(_ir.InsertionPoint(block))
def FuncOp(self, name: str, func_type: _ir.Type,
create_entry_block: bool) -> Tuple[_ir.Operation, _ir.Block]:
"""Creates a |func| op.
TODO: This should really be in the MLIR API.
Returns:
(operation, entry_block)
"""
with self.context, self.loc:
attrs = {
"type": _ir.TypeAttr.get(func_type),
"sym_name": _ir.StringAttr.get(name),
}
op = _ir.Operation.create("func", regions=1, attributes=attrs, ip=self.ip)
body_region = op.regions[0]
entry_block = body_region.blocks.append(*func_type.inputs)
return op, entry_block
def basicpy_ExecOp(self):
"""Creates a basicpy.exec op.
Returns:
Insertion point to the body.
"""
op = _ir.Operation.create("basicpy.exec",
regions=1,
ip=self.ip,
loc=self.loc)
b = op.regions[0].blocks.append()
return _ir.InsertionPoint(b)
def basicpy_FuncTemplateCallOp(self, result_type, callee_symbol, args,
arg_names):
with self.loc, self.ip:
attributes = {
"callee":
_ir.FlatSymbolRefAttr.get(callee_symbol),
"arg_names":
_ir.ArrayAttr.get([_ir.StringAttr.get(n) for n in arg_names]),
}
op = _ir.Operation.create("basicpy.func_template_call",
results=[result_type],
operands=args,
attributes=attributes,
ip=self.ip)
return op
def scf_IfOp(self, results, condition: _ir.Value, with_else_region: bool):
"""Creates an SCF if op.
Returns:
(if_op, then_ip, else_ip) if with_else_region, otherwise (if_op, then_ip)
"""
op = _ir.Operation.create("scf.if",
results=results,
operands=[condition],
regions=2 if with_else_region else 1,
loc=self.loc,
ip=self.ip)
then_region = op.regions[0]
then_block = then_region.blocks.append()
if with_else_region:
else_region = op.regions[1]
else_block = else_region.blocks.append()
return op, _ir.InsertionPoint(then_block), _ir.InsertionPoint(else_block)
else:
return op, _ir.InsertionPoint(then_block)
def scf_YieldOp(self, operands):
return _ir.Operation.create("scf.yield",
operands=operands,
loc=self.loc,
ip=self.ip)

View File

@ -1,112 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from _npcomp.dialect import BasicpyDialectHelper as _BaseDialectHelper
from _npcomp.mlir import ir
__all__ = [
"DialectHelper",
]
class DialectHelper(_BaseDialectHelper):
r"""Dialect helper for the Basicpy dialect.
>>> c = ir.MLIRContext()
>>> h = DialectHelper(c, ir.OpBuilder(c))
Dialect Types:
>>> h.basicpy_NoneType
!basicpy.NoneType
>>> h.basicpy_EllipsisType
!basicpy.EllipsisType
>>> h.basicpy_SlotObject_type(
... "foobar", h.basicpy_NoneType, h.basicpy_NoneType)
!basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
singleton op:
>>> m = c.new_module()
>>> h.builder.insert_block_start(m.first_block)
>>> _ = h.basicpy_singleton_op(h.basicpy_NoneType)
>>> m.to_asm().strip()
'module {\n %0 = basicpy.singleton : !basicpy.NoneType\n}'
slot_object ops:
>>> m = c.new_module()
>>> h.builder.insert_block_start(m.first_block)
>>> v0 = h.basicpy_singleton_op(h.basicpy_NoneType).result
>>> slot_object = h.basicpy_slot_object_make_op("foobar", v0, v0).result
>>> _ = h.basicpy_slot_object_get_op(slot_object, 0)
>>> print(m.to_asm().strip())
module {
%0 = basicpy.singleton : !basicpy.NoneType
%1 = basicpy.slot_object_make(%0, %0) -> !basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
%2 = basicpy.slot_object_get %1[0] : !basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
}
"""
def basicpy_binary_expr_op(self, result_type, lhs, rhs, operation_name):
c = self.context
attrs = c.dictionary_attr({"operation": c.string_attr(operation_name)})
return self.op("basicpy.binary_expr", [result_type], [lhs, rhs], attrs)
def basicpy_bool_cast_op(self, result_type, value):
return self.op("basicpy.bool_cast", [result_type], [value])
def basicpy_bool_constant_op(self, value):
c = self.context
ival = 1 if value else 0
attrs = c.dictionary_attr({"value": c.integer_attr(self.i1_type, ival)})
return self.op("basicpy.bool_constant", [self.basicpy_BoolType], [], attrs)
def basicpy_bytes_constant_op(self, value):
c = self.context
attrs = c.dictionary_attr({"value": c.string_attr(value)})
return self.op("basicpy.bytes_constant", [self.basicpy_BytesType], [],
attrs)
def basicpy_binary_compare_op(self, lhs, rhs, operation_name):
c = self.context
attrs = c.dictionary_attr({"operation": c.string_attr(operation_name)})
return self.op("basicpy.binary_compare", [self.basicpy_BoolType],
[lhs, rhs], attrs)
def basicpy_singleton_op(self, singleton_type):
return self.op("basicpy.singleton", [singleton_type], [])
def basicpy_slot_object_make_op(self, class_name, *slot_values):
c = self.context
class_name_attr = c.string_attr(class_name)
object_type = self.basicpy_SlotObject_type(class_name,
*[v.type for v in slot_values])
attrs = c.dictionary_attr({"className": class_name_attr})
return self.op("basicpy.slot_object_make", [object_type], slot_values,
attrs)
def basicpy_str_constant_op(self, value):
c = self.context
attrs = c.dictionary_attr({"value": c.string_attr(value.encode("utf-8"))})
return self.op("basicpy.str_constant", [self.basicpy_StrType], [], attrs)
def basicpy_as_i1_op(self, value):
return self.op("basicpy.as_i1", [self.i1_type], [value])
def basicpy_unknown_cast_op(self, result_type, operand):
return self.op("basicpy.unknown_cast", [result_type], [operand])
def basicpy_func_template_call_op(self, result_type, callee_symbol, args,
arg_names):
"""Creates a basicpy.func_template_call op."""
c = self.context
attrs = c.dictionary_attr({
"callee": c.flat_symbol_ref_attr(callee_symbol),
"arg_names": c.array_attr([c.string_attr(n) for n in arg_names]),
})
return self.op("basicpy.func_template_call", [result_type], args, attrs)
if __name__ == "__main__":
import doctest
doctest.testmod()

View File

@ -1,80 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
from npcomp.dialect import Basicpy
from _npcomp.mlir import ir
__all__ = [
"load_builtin_module",
"DialectHelper",
]
class DialectHelper(Basicpy.DialectHelper):
r"""Dialect helper.
>>> c = ir.MLIRContext()
>>> h = DialectHelper(c, ir.OpBuilder(c))
DenseElementsAttrs:
>>> c.dense_elements_attr(np.asarray([1, 2, 3, 4], dtype=np.int32))
dense<[1, 2, 3, 4]> : tensor<4xsi32>
>>> c.dense_elements_attr(np.asarray([[1, 2], [3, 4]], dtype=np.int32))
dense<[[1, 2], [3, 4]]> : tensor<2x2xsi32>
>>> c.dense_elements_attr(np.asarray([[1., 2.], [3., 4.]]))
dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf64>
>>> c.dense_elements_attr(np.asarray([[1., 2.], [3., 4.]], dtype=np.float32))
dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>
Types:
>>> c = ir.MLIRContext()
>>> t = DialectHelper(c, ir.OpBuilder(c))
>>> t.numpy_any_dtype
!basicpy.UnknownType
>>> t.tensor_type(t.numpy_any_dtype, [1, 2, 3])
tensor<1x2x3x!basicpy.UnknownType>
>>> t.tensor_type(t.numpy_any_dtype)
tensor<*x!basicpy.UnknownType>
>>> t.tensor_type(t.numpy_any_dtype, [-1, 2])
tensor<?x2x!basicpy.UnknownType>
>>> t.tensor_type(t.f32_type)
tensor<*xf32>
>>> t.function_type([t.i32_type], [t.f32_type])
(i32) -> f32
>>> t.numpy_unknown_tensor_type
tensor<*x!basicpy.UnknownType>
"""
@property
def numpy_any_dtype(self):
return self.basicpy_UnknownType
@property
def numpy_unknown_tensor_type(self):
return self.tensor_type(self.basicpy_UnknownType)
@property
def unknown_array_type(self):
return self.numpy_NdArrayType(self.basicpy_UnknownType)
def numpy_builtin_ufunc_call_op(self, *args, qualified_name, result_type):
"""Creates a numpy.builtin_ufunc_call op."""
c = self.context
attrs = c.dictionary_attr({"qualified_name": c.string_attr(qualified_name)})
return self.op("numpy.builtin_ufunc_call", [result_type], args, attrs)
def numpy_narrow_op(self, result_type, operand):
"""Creates a numpy.narrow op."""
return self.op("numpy.narrow", [result_type], [operand])
def numpy_get_slice_op(self, result_type, array, *slice_elements):
return self.op("numpy.get_slice", [result_type],
[array] + list(slice_elements))
if __name__ == "__main__":
import doctest
doctest.testmod()

View File

@ -0,0 +1,15 @@
//===-- ATenBind.td - Aten dialect bind --------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_DIALECTS_ATEN_BIND
#define NPCOMP_PYTHON_DIALECTS_ATEN_BIND
include "mlir/Bindings/Python/Attributes.td"
include "npcomp/Dialect/ATen/IR/ATenOps.td"
#endif

View File

@ -0,0 +1,15 @@
//===-- BasicpyBind.td - Basicpy dialect bind --------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_DIALECTS_BASICPY_BIND
#define NPCOMP_PYTHON_DIALECTS_BASICPY_BIND
include "mlir/Bindings/Python/Attributes.td"
include "npcomp/Dialect/Basicpy/IR/BasicpyOps.td"
#endif

View File

@ -0,0 +1,12 @@
function(_add_dialect target td_file bind_name)
set(LLVM_TARGET_DEFINITIONS ${td_file})
mlir_tablegen("${bind_name}.py" -gen-python-op-bindings -bind-dialect=${bind_name})
add_public_tablegen_target(${target})
add_dependencies(NPCOMPNativePyExt ${target})
endfunction()
_add_dialect(NPCOMPPyDialectATen ATenBind.td "aten")
_add_dialect(NPCOMPPyDialectBasicpy BasicpyBind.td "basicpy")
_add_dialect(NPCOMPPyDialectNumpy NumpyBind.td "numpy")
_add_dialect(NPCOMPPyDialectTCF TCFBind.td "tcf")
_add_dialect(NPCOMPPyDialectTorch TorchBind.td "torch")

View File

@ -0,0 +1,15 @@
//===-- NumpyOps.td - Numpy dialect bind -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_DIALECTS_NUMPY_BIND
#define NPCOMP_PYTHON_DIALECTS_NUMPY_BIND
include "mlir/Bindings/Python/Attributes.td"
include "npcomp/Dialect/Numpy/IR/NumpyOps.td"
#endif

View File

@ -0,0 +1,15 @@
//===-- TCFBind.td - TCF dialect bind ----------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_DIALECTS_TCF_BIND
#define NPCOMP_PYTHON_DIALECTS_TCF_BIND
include "mlir/Bindings/Python/Attributes.td"
include "npcomp/Dialect/TCF/IR/TCFOps.td"
#endif

View File

@ -0,0 +1,15 @@
//===-- TorchBind.td - Torch dialect bind ------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_PYTHON_DIALECTS_TORCH_BIND
#define NPCOMP_PYTHON_DIALECTS_TORCH_BIND
include "mlir/Bindings/Python/Attributes.td"
include "npcomp/Dialect/Torch/IR/TorchOps.td"
#endif

View File

@ -0,0 +1,7 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Generated tablegen dialects expect to be able to find some symbols from
# the mlir.dialects package.
from mlir.dialects import _cext, _segmented_accessor, _equally_sized_accessor, _get_default_loc_context

View File

@ -7,6 +7,10 @@ import numpy as np
from collections import namedtuple
from enum import Enum
from mlir import ir as _ir
from npcomp.dialects import numpy as numpy_ops
class Protocol(Enum):
UFUNC = 1
@ -39,8 +43,7 @@ TraceInvocation.__new__.__defaults__ = (Protocol.ARRAY_FUNC, "__call__")
class EmissionRequest(
namedtuple("EmissionRequest",
["input_ssa_values", "dialect_helper", "extra"])):
namedtuple("EmissionRequest", ["input_ssa_values", "ic", "extra"])):
"""Represents the result of processing inputs from an invocation.
The `input_ssa_values` are mlir.ir.Value instances corresponding to
@ -173,11 +176,14 @@ class GenericCallUfuncEmitter(FuncEmitter):
return py_results[0]
def emit(self, request: EmissionRequest):
h = request.dialect_helper
op_result_type = h.tensor_type(h.numpy_any_dtype)
call_op = h.numpy_builtin_ufunc_call_op(*request.input_ssa_values,
qualified_name=self._ufunc_name,
result_type=op_result_type)
ic = request.ic
name_attr = _ir.StringAttr.get(self._ufunc_name)
result_type = ic.unknown_tensor_type
call_op = numpy_ops.BuiltinUfuncCallOp(result_type,
qualified_name=name_attr,
inputs=request.input_ssa_values,
loc=ic.loc,
ip=ic.ip)
return call_op.results
@ -219,9 +225,13 @@ class GenericArrayFuncEmitter(FuncEmitter):
return tuple(py_results)
def emit(self, request: EmissionRequest):
h = request.dialect_helper
op_result_types = [h.tensor_type(h.numpy_any_dtype)] * self._nresults
op = h.op(self._op_name, op_result_types, request.input_ssa_values)
ic = request.ic
op_result_types = [ic.unknown_tensor_type] * self._nresults
op = _ir.Operation.create(self._op_name,
results=op_result_types,
operands=request.input_ssa_values,
loc=ic.loc,
ip=ic.ip)
return op.results

View File

@ -3,28 +3,44 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import re
from typing import Iterable
from typing import Iterable, Optional
import numpy as np
from _npcomp.mlir import ir
from mlir import ir as _ir
from mlir.dialects import std as std_ops
from npcomp.dialect import Numpy
from npcomp.exporter import *
from npcomp.types import *
from npcomp.tracing.context import *
from npcomp.tracing.emitters import *
from npcomp import _cext
from npcomp.dialects import basicpy as basicpy_ops
from npcomp.dialects import numpy as numpy_ops
from ..exporter import *
from ..types import *
from ..compiler.utils.mlir_utils import *
from .context import *
from .emitters import *
class ModuleBuilder:
"""Builds an MLIR module by tracing functions."""
def __init__(self, mlir_context=None, emitter_registry=None):
self.context = mlir_context if mlir_context else ir.MLIRContext()
self.module = self.context.new_module()
self.helper = Numpy.DialectHelper(self.context, ir.OpBuilder(self.context))
__slots__ = [
"emitters",
"ic",
]
def __init__(self,
mlir_context: Optional[_ir.Context] = None,
emitter_registry=None):
ic = self.ic = ImportContext(mlir_context)
ic.module = _ir.Module.create(loc=ic.loc)
self.emitters = (emitter_registry
if emitter_registry else EmitterRegistry.create_default())
@property
def module(self):
return self.ic.module
def trace(self, *export_py_funcs: ExportPyFunction):
"""Traces exported py functions."""
for export_py_func in export_py_funcs:
@ -43,9 +59,7 @@ class FunctionTracer(TraceContext):
"_args_array_params",
"_f",
"_f_types",
"_helper",
"_mlir_m",
"_mlir_c",
"_ic",
"_python_args",
"_result_array_params",
"_traced_arrays",
@ -61,35 +75,39 @@ class FunctionTracer(TraceContext):
self._validate()
# Alias some parent members for convenience.
self._mlir_m = module_builder.module
self._mlir_c = module_builder.context
self._helper = module_builder.helper
self._ic = module_builder.ic
with self._ic.context:
# Extract ArrayParams for all args and results.
self._args_array_params = [
ArrayParams.from_constraints(arg.constraints)
for arg in self.epf.sig.args
]
self._python_args = [None] * len(self._args_array_params)
self._result_array_params = ArrayParams.from_constraints(
self.epf.sig.result.constraints)
# Extract ArrayParams for all args and results.
self._args_array_params = [
ArrayParams.from_constraints(arg.constraints)
for arg in self.epf.sig.args
]
self._python_args = [None] * len(self._args_array_params)
self._result_array_params = ArrayParams.from_constraints(
self.epf.sig.result.constraints)
# Create the MLIR function.
self._f, self._f_types = self._create_mlir_function()
self._create_trace_roots()
# Create the MLIR function.
self._f, self._f_types = self._create_mlir_function()
self._create_trace_roots()
@property
def entry_block(self) -> _ir.Block:
return self._f.regions[0].blocks[0]
def trace(self):
# Invoke the python function with placeholders.
# TODO: More sophisticated signature merging
# TODO: Multiple results
# TODO: Error reporting
h = self._helper
py_results = (self.epf.pyfunc(*self._python_args),)
if len(py_results) != len(self._f_types):
raise TracingError("Traced function returned != %d results: %r" % (
len(self._f_types),
py_results,
))
ic = self._ic
ic.insert_end_of_block(self.entry_block)
with ic.context:
py_results = (self.epf.pyfunc(*self._python_args),)
if len(py_results) != len(self._f_types):
raise TracingError("Traced function returned != %d results: %r" % (
len(self._f_types),
py_results,
))
# Narrow all results to the declared return types.
return_operands = []
@ -97,8 +115,12 @@ class FunctionTracer(TraceContext):
mlir_result = self.get_traced_array_value(py_result)
# narrow to declared result type.
return_operands.extend(
h.numpy_narrow_op(mlir_result_type, mlir_result).results)
h.return_op(return_operands)
numpy_ops.NarrowOp(mlir_result_type,
mlir_result,
loc=ic.loc,
ip=ic.ip).results)
std_ops.ReturnOp(return_operands, loc=ic.loc, ip=ic.ip)
ic.pop_ip()
def set_traced_array(self, traced_array, value):
"""Sets the current SSA value for a traced_array."""
@ -117,15 +139,18 @@ class FunctionTracer(TraceContext):
return traced_value
def _get_external_array_value(self, external_array):
h = self._helper
ic = self._ic
if not isinstance(external_array, np.ndarray):
raise TracingError("Expected ndarray but got: %r" % (external_array,))
found_it = self._external_arrays.get(id(external_array))
if found_it:
return found_it[1]
# Import it.
dense_attr = h.context.dense_elements_attr(external_array)
const_value = h.constant_op(dense_attr.type, dense_attr).result
dense_attr = _ir.DenseElementsAttr.get(external_array, context=ic.context)
const_value = std_ops.ConstantOp(dense_attr.type,
dense_attr,
loc=ic.loc,
ip=ic.ip).result
self._external_arrays[id(external_array)] = (external_array, const_value)
return const_value
@ -138,28 +163,24 @@ class FunctionTracer(TraceContext):
(self.epf.sig.result,))
def _create_mlir_function(self):
mlir_c = self._mlir_c
mlir_m = self._mlir_m
h = self._helper
ic = self._ic
epf = self.epf
f_args = [
mlir_c.parse_type(ap.mlir_tensor_type_asm)
_ir.Type.parse(ap.mlir_tensor_type_asm)
for ap in self._args_array_params
]
f_types = [
mlir_c.parse_type(self._result_array_params.mlir_tensor_type_asm)
]
h.builder.insert_before_terminator(mlir_m.first_block)
f_type = h.function_type(f_args, f_types)
f = h.func_op(epf.__name__, f_type, create_entry_block=True)
f_types = [_ir.Type.parse(self._result_array_params.mlir_tensor_type_asm)]
ic.insert_before_terminator(ic.module.body)
f_type = _ir.FunctionType.get(f_args, f_types)
f, _ = ic.FuncOp(epf.__name__, f_type, create_entry_block=True)
return f, f_types
def _create_trace_roots(self):
entry_block = self._f.first_block
entry_block = self.entry_block
for index, ap in enumerate(self._args_array_params):
if ap is not None:
ta = TracedArray(self)
self.set_traced_array(ta, entry_block.args[index])
self.set_traced_array(ta, entry_block.arguments[index])
self._python_args[index] = ta
def _resolve_input_ssa_values(self, trace_values: Iterable[TraceValue]):
@ -190,9 +211,7 @@ class FunctionTracer(TraceContext):
def _emit_invocation(self, emitter: FuncEmitter, invocation: TraceInvocation):
tv_map = emitter.map_invocation(invocation)
input_ssa_values = self._resolve_input_ssa_values(tv_map.input_trace_values)
request = EmissionRequest(input_ssa_values,
dialect_helper=self._helper,
extra=tv_map.extra)
request = EmissionRequest(input_ssa_values, ic=self._ic, extra=tv_map.extra)
result_ssa_values = emitter.emit(request)
py_values = self._resolve_result_py_values(tv_map.result_trace_value_types,
result_ssa_values)
@ -213,14 +232,18 @@ class FunctionTracer(TraceContext):
return self._emit_invocation(emitter, invocation)
def _emit_slice_value(self, slice_element):
h = self._helper
ic = self._ic
if slice_element == None:
return h.basicpy_singleton_op(h.basicpy_NoneType).result
return basicpy_ops.SingletonOp(ic.none_type, loc=ic.loc, ip=ic.ip).result
elif slice_element == Ellipsis:
return h.basicpy_singleton_op(h.basicpy_EllipsisType).result
return basicpy_ops.SingletonOp(ic.ellipsis_type, loc=ic.loc,
ip=ic.ip).result
elif isinstance(slice_element, int):
return h.constant_op(h.index_type,
h.context.index_attr(slice_element)).result
return std_ops.ConstantOp(ic.index_type,
_ir.IntegerAttr.get(ic.index_type,
slice_element),
loc=ic.loc,
ip=ic.ip).result
elif isinstance(slice_element, slice):
return self._emit_slice_object(slice_element)
else:
@ -229,29 +252,40 @@ class FunctionTracer(TraceContext):
"TODO: Slicing with generic arrays not yet implemented")
def _emit_slice_object(self, slice_object: slice):
h = self._helper
ic = self._ic
def emit_index(index):
if index is None:
return h.basicpy_singleton_op(h.basicpy_NoneType).result
return basicpy_ops.SingletonOp(ic.none_type, loc=ic.loc,
ip=ic.ip).result
else:
return h.constant_op(h.index_type,
h.context.index_attr(int(index))).result
return std_ops.ConstantOp(ic.index_type,
_ir.IntegerAttr.get(ic.index_type,
int(index)),
loc=ic.loc,
ip=ic.ip).result
start = emit_index(slice_object.start)
stop = emit_index(slice_object.stop)
step = emit_index(slice_object.step)
return h.basicpy_slot_object_make_op("slice", start, stop, step).result
result_type = _cext.slot_object_type(ic.context, "slice",
[start.type, stop.type, step.type])
return basicpy_ops.SlotObjectMakeOp(result_type, [start, stop, step],
loc=ic.loc,
ip=ic.ip).result
def _handle_array_getitem(self, array, key):
h = self._helper
ic = self._ic
array_value = self.get_traced_array_value(array)
# Array slicing is always based on a tuple.
slice_tuple = key if isinstance(key, tuple) else (key,)
# Resolve and emit each slice element.
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
result_value = h.numpy_get_slice_op(h.unknown_array_type, array_value,
*slice_values).result
result_value = numpy_ops.GetSliceOp(ic.unknown_array_type,
array_value,
slice_values,
loc=ic.loc,
ip=ic.ip).result
result_array = TracedArray(self)
self.set_traced_array(result_array, result_value)
return result_array

View File

@ -1,23 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
bias = np.random.uniform(size=(4,)).astype(np.float32)
def constants(a: np.ndarray) -> np.ndarray:
return np.dot(a, weights) + bias
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.constants = constants
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.constants)
print(mb.module.to_asm())

View File

@ -1,35 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
return a * b + a + b
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.simple_mul = simple_mul
exp.simple_mul.sig.args["a"] += Shape(1, 4)
exp.simple_mul.sig.args["a"] += DynamicDim(0)
exp.simple_mul.sig.args["a"] += DType(np.float32)
exp.simple_mul.sig.args["b"] += Shape(1)
exp.simple_mul.sig.args["b"] += DType(np.float32)
exp.simple_mul.sig.result += Shape(1, 4)
exp.simple_mul.sig.result += DynamicDim(0)
exp.simple_mul.sig.result += DType(np.float32)
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.simple_mul)
# CHECK: func @simple_mul(%arg0: tensor<?x4xf32>, %arg1: tensor<1xf32>) -> tensor<?x4xf32> {
# CHECK: %0 = numpy.ufunc_call @numpy.multiply(%arg0, %arg1) : (tensor<?x4xf32>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype>
# CHECK: %1 = numpy.ufunc_call @numpy.add(%0, %arg0) : (tensor<*x!numpy.any_dtype>, tensor<?x4xf32>) -> tensor<*x!numpy.any_dtype>
# CHECK: %2 = numpy.ufunc_call @numpy.add(%1, %arg1) : (tensor<*x!numpy.any_dtype>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype>
# CHECK: %3 = numpy.narrow %2 : (tensor<*x!numpy.any_dtype>) -> tensor<?x4xf32>
# CHECK: return %3 : tensor<?x4xf32>
# CHECK: }
print(mb.module.to_asm())

View File

@ -1,20 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
def slice_array1(a: np.ndarray) -> np.ndarray:
return a[1, 2:10:2, 3:4, ..., :, 0]
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.slice_array1 = slice_array1
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.slice_array1)
print(mb.module.to_asm())

View File

@ -1,25 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
def transpose_attribute(a: np.ndarray) -> np.ndarray:
return a.T
def transpose(a: np.ndarray) -> np.ndarray:
return np.transpose(a)
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.transpose_attribute = transpose_attribute
exp.transpose = transpose
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.transpose_attribute, exp.transpose)
print(mb.module.to_asm())

View File

@ -8,9 +8,6 @@ from npcomp.compiler.numpy import test_config
from npcomp.compiler.numpy.target import *
from npcomp.compiler.utils import logging
# TODO: This should all exist in a high level API somewhere.
from _npcomp import mlir
logging.enable()

View File

@ -1,43 +0,0 @@
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Test for the MLIR IR Python bindings.
TODO: These tests were just for bootstrapping and are not authoritative at this
point.
"""
from _npcomp.mlir import ir
c = ir.MLIRContext()
# CHECK-LABEL: module @parseSuccess
m = c.parse_asm(r"""
module @parseSuccess {
func @f() {
return
}
}
""")
# CHECK: func @f
print(m.to_asm())
# CHECK: OP NAME: module
print("OP NAME:", m.name)
# CHECK: NUM_REGIONS: 1
print("NUM_REGIONS:", m.num_regions)
region = m.region(0)
# CHECK: CONTAINED OP: func
# CHECK: CONTAINED OP: module_terminator
for block in region.blocks:
for op in block.operations:
print("CONTAINED OP:", op.name)
# CHECK-LABEL: PARSE_FAILURE
print("PARSE_FAILURE")
try:
m = c.parse_asm("{{ILLEGAL SYNTAX}}")
except ValueError as e:
# CHECK: [ERROR]: expected operation name in quotes
print(e)

View File

@ -1,40 +0,0 @@
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Test for the MLIR Pass Python bindings"""
from _npcomp.mlir import ir
from _npcomp.mlir import passes
c = ir.MLIRContext()
pm = passes.PassManager(c)
# CHECK-LABEL: module @parseSuccess
m = c.parse_asm(r"""
module @parseSuccess {
func @notUsed() attributes { sym_visibility = "private" }
func @f() {
return
}
}
""")
# CHECK: func private @notUsed
# CHECK: func @f
print(m.to_asm())
# CHECK: PASS COUNT: 0
print("PASS COUNT:", len(pm))
pm.addPassPipelines("canonicalize", "symbol-dce")
# Note: not checking the actual count since these may expand to more than
# two passes.
# CHECK: PASS COUNT:
print("PASS COUNT:", len(pm))
# CHECK: PASSES: canonicalize, symbol-dce
print("PASSES:", str(pm))
pm.run(m)
print(m.to_asm())
# CHECK-NOT: func @notUsed

View File

@ -0,0 +1,34 @@
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
bias = np.random.uniform(size=(4,)).astype(np.float32)
def constants(a: np.ndarray) -> np.ndarray:
return np.dot(a, weights) + bias
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.constants = constants
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.constants)
# CHECK-LABEL: func @constants(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
# CHECK: %[[VAL_1:.*]] = constant dense<{{.*}}> : tensor<16x4xf32>
# CHECK: %[[VAL_2:.*]] = numpy.dot %[[VAL_0]], %[[VAL_1]] : (tensor<*x!numpy.any_dtype>, tensor<16x4xf32>) -> tensor<*x!basicpy.UnknownType>
# CHECK: %[[VAL_3:.*]] = constant dense<{{.*}}> : tensor<4xf32>
# CHECK: %[[VAL_4:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[VAL_2]], %[[VAL_3]]) : (tensor<*x!basicpy.UnknownType>, tensor<4xf32>) -> tensor<*x!basicpy.UnknownType>
# CHECK: %[[VAL_5:.*]] = numpy.narrow %[[VAL_4]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
# CHECK: return %[[VAL_5]] : tensor<*x!numpy.any_dtype>
# CHECK: }
print(mb.module)

View File

@ -1,3 +1,5 @@
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@ -25,4 +27,12 @@ exp.dot2d.sig.result += DType(np.float32)
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.dot2d)
print(mb.module.to_asm())
# CHECK-LABEL: func @dot2d(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<?x16xf32>,
# CHECK-SAME: %[[VAL_1:.*]]: tensor<16x32xf32>) -> tensor<?x32xf32> {
# CHECK: %[[VAL_2:.*]] = numpy.dot %[[VAL_0]], %[[VAL_1]] : (tensor<?x16xf32>, tensor<16x32xf32>) -> tensor<*x!basicpy.UnknownType>
# CHECK: %[[VAL_3:.*]] = numpy.narrow %[[VAL_2]] : (tensor<*x!basicpy.UnknownType>) -> tensor<?x32xf32>
# CHECK: return %[[VAL_3]] : tensor<?x32xf32>
# CHECK: }
print(mb.module)

View File

@ -30,6 +30,7 @@ exp.simple_mul.sig.result += DType(np.float32)
mb = ModuleBuilder()
mb.trace(exp.simple_mul)
# This test exercises the full tracing path and incidentally checks the ops.
# CHECK: func @simple_mul(%arg0: tensor<?x4xf32>, %arg1: tensor<1xf32>) -> tensor<?x4xf32> {
# CHECK: %0 = numpy.builtin_ufunc_call<"numpy.multiply"> (%arg0, %arg1) : (tensor<?x4xf32>, tensor<1xf32>) -> tensor<*x!basicpy.UnknownType>
# CHECK: %1 = numpy.builtin_ufunc_call<"numpy.add"> (%0, %arg0) : (tensor<*x!basicpy.UnknownType>, tensor<?x4xf32>) -> tensor<*x!basicpy.UnknownType>
@ -37,4 +38,4 @@ mb.trace(exp.simple_mul)
# CHECK: %3 = numpy.narrow %2 : (tensor<*x!basicpy.UnknownType>) -> tensor<?x4xf32>
# CHECK: return %3 : tensor<?x4xf32>
# CHECK: }
print(mb.module.to_asm())
print(str(mb.module))

View File

@ -0,0 +1,48 @@
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
def slice_array1(a: np.ndarray) -> np.ndarray:
return a[1, 2:10:2, 3:4, ..., :, 0]
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.slice_array1 = slice_array1
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.slice_array1)
# TODO: The numpy.get_slice op emission should be analyzed: it probably
# needs to both accept and produce either arrays or tensors and the following
# narrow should do likewise.
# CHECK-LABEL: func @slice_array1(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
# CHECK: %[[VAL_1:.*]] = constant 1 : index
# CHECK: %[[VAL_2:.*]] = constant 2 : index
# CHECK: %[[VAL_3:.*]] = constant 10 : index
# CHECK: %[[VAL_4:.*]] = constant 2 : index
# CHECK: %[[VAL_5:.*]] = basicpy.slot_object_make(%[[VAL_2]], %[[VAL_3]], %[[VAL_4]]) -> !basicpy.SlotObject<slice, index, index, index>
# CHECK: %[[VAL_6:.*]] = constant 3 : index
# CHECK: %[[VAL_7:.*]] = constant 4 : index
# CHECK: %[[VAL_8:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[VAL_9:.*]] = basicpy.slot_object_make(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) -> !basicpy.SlotObject<slice, index, index, !basicpy.NoneType>
# CHECK: %[[VAL_10:.*]] = basicpy.singleton : !basicpy.EllipsisType
# CHECK: %[[VAL_11:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[VAL_12:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[VAL_13:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[VAL_14:.*]] = basicpy.slot_object_make(%[[VAL_11]], %[[VAL_12]], %[[VAL_13]]) -> !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>
# CHECK: %[[VAL_15:.*]] = constant 0 : index
# CHECK: %[[VAL_16:.*]] = numpy.get_slice %[[VAL_0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_9]], %[[VAL_10]], %[[VAL_14]], %[[VAL_15]] : (tensor<*x!numpy.any_dtype>, index, !basicpy.SlotObject<slice, index, index, index>, !basicpy.SlotObject<slice, index, index, !basicpy.NoneType>, !basicpy.EllipsisType, !basicpy.SlotObject<slice, !basicpy.NoneType, !basicpy.NoneType, !basicpy.NoneType>, index) -> !numpy.ndarray<*:?>
# CHECK: %[[VAL_17:.*]] = numpy.narrow %[[VAL_16]] : (!numpy.ndarray<*:?>) -> tensor<*x!numpy.any_dtype>
# CHECK: return %[[VAL_17]] : tensor<*x!numpy.any_dtype>
# CHECK: }
print(mb.module)

View File

@ -0,0 +1,42 @@
# RUN: %PYTHON %s | FileCheck %s --dump-input=fail
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
def transpose_attribute(a: np.ndarray) -> np.ndarray:
return a.T
def transpose(a: np.ndarray) -> np.ndarray:
return np.transpose(a)
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.transpose_attribute = transpose_attribute
exp.transpose = transpose
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.transpose_attribute, exp.transpose)
# TODO: Consolidate any_dtype -> UnknownType.
# CHECK-LABEL: func @transpose_attribute(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
# CHECK: %[[VAL_1:.*]] = numpy.transpose %[[VAL_0]] : (tensor<*x!numpy.any_dtype>) -> tensor<*x!basicpy.UnknownType>
# CHECK: %[[VAL_2:.*]] = numpy.narrow %[[VAL_1]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
# CHECK: return %[[VAL_2]] : tensor<*x!numpy.any_dtype>
# CHECK: }
# CHECK-LABEL: func @transpose(
# CHECK-SAME: %[[VAL_0:.*]]: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
# CHECK: %[[VAL_1:.*]] = numpy.transpose %[[VAL_0]] : (tensor<*x!numpy.any_dtype>) -> tensor<*x!basicpy.UnknownType>
# CHECK: %[[VAL_2:.*]] = numpy.narrow %[[VAL_1]] : (tensor<*x!basicpy.UnknownType>) -> tensor<*x!numpy.any_dtype>
# CHECK: return %[[VAL_2]] : tensor<*x!numpy.any_dtype>
# CHECK: }
print(mb.module)

View File

@ -24,8 +24,6 @@ def run_doctest(mod):
TEST_MODULES = (
"npcomp.compiler.numpy.py_value_utils",
"npcomp.dialect.Basicpy",
"npcomp.dialect.Numpy",
"npcomp.tracing.context",
"npcomp.tracing.emitters",
"npcomp.tracing.mlir_trace",