Add enough python bindings to build functions and ufunc calls.

pull/1/head
Stella Laurenzo 2020-05-01 18:44:06 -07:00
parent ba0c96b51a
commit 78a8e6ec9e
3 changed files with 473 additions and 22 deletions

View File

@ -2,14 +2,72 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from npcomp.native.mlir.ir import *
from npcomp.native.mlir import ir
__all__ = [
"load_builtin_module",
"Types",
]
def load_builtin_module(context = None):
class Ops(ir.Ops):
r"""Dialect ops.
>>> c = ir.MLIRContext()
>>> t = Types(c)
>>> m = c.new_module()
>>> tensor_type = t.tensor(t.f32)
>>> ops = Ops(c)
>>> ops.builder.insert_block_start(m.first_block)
>>> f = ops.func_op("foobar", t.function(
... [tensor_type, tensor_type], [tensor_type]),
... create_entry_block=True)
>>> uf = ops.numpy_ufunc_call("numpy.add", tensor_type,
... *f.first_block.args)
>>> _ = ops.return_op(uf.results)
>>> print(m.to_asm())
<BLANKLINE>
<BLANKLINE>
module {
func @foobar(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = numpy.ufunc_call @numpy.add(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
}
"""
def numpy_ufunc_call(self, callee_symbol, result_type, *args):
"""Creates a numpy.ufunc_call op."""
c = self.context
attrs = c.dictionary_attr({
"ufunc_ref": c.flat_symbol_ref_attr(callee_symbol)
})
return self.op("numpy.ufunc_call", [result_type], args, attrs)
class Types(ir.Types):
"""Container/factory for dialect types.
>>> t = Types(ir.MLIRContext())
>>> t.any_dtype
!numpy.any_dtype
>>> t.tensor(t.any_dtype, [1, 2, 3])
tensor<1x2x3x!numpy.any_dtype>
>>> t.tensor(t.any_dtype)
tensor<*x!numpy.any_dtype>
>>> t.tensor(t.any_dtype, [-1, 2])
tensor<?x2x!numpy.any_dtype>
>>> t.tensor(t.f32)
tensor<*xf32>
>>> t.function([t.i32], [t.f32])
(i32) -> f32
"""
def __init__(self, context):
super().__init__(context)
self.any_dtype = context.parse_type("!numpy.any_dtype")
def load_builtin_module(context=None):
"""Loads a module populated with numpy built-ins.
This is not a long-term solution but overcomes some bootstrapping
@ -17,15 +75,17 @@ def load_builtin_module(context = None):
>>> m = load_builtin_module()
>>> op = m.region(0).blocks.front.operations.front
>>> print(op.name)
numpy.generic_ufunc
>>> op.is_registered
True
>>> op.name
'numpy.generic_ufunc'
Args:
context: The MLIRContext to use (None to create a new one).
Returns:
A ModuleOp.
"""
if context is None: context = MLIRContext()
if context is None: context = ir.MLIRContext()
return context.parse_asm(_BUILTIN_MODULE_ASM)

View File

