mirror of https://github.com/llvm/torch-mlir
178 lines
5.1 KiB
C
178 lines
5.1 KiB
C
|
//===- MlirIr.h - MLIR IR Bindings ----------------------------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#ifndef NPCOMP_PYTHON_NATIVE_MLIR_IR_H
|
||
|
#define NPCOMP_PYTHON_NATIVE_MLIR_IR_H
|
||
|
|
||
|
#include "PybindUtils.h"
|
||
|
|
||
|
#include "mlir/IR/Block.h"
|
||
|
#include "mlir/IR/Builders.h"
|
||
|
#include "mlir/IR/MLIRContext.h"
|
||
|
#include "mlir/IR/Module.h"
|
||
|
#include "mlir/IR/Operation.h"
|
||
|
#include "mlir/IR/Region.h"
|
||
|
#include "mlir/IR/SymbolTable.h"
|
||
|
|
||
|
namespace mlir {
|
||
|
|
||
|
struct PyContext;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Utility types
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
template <typename ListTy, typename ItemWrapperTy> class PyIpListWrapper {
|
||
|
public:
|
||
|
using ThisTy = PyIpListWrapper<ListTy, ItemWrapperTy>;
|
||
|
static void bind(py::module m, const char *className);
|
||
|
PyIpListWrapper(ListTy &list) : list(list) {}
|
||
|
|
||
|
private:
|
||
|
ListTy &list;
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Wrapper types
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Wrapper around an Operation*.
|
||
|
struct PyBaseOperation {
|
||
|
virtual ~PyBaseOperation();
|
||
|
static void bind(py::module m);
|
||
|
virtual Operation *getOperation() = 0;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around Module, capturing a PyContext reference.
|
||
|
struct PyModuleOp : PyBaseOperation {
|
||
|
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
|
||
|
: context(context), moduleOp(moduleOp) {
|
||
|
assert(moduleOp);
|
||
|
}
|
||
|
~PyModuleOp();
|
||
|
static void bind(py::module m);
|
||
|
Operation *getOperation() override;
|
||
|
std::string toAsm(bool enableDebugInfo, bool prettyForm,
|
||
|
int64_t largeElementLimit);
|
||
|
|
||
|
std::shared_ptr<PyContext> context;
|
||
|
ModuleOp moduleOp;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around an Operation*.
|
||
|
struct PyOperationRef : PyBaseOperation {
|
||
|
PyOperationRef(Operation *operation) : operation(operation) {
|
||
|
assert(operation);
|
||
|
}
|
||
|
PyOperationRef(Operation &operation) : operation(&operation) {}
|
||
|
~PyOperationRef();
|
||
|
static void bind(py::module m);
|
||
|
Operation *getOperation() override;
|
||
|
|
||
|
Operation *operation;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around SymbolTable.
|
||
|
struct PySymbolTable {
|
||
|
PySymbolTable(SymbolTable &symbolTable) : symbolTable(symbolTable) {}
|
||
|
static void bind(py::module m);
|
||
|
SymbolTable &symbolTable;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around Value.
|
||
|
struct PyValue {
|
||
|
PyValue(Value value) : value(value) { assert(value); }
|
||
|
static void bind(py::module m);
|
||
|
operator Value() { return value; }
|
||
|
Value value;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around Attribute.
|
||
|
struct PyAttribute {
|
||
|
PyAttribute(Attribute attr) : attr(attr) { assert(attr); }
|
||
|
static void bind(py::module m);
|
||
|
Attribute attr;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around MLIRContext.
|
||
|
struct PyContext : std::enable_shared_from_this<PyContext> {
|
||
|
static void bind(py::module m);
|
||
|
PyModuleOp parseAsm(const std::string &asm_text);
|
||
|
MLIRContext context;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around a Block&.
|
||
|
struct PyBlockRef {
|
||
|
PyBlockRef(Block &block) : block(block) {}
|
||
|
static void bind(py::module m);
|
||
|
Block █
|
||
|
};
|
||
|
|
||
|
/// Wrapper around a Region&.
|
||
|
struct PyRegionRef {
|
||
|
PyRegionRef(Region ®ion) : region(region) {}
|
||
|
static void bind(py::module m);
|
||
|
Region ®ion;
|
||
|
};
|
||
|
|
||
|
struct PyType {
|
||
|
PyType() = default;
|
||
|
PyType(Type type) : type(type) {}
|
||
|
static void bind(py::module m);
|
||
|
operator Type() { return type; }
|
||
|
Type type;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around an OpBuilder reference.
|
||
|
/// This class is inherently dangerous because it does not track ownership
|
||
|
/// of IR objects that it may be operating on and incorrect usage can cause
|
||
|
/// memory access errors, just as it can in C++. It is intended for use by
|
||
|
/// higher level constructs that are specifically coded to satisfy object
|
||
|
/// lifetime needs.
|
||
|
class PyBaseOpBuilder {
|
||
|
public:
|
||
|
virtual ~PyBaseOpBuilder();
|
||
|
static void bind(py::module m);
|
||
|
virtual OpBuilder &getBuilder(bool requirePosition = false) = 0;
|
||
|
};
|
||
|
|
||
|
/// Wrapper around an instance of an OpBuilder.
|
||
|
class PyOpBuilder : public PyBaseOpBuilder {
|
||
|
public:
|
||
|
PyOpBuilder(PyContext &context) : builder(&context.context) {}
|
||
|
~PyOpBuilder() override;
|
||
|
static void bind(py::module m);
|
||
|
OpBuilder &getBuilder(bool requirePosition = false) override;
|
||
|
|
||
|
private:
|
||
|
OpBuilder builder;
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Custom types
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Helper for creating (possibly dialect specific) IR objects. This class
|
||
|
/// is intended to be subclassed on the Python side (possibly with multiple
|
||
|
/// inheritance) to provide Python level APIs for custom dialects. The base
|
||
|
/// class contains helpers for std types and ops.
|
||
|
class PyDialectHelper {
|
||
|
public:
|
||
|
PyDialectHelper(std::shared_ptr<PyContext> context)
|
||
|
: pyOpBuilder(*context), context(std::move(context)) {}
|
||
|
static void bind(py::module m);
|
||
|
|
||
|
protected:
|
||
|
PyOpBuilder pyOpBuilder;
|
||
|
std::shared_ptr<PyContext> context;
|
||
|
};
|
||
|
|
||
|
} // namespace mlir
|
||
|
|
||
|
#endif // NPCOMP_PYTHON_NATIVE_MLIR_IR_H
|