mirror of https://github.com/llvm/torch-mlir
Merge ir.Ops and ir.Types into ir.DialectHelper.
This will aid in managing hierarchies of custom dialect helpers.pull/1/head
parent
aa9ffc3a11
commit
4ebf972503
|
@ -6,25 +6,24 @@ from npcomp.native.mlir import ir
|
|||
|
||||
__all__ = [
|
||||
"load_builtin_module",
|
||||
"Types",
|
||||
"DialectHelper",
|
||||
]
|
||||
|
||||
|
||||
class Ops(ir.Ops):
|
||||
r"""Dialect ops.
|
||||
class DialectHelper(ir.DialectHelper):
|
||||
r"""Dialect helper.
|
||||
|
||||
>>> c = ir.MLIRContext()
|
||||
>>> t = Types(c)
|
||||
>>> h = DialectHelper(c)
|
||||
>>> m = c.new_module()
|
||||
>>> tensor_type = t.tensor(t.f32)
|
||||
>>> ops = Ops(c)
|
||||
>>> ops.builder.insert_block_start(m.first_block)
|
||||
>>> f = ops.func_op("foobar", t.function(
|
||||
>>> tensor_type = h.tensor_type(h.f32_type)
|
||||
>>> h.builder.insert_block_start(m.first_block)
|
||||
>>> f = h.func_op("foobar", h.function_type(
|
||||
... [tensor_type, tensor_type], [tensor_type]),
|
||||
... create_entry_block=True)
|
||||
>>> uf = ops.numpy_ufunc_call_op("numpy.add", tensor_type,
|
||||
>>> uf = h.numpy_ufunc_call_op("numpy.add", tensor_type,
|
||||
... *f.first_block.args)
|
||||
>>> _ = ops.return_op(uf.results)
|
||||
>>> _ = h.return_op(uf.results)
|
||||
>>> print(m.to_asm())
|
||||
<BLANKLINE>
|
||||
<BLANKLINE>
|
||||
|
@ -34,7 +33,27 @@ class Ops(ir.Ops):
|
|||
return %0 : tensor<*xf32>
|
||||
}
|
||||
}
|
||||
|
||||
Types:
|
||||
>>> t = DialectHelper(ir.MLIRContext())
|
||||
>>> t.numpy_any_dtype
|
||||
!numpy.any_dtype
|
||||
>>> t.tensor_type(t.numpy_any_dtype, [1, 2, 3])
|
||||
tensor<1x2x3x!numpy.any_dtype>
|
||||
>>> t.tensor_type(t.numpy_any_dtype)
|
||||
tensor<*x!numpy.any_dtype>
|
||||
>>> t.tensor_type(t.numpy_any_dtype, [-1, 2])
|
||||
tensor<?x2x!numpy.any_dtype>
|
||||
>>> t.tensor_type(t.f32_type)
|
||||
tensor<*xf32>
|
||||
>>> t.function_type([t.i32_type], [t.f32_type])
|
||||
(i32) -> f32
|
||||
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.numpy_any_dtype = self.context.parse_type("!numpy.any_dtype")
|
||||
|
||||
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
|
||||
"""Creates a numpy.ufunc_call op."""
|
||||
c = self.context
|
||||
|
@ -43,34 +62,11 @@ class Ops(ir.Ops):
|
|||
})
|
||||
return self.op("numpy.ufunc_call", [result_type], args, attrs)
|
||||
|
||||
def numpy_narrow(self, result_type, operand):
|
||||
def numpy_narrow_op(self, result_type, operand):
|
||||
"""Creates a numpy.narrow op."""
|
||||
return self.op("numpy.narrow", [result_type], [operand])
|
||||
|
||||
|
||||
class Types(ir.Types):
|
||||
"""Container/factory for dialect types.
|
||||
|
||||
>>> t = Types(ir.MLIRContext())
|
||||
>>> t.numpy_any_dtype
|
||||
!numpy.any_dtype
|
||||
>>> t.tensor(t.numpy_any_dtype, [1, 2, 3])
|
||||
tensor<1x2x3x!numpy.any_dtype>
|
||||
>>> t.tensor(t.numpy_any_dtype)
|
||||
tensor<*x!numpy.any_dtype>
|
||||
>>> t.tensor(t.numpy_any_dtype, [-1, 2])
|
||||
tensor<?x2x!numpy.any_dtype>
|
||||
>>> t.tensor(t.f32)
|
||||
tensor<*xf32>
|
||||
>>> t.function([t.i32], [t.f32])
|
||||
(i32) -> f32
|
||||
|
||||
"""
|
||||
def __init__(self, context):
|
||||
super().__init__(context)
|
||||
self.numpy_any_dtype = context.parse_type("!numpy.any_dtype")
|
||||
|
||||
|
||||
def load_builtin_module(context=None):
|
||||
"""Loads a module populated with numpy built-ins.
|
||||
|
||||
|
|
|
@ -120,22 +120,29 @@ private:
|
|||
// Python only classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class PyOps {
|
||||
/// Helper for creating (possibly dialect specific) IR objects. This class
|
||||
/// is intended to be subclassed on the Python side (possibly with multiple
|
||||
/// inheritance) to provide Python level APIs for custom dialects. The base
|
||||
/// class contains helpers for std types and ops.
|
||||
class PyDialectHelper {
|
||||
public:
|
||||
PyOps(std::shared_ptr<PyContext> context)
|
||||
: pyOpBuilder(*context), context(std::move(context)) {}
|
||||
PyDialectHelper(std::shared_ptr<PyContext> context)
|
||||
: pyOpBuilder(*context), context(std::move(context)) {
|
||||
}
|
||||
static void bind(py::module m) {
|
||||
py::class_<PyOps>(m, "Ops")
|
||||
py::class_<PyDialectHelper>(m, "DialectHelper")
|
||||
.def(py::init<std::shared_ptr<PyContext>>())
|
||||
.def_property_readonly(
|
||||
"builder",
|
||||
[](PyOps &self) -> PyBaseOpBuilder & { return self.pyOpBuilder; })
|
||||
.def_property_readonly("context",
|
||||
[](PyOps &self) -> std::shared_ptr<PyContext> {
|
||||
return self.context;
|
||||
.def_property_readonly("builder",
|
||||
[](PyDialectHelper &self) -> PyBaseOpBuilder & {
|
||||
return self.pyOpBuilder;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"context",
|
||||
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
|
||||
return self.context;
|
||||
})
|
||||
.def("op",
|
||||
[](PyOps &self, const std::string &opNameStr,
|
||||
[](PyDialectHelper &self, const std::string &opNameStr,
|
||||
std::vector<PyType> pyResultTypes,
|
||||
std::vector<PyValue> pyOperands,
|
||||
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
||||
|
@ -163,7 +170,7 @@ public:
|
|||
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
||||
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
||||
.def("func_op",
|
||||
[](PyOps &self, const std::string &name, PyType type,
|
||||
[](PyDialectHelper &self, const std::string &name, PyType type,
|
||||
bool createEntryBlock) {
|
||||
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
||||
if (!functionType) {
|
||||
|
@ -188,64 +195,55 @@ public:
|
|||
R"(Creates a new `func` op, optionally creating an entry block.
|
||||
If an entry block is created, the builder will be positioned
|
||||
to its start.)")
|
||||
.def("return_op", [](PyOps &self, std::vector<PyValue> pyOperands) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||
SmallVector<Value, 4> operands(pyOperands.begin(), pyOperands.end());
|
||||
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
||||
});
|
||||
}
|
||||
PyOpBuilder pyOpBuilder;
|
||||
std::shared_ptr<PyContext> context;
|
||||
};
|
||||
.def("return_op",
|
||||
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||
SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||
pyOperands.end());
|
||||
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
||||
})
|
||||
|
||||
class PyTypes {
|
||||
public:
|
||||
PyTypes(std::shared_ptr<PyContext> context) : context(std::move(context)) {}
|
||||
static void bind(py::module m) {
|
||||
py::class_<PyTypes>(m, "Types")
|
||||
.def(py::init<std::shared_ptr<PyContext>>())
|
||||
.def_property_readonly("context",
|
||||
[](PyTypes &self) { return self.context; })
|
||||
.def("integer",
|
||||
[](PyTypes &self, unsigned width) {
|
||||
// Types.
|
||||
.def("integer_type",
|
||||
[](PyDialectHelper &self, unsigned width) {
|
||||
return PyType(IntegerType::get(width, &self.context->context));
|
||||
},
|
||||
py::arg("width") = 32)
|
||||
.def_property_readonly(
|
||||
"i1",
|
||||
[](PyTypes &self) {
|
||||
"i1_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(1, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i16",
|
||||
[](PyTypes &self) {
|
||||
"i16_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(32, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i32",
|
||||
[](PyTypes &self) {
|
||||
"i32_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(32, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i64",
|
||||
[](PyTypes &self) {
|
||||
"i64_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(64, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"f32",
|
||||
[](PyTypes &self) {
|
||||
"f32_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(
|
||||
FloatType::get(StandardTypes::F32, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"f64",
|
||||
[](PyTypes &self) {
|
||||
"f64_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(
|
||||
FloatType::get(StandardTypes::F64, &self.context->context));
|
||||
})
|
||||
.def("tensor",
|
||||
[](PyTypes &self, PyType elementType,
|
||||
.def("tensor_type",
|
||||
[](PyDialectHelper &self, PyType elementType,
|
||||
llvm::Optional<std::vector<int64_t>> shape) {
|
||||
if (!elementType.type) {
|
||||
throw py::raiseValueError("Null element type");
|
||||
|
@ -258,22 +256,22 @@ public:
|
|||
},
|
||||
py::arg("element_type"),
|
||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||
.def("function", [](PyTypes &self, std::vector<PyType> inputs,
|
||||
std::vector<PyType> results) {
|
||||
llvm::SmallVector<Type, 4> inputTypes;
|
||||
llvm::SmallVector<Type, 1> resultTypes;
|
||||
for (auto input : inputs) {
|
||||
inputTypes.push_back(input.type);
|
||||
}
|
||||
for (auto result : results) {
|
||||
resultTypes.push_back(result.type);
|
||||
}
|
||||
return PyType(FunctionType::get(inputTypes, resultTypes,
|
||||
&self.context->context));
|
||||
});
|
||||
.def("function_type",
|
||||
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
||||
std::vector<PyType> results) {
|
||||
llvm::SmallVector<Type, 4> inputTypes;
|
||||
llvm::SmallVector<Type, 1> resultTypes;
|
||||
for (auto input : inputs) {
|
||||
inputTypes.push_back(input.type);
|
||||
}
|
||||
for (auto result : results) {
|
||||
resultTypes.push_back(result.type);
|
||||
}
|
||||
return PyType(FunctionType::get(inputTypes, resultTypes,
|
||||
&self.context->context));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
PyOpBuilder pyOpBuilder;
|
||||
std::shared_ptr<PyContext> context;
|
||||
};
|
||||
|
||||
|
@ -285,8 +283,7 @@ void defineMlirIrModule(py::module m) {
|
|||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||
|
||||
// Python only types.
|
||||
PyOps::bind(m);
|
||||
PyTypes::bind(m);
|
||||
PyDialectHelper::bind(m);
|
||||
|
||||
// Utility types.
|
||||
PyBlockList::bind(m, "BlockList");
|
||||
|
|
|
@ -36,7 +36,8 @@ class TraceInvocation(
|
|||
|
||||
|
||||
class EmissionRequest(
|
||||
namedtuple("EmissionRequest", ["input_ssa_values", "ops", "types", "extra"],
|
||||
namedtuple("EmissionRequest",
|
||||
["input_ssa_values", "dialect_helper", "extra"],
|
||||
defaults=(None,))):
|
||||
"""Represents the result of processing inputs from an invocation.
|
||||
|
||||
|
@ -47,8 +48,7 @@ class EmissionRequest(
|
|||
blackbox mechanism to transfer un-tracked state from an invocation to
|
||||
emission.
|
||||
|
||||
The `ops` and `types` fields correspond to mlir.ir.Ops and mlir.ir.Types
|
||||
instances respectively.
|
||||
The `dialect_helper` fields correspond to mlir.ir.DialectHelper.
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
|
@ -166,9 +166,10 @@ class GenericCallUfuncEmitter(FuncEmitter):
|
|||
return py_results[0]
|
||||
|
||||
def emit(self, request: EmissionRequest):
|
||||
op_result_type = request.types.tensor(request.types.numpy_any_dtype)
|
||||
call_op = request.ops.numpy_ufunc_call_op(self._ufunc_name, op_result_type,
|
||||
*request.input_ssa_values)
|
||||
h = request.dialect_helper
|
||||
op_result_type = h.tensor_type(h.numpy_any_dtype)
|
||||
call_op = h.numpy_ufunc_call_op(self._ufunc_name, op_result_type,
|
||||
*request.input_ssa_values)
|
||||
return call_op.results
|
||||
|
||||
|
||||
|
@ -210,9 +211,10 @@ class GenericArrayFuncEmitter(FuncEmitter):
|
|||
return tuple(py_results)
|
||||
|
||||
def emit(self, request: EmissionRequest):
|
||||
op_result_types = [request.types.tensor(request.types.numpy_any_dtype)
|
||||
h = request.dialect_helper
|
||||
op_result_types = [h.tensor_type(h.numpy_any_dtype)
|
||||
] * self._nresults
|
||||
op = request.ops.op(self._op_name, op_result_types,
|
||||
op = h.op(self._op_name, op_result_types,
|
||||
request.input_ssa_values)
|
||||
return op.results
|
||||
|
||||
|
|
|
@ -23,8 +23,7 @@ class ModuleBuilder:
|
|||
# TODO: Instead of bootstrapping a large module, populate imports
|
||||
# dynamically.
|
||||
self.module = Numpy.load_builtin_module(self.context)
|
||||
self.ops = Numpy.Ops(self.context)
|
||||
self.types = Numpy.Types(self.context)
|
||||
self.helper = Numpy.DialectHelper(self.context)
|
||||
self.emitters = (emitter_registry
|
||||
if emitter_registry else EmitterRegistry.create_default())
|
||||
|
||||
|
@ -46,13 +45,12 @@ class FunctionTracer(TraceContext):
|
|||
"_args_array_params",
|
||||
"_f",
|
||||
"_f_types",
|
||||
"_helper",
|
||||
"_mlir_m",
|
||||
"_mlir_c",
|
||||
"_python_args",
|
||||
"_ops",
|
||||
"_result_array_params",
|
||||
"_traced_arrays",
|
||||
"_types",
|
||||
]
|
||||
|
||||
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
||||
|
@ -65,8 +63,7 @@ class FunctionTracer(TraceContext):
|
|||
# Alias some parent members for convenience.
|
||||
self._mlir_m = module_builder.module
|
||||
self._mlir_c = module_builder.context
|
||||
self._ops = module_builder.ops
|
||||
self._types = module_builder.types
|
||||
self._helper = module_builder.helper
|
||||
|
||||
# Extract ArrayParams for all args and results.
|
||||
self._args_array_params = [
|
||||
|
@ -86,7 +83,7 @@ class FunctionTracer(TraceContext):
|
|||
# TODO: More sophisticated signature merging
|
||||
# TODO: Multiple results
|
||||
# TODO: Error reporting
|
||||
ops = self._ops
|
||||
h = self._helper
|
||||
py_results = (self.epf.pyfunc(*self._python_args),)
|
||||
if len(py_results) != len(self._f_types):
|
||||
raise TracingError("Traced function returned != %d results: %r" % (
|
||||
|
@ -102,8 +99,8 @@ class FunctionTracer(TraceContext):
|
|||
raise TracingError("Unregistered traced array: %r", (py_result,))
|
||||
# narrow to declared result type.
|
||||
return_operands.extend(
|
||||
ops.numpy_narrow(mlir_result_type, mlir_result).results)
|
||||
ops.return_op(return_operands)
|
||||
h.numpy_narrow_op(mlir_result_type, mlir_result).results)
|
||||
h.return_op(return_operands)
|
||||
|
||||
def set_traced_array(self, traced_array, value):
|
||||
"""Sets the current SSA value for a traced_array."""
|
||||
|
@ -124,8 +121,7 @@ class FunctionTracer(TraceContext):
|
|||
def _create_mlir_function(self):
|
||||
mlir_c = self._mlir_c
|
||||
mlir_m = self._mlir_m
|
||||
ops = self._ops
|
||||
types = self._types
|
||||
h = self._helper
|
||||
epf = self.epf
|
||||
f_args = [
|
||||
mlir_c.parse_type(ap.mlir_tensor_type_asm)
|
||||
|
@ -134,9 +130,9 @@ class FunctionTracer(TraceContext):
|
|||
f_types = [
|
||||
mlir_c.parse_type(self._result_array_params.mlir_tensor_type_asm)
|
||||
]
|
||||
ops.builder.insert_before_terminator(mlir_m.first_block)
|
||||
f_type = types.function(f_args, f_types)
|
||||
f = ops.func_op(epf.__name__, f_type, create_entry_block=True)
|
||||
h.builder.insert_before_terminator(mlir_m.first_block)
|
||||
f_type = h.function_type(f_args, f_types)
|
||||
f = h.func_op(epf.__name__, f_type, create_entry_block=True)
|
||||
return f, f_types
|
||||
|
||||
def _create_trace_roots(self):
|
||||
|
@ -179,8 +175,7 @@ class FunctionTracer(TraceContext):
|
|||
tv_map = emitter.map_invocation(invocation)
|
||||
input_ssa_values = self._resolve_input_ssa_values(tv_map.input_trace_values)
|
||||
request = EmissionRequest(input_ssa_values,
|
||||
ops=self._ops,
|
||||
types=self._types,
|
||||
dialect_helper=self._helper,
|
||||
extra=tv_map.extra)
|
||||
result_ssa_values = emitter.emit(request)
|
||||
py_values = self._resolve_result_py_values(tv_map.result_trace_value_types,
|
||||
|
|
Loading…
Reference in New Issue