@ -8,8 +8,12 @@
#include "mlir_ir.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
@ -112,12 +116,184 @@ private:
mlir::DiagnosticEngine::HandlerID handler_id;
};
//===----------------------------------------------------------------------===//
// Python only classes
//===----------------------------------------------------------------------===//
class PyOps {
public:
PyOps(std::shared_ptr<PyContext> context)
: pyOpBuilder(*context), context(std::move(context)) {}
static void bind(py::module m) {
py::class_<PyOps>(m, "Ops")
.def(py::init<std::shared_ptr<PyContext>>())
.def_property_readonly(
"builder",
[](PyOps &self) -> PyBaseOpBuilder & { return self.pyOpBuilder; })
.def_property_readonly("context",
[](PyOps &self) -> std::shared_ptr<PyContext> {
return self.context;
})
.def("op",
[](PyOps &self, const std::string &opNameStr,
std::vector<PyType> pyResultTypes,
std::vector<PyValue> pyOperands,
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
Location loc = UnknownLoc::get(opBuilder.getContext());
OperationName opName(opNameStr, opBuilder.getContext());
SmallVector<Type, 4> types(pyResultTypes.begin(),
pyResultTypes.end());
SmallVector<Value, 4> operands(pyOperands.begin(),
pyOperands.end());
NamedAttributeList attrList;
if (attrs) {
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
if (!dictAttrs) {
throw py::raiseValueError(
"Expected `attrs` to be a DictionaryAttr");
}
attrList = NamedAttributeList(dictAttrs);
}
Operation *op =
Operation::create(loc, opName, types, operands, attrList);
opBuilder.insert(op);
return op;
},
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
py::arg("attrs") = llvm::Optional<PyAttribute>())
.def("func_op",
[](PyOps &self, const std::string &name, PyType type,
bool createEntryBlock) {
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
if (!functionType) {
throw py::raiseValueError("Illegal function type");
}
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = UnknownLoc::get(opBuilder.getContext());
// TODO: Add function and arg/result attributes.
FuncOp op = opBuilder.create<FuncOp>(
loc, StringRef(name), functionType,
/*attrs=*/ArrayRef<NamedAttribute>());
if (createEntryBlock) {
Block *entryBlock = new Block();
entryBlock->addArguments(functionType.getInputs());
op.getBody().push_back(entryBlock);
opBuilder.setInsertionPointToStart(entryBlock);
}
return PyOperationRef(op);
},
py::arg("name"), py::arg("type"),
py::arg("create_entry_block") = false,
R"(Creates a new `func` op, optionally creating an entry block.
If an entry block is created, the builder will be positioned
to its start.)")
.def("return_op", [](PyOps &self, std::vector<PyValue> pyOperands) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = UnknownLoc::get(opBuilder.getContext());
SmallVector<Value, 4> operands(pyOperands.begin(), pyOperands.end());
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
});
}
PyOpBuilder pyOpBuilder;
std::shared_ptr<PyContext> context;
};
class PyTypes {
public:
PyTypes(std::shared_ptr<PyContext> context) : context(std::move(context)) {}
static void bind(py::module m) {
py::class_<PyTypes>(m, "Types")
.def(py::init<std::shared_ptr<PyContext>>())
.def_property_readonly("context",
[](PyTypes &self) { return self.context; })
.def("integer",
[](PyTypes &self, unsigned width) {
return PyType(IntegerType::get(width, &self.context->context));
},
py::arg("width") = 32)
.def_property_readonly(
"i1",
[](PyTypes &self) {
return PyType(IntegerType::get(1, &self.context->context));
})
.def_property_readonly(
"i16",
[](PyTypes &self) {
return PyType(IntegerType::get(32, &self.context->context));
})
.def_property_readonly(
"i32",
[](PyTypes &self) {
return PyType(IntegerType::get(32, &self.context->context));
})
.def_property_readonly(
"i64",
[](PyTypes &self) {
return PyType(IntegerType::get(64, &self.context->context));
})
.def_property_readonly(
"f32",
[](PyTypes &self) {
return PyType(
FloatType::get(StandardTypes::F32, &self.context->context));
})
.def_property_readonly(
"f64",
[](PyTypes &self) {
return PyType(
FloatType::get(StandardTypes::F64, &self.context->context));
})
.def("tensor",
[](PyTypes &self, PyType elementType,
llvm::Optional<std::vector<int64_t>> shape) {
if (!elementType.type) {
throw py::raiseValueError("Null element type");
}
if (shape) {
return PyType(RankedTensorType::get(*shape, elementType.type));
} else {
return PyType(UnrankedTensorType::get(elementType.type));
}
},
py::arg("element_type"),
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
.def("function", [](PyTypes &self, std::vector<PyType> inputs,
std::vector<PyType> results) {
llvm::SmallVector<Type, 4> inputTypes;
llvm::SmallVector<Type, 1> resultTypes;
for (auto input : inputs) {
inputTypes.push_back(input.type);
}
for (auto result : results) {
resultTypes.push_back(result.type);
}
return PyType(FunctionType::get(inputTypes, resultTypes,
&self.context->context));
});
}
private:
std::shared_ptr<PyContext> context;
};
//===----------------------------------------------------------------------===//
// Module initialization
//===----------------------------------------------------------------------===//
void defineMlirIrModule(py::module m) {
m.doc() = "Python bindings for constructs in the mlir/IR library";
// Python only types.
PyOps::bind(m);
PyTypes::bind(m);
// Utility types.
PyBlockList::bind(m, "BlockList");
PyOperationList::bind(m, "OperationList");
// Wrapper types.
PyAttribute::bind(m);
PyBaseOperation::bind(m);
PyBaseOpBuilder::bind(m);
PyBlockRef::bind(m);
@ -126,6 +302,9 @@ void defineMlirIrModule(py::module m) {
PyOperationRef::bind(m);
PyOpBuilder::bind(m);
PyRegionRef::bind(m);
PySymbolTable::bind(m);
PyType::bind(m);
PyValue::bind(m);
}
//===----------------------------------------------------------------------===//
@ -139,10 +318,58 @@ void PyContext::bind(py::module m) {
return std::make_shared<PyContext>();
}))
.def("new_module",
[&](PyContext &context) -> PyModuleOp {
return PyModuleOp(context.shared_from_this(), {});
[&](PyContext &self) -> PyModuleOp {
Location loc = UnknownLoc::get(&self.context);
auto m = ModuleOp::create(loc);
return PyModuleOp(self.shared_from_this(), m);
})
.def("parse_asm", &PyContext::parseAsm);
.def("parse_asm", &PyContext::parseAsm)
.def("new_builder",
[](PyContext &self) {
// Note: we collapse the Builder and OpBuilder into one because
// there is little reason to expose the inheritance hierarchy to
// Python.
return PyOpBuilder(self);
})
// Salient functions from Builder.
.def("parse_type",
[](PyContext &self, const std::string &asmText) {
Type t = parseType(asmText, &self.context);
if (!t) {
std::string message = "Unable to parse MLIR type: ";
message.append(asmText);
throw py::raiseValueError(message);
}
return PyType(t);
})
.def("string_attr",
[](PyContext &self, const std::string &s) -> PyAttribute {
return StringAttr::get(s, &self.context);
})
.def("bytes_attr",
[](PyContext &self, py::bytes bytes) -> PyAttribute {
char *buffer;
ssize_t length;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bytes.ptr(), &buffer,
&length)) {
throw py::raiseValueError("Cannot extract bytes");
}
return StringAttr::get(StringRef(buffer, length), &self.context);
})
.def("flat_symbol_ref_attr",
[](PyContext &self, const std::string &s) -> PyAttribute {
return FlatSymbolRefAttr::get(s, &self.context);
})
.def("dictionary_attr", [](PyContext &self, py::dict d) -> PyAttribute {
SmallVector<NamedAttribute, 4> attrs;
for (auto &it : d) {
auto key = it.first.cast<std::string>();
auto value = it.second.cast<PyAttribute>();
auto keyIdent = Identifier::get(key, &self.context);
attrs.emplace_back(keyIdent, value.attr);
}
return DictionaryAttr::get(attrs, &self.context);
});
}
PyModuleOp PyContext::parseAsm(const std::string &asm_text) {
@ -183,13 +410,33 @@ void PyBaseOperation::bind(py::module m) {
[](PyBaseOperation &self) {
return self.getOperation()->getNumRegions();
})
.def("region", [](PyBaseOperation &self, int index) {
auto *op = self.getOperation();
if (index < 0 || index >= op->getNumRegions()) {
throw py::raisePyError(PyExc_IndexError,
"Region index out of bounds");
.def_property_readonly("results",
[](PyBaseOperation &self) {
auto *op = self.getOperation();
std::vector<PyValue> results(op->result_begin(),
op->result_end());
return results;
})
.def("region",
[](PyBaseOperation &self, int index) {
auto *op = self.getOperation();
if (index < 0 || index >= op->getNumRegions()) {
throw py::raisePyError(PyExc_IndexError,
"Region index out of bounds");
}
return PyRegionRef(op->getRegion(index));
})
.def_property_readonly("first_block", [](PyBaseOperation &self) {
Operation *op = self.getOperation();
assert(op);
if (op->getNumRegions() == 0) {
throw py::raiseValueError("Op has no regions");
}
return PyRegionRef(op->getRegion(index));
auto &region = op->getRegion(0);
if (region.empty()) {
throw py::raiseValueError("Op has no blocks");
}
return PyBlockRef(region.front());
});
}
@ -211,6 +458,8 @@ Operation *PyOperationRef::getOperation() { return operation; }
PyModuleOp::~PyModuleOp() = default;
void PyModuleOp::bind(py::module m) {
py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp")
.def_property_readonly("context",
[](PyModuleOp &self) { return self.context; })
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
py::arg("pretty") = false, py::arg("large_element_limit") = -1);
}
@ -313,6 +562,25 @@ void printLocation(Location loc, raw_ostream &out) {
}
}
//===----------------------------------------------------------------------===//
// PySymbolTable
//===----------------------------------------------------------------------===//
void PySymbolTable::bind(py::module m) {
py::class_<PySymbolTable>(m, "SymbolTable")
.def_property_readonly_static("symbol_attr_name",
[](const py::object &) {
auto sr =
SymbolTable::getSymbolAttrName();
return py::str(sr.data(), sr.size());
})
.def_property_readonly_static(
"visibility_attr_name", [](const py::object &) {
auto sr = SymbolTable::getVisibilityAttrName();
return py::str(sr.data(), sr.size());
});
}
//===----------------------------------------------------------------------===//
// DiagnosticCapture
//===----------------------------------------------------------------------===//
@ -364,8 +632,14 @@ DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
void PyBlockRef::bind(py::module m) {
py::class_<PyBlockRef>(m, "BlockRef")
.def_property_readonly("operations", [](PyBlockRef &self) {
return PyOperationList(self.block.getOperations());
.def_property_readonly("operations",
[](PyBlockRef &self) {
return PyOperationList(
self.block.getOperations());
})
.def_property_readonly("args", [](PyBlockRef &self) {
return std::vector<PyValue>(self.block.args_begin(),
self.block.args_end());
});
}
@ -380,13 +654,62 @@ void PyRegionRef::bind(py::module m) {
});
}
//===----------------------------------------------------------------------===//
// PyType
//===----------------------------------------------------------------------===//
void PyType::bind(py::module m) {
py::class_<PyType>(m, "Type").def("__repr__",
[](PyType &self) -> std::string {
if (!self.type)
return "<undefined type>";
std::string res;
llvm::raw_string_ostream os(res);
self.type.print(os);
return res;
});
}
//===----------------------------------------------------------------------===//
// PyValue
//===----------------------------------------------------------------------===//
void PyValue::bind(py::module m) {
py::class_<PyValue>(m, "Value").def("__repr__", [](PyValue &self) {
std::string res;
llvm::raw_string_ostream os(res);
os << self.value;
return res;
});
}
//===----------------------------------------------------------------------===//
// PyAttribute
//===----------------------------------------------------------------------===//
void PyAttribute::bind(py::module m) {
py::class_<PyAttribute>(m, "Attribute")
.def("__repr__", [](PyAttribute &self) {
std::string res;
llvm::raw_string_ostream os(res);
os << self.attr;
return res;
});
}
//===----------------------------------------------------------------------===//
// OpBuilder implementations
//===----------------------------------------------------------------------===//
PyBaseOpBuilder::~PyBaseOpBuilder() = default;
PyOpBuilder::~PyOpBuilder() = default;
OpBuilder &PyOpBuilder::getBuilder() { return builder; }
OpBuilder &PyOpBuilder::getBuilder(bool requirePosition) {
if (!builder.getBlock()) {
throw py::raisePyError(PyExc_IndexError, "Insertion point not set");
}
return builder;
}
void PyBaseOpBuilder::bind(py::module m) {
py::class_<PyBaseOpBuilder>(m, "BaseOpBuilder");
@ -394,7 +717,40 @@ void PyBaseOpBuilder::bind(py::module m) {
void PyOpBuilder::bind(py::module m) {
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
.def(py::init<PyContext &>());
.def(py::init<PyContext &>())
.def("clear_insertion_point",
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
.def("insert_op_before",
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation();
self.builder.setInsertionPoint(op);
},
"Sets the insertion point to just before the specified op.")
.def("insert_op_after",
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
Operation *op = pyOp.getOperation();
self.builder.setInsertionPointAfter(op);
},
"Sets the insertion point to just after the specified op.")
.def("insert_block_start",
[](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToStart(&block.block);
},
"Sets the insertion point to the start of the block.")
.def("insert_block_end",
[](PyOpBuilder &self, PyBlockRef block) {
self.builder.setInsertionPointToEnd(&block.block);
},
"Sets the insertion point to the end of the block.")
.def("insert_before_terminator",
[](PyOpBuilder &self, PyBlockRef block) {
auto *terminator = block.block.getTerminator();
if (!terminator) {
throw py::raiseValueError("Block has no terminator");
}
self.builder.setInsertionPoint(terminator);
},
"Sets the insertion point to just before the block terminator.");
}
} // namespace mlir

View File

@ -17,6 +17,7 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir {
@ -50,7 +51,9 @@ struct PyBaseOperation {
/// Wrapper around Module, capturing a PyContext reference.
struct PyModuleOp : PyBaseOperation {
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
: context(context), moduleOp(moduleOp) {}
: context(context), moduleOp(moduleOp) {
assert(moduleOp);
}
~PyModuleOp();
static void bind(py::module m);
Operation *getOperation() override;
@ -63,7 +66,9 @@ struct PyModuleOp : PyBaseOperation {
/// Wrapper around an Operation*.
struct PyOperationRef : PyBaseOperation {
PyOperationRef(Operation *operation) : operation(operation) {}
PyOperationRef(Operation *operation) : operation(operation) {
assert(operation);
}
PyOperationRef(Operation &operation) : operation(&operation) {}
~PyOperationRef();
static void bind(py::module m);
@ -72,6 +77,28 @@ struct PyOperationRef : PyBaseOperation {
Operation *operation;
};
/// Wrapper around SymbolTable.
struct PySymbolTable {
PySymbolTable(SymbolTable &symbolTable) : symbolTable(symbolTable) {}
static void bind(py::module m);
SymbolTable &symbolTable;
};
/// Wrapper around Value.
struct PyValue {
PyValue(Value value) : value(value) { assert(value); }
static void bind(py::module m);
operator Value() { return value; }
Value value;
};
/// Wrapper around Attribute.
struct PyAttribute {
PyAttribute(Attribute attr) : attr(attr) { assert(attr); }
static void bind(py::module m);
Attribute attr;
};
/// Wrapper around MLIRContext.
struct PyContext : std::enable_shared_from_this<PyContext> {
static void bind(py::module m);
@ -93,6 +120,14 @@ struct PyRegionRef {
Region &region;
};
struct PyType {
PyType() = default;
PyType(Type type) : type(type) {}
static void bind(py::module m);
operator Type() { return type; }
Type type;
};
/// Wrapper around an OpBuilder reference.
/// This class is inherently dangerous because it does not track ownership
/// of IR objects that it may be operating on and incorrect usage can cause
@ -103,7 +138,7 @@ class PyBaseOpBuilder {
public:
virtual ~PyBaseOpBuilder();
static void bind(py::module m);
virtual OpBuilder &getBuilder() = 0;
virtual OpBuilder &getBuilder(bool requirePosition = false) = 0;
};
/// Wrapper around an instance of an OpBuilder.
@ -112,7 +147,7 @@ public:
PyOpBuilder(PyContext &context) : builder(&context.context) {}
~PyOpBuilder() override;
static void bind(py::module m);
OpBuilder &getBuilder() override;
OpBuilder &getBuilder(bool requirePosition = false) override;
private:
OpBuilder builder;