Add numpy.get_slice op and wire it up to the tracer.

pull/1/head
Stella Laurenzo 2020-05-08 16:04:58 -07:00
parent db0b0ef1b2
commit a91b0bfbe1
10 changed files with 225 additions and 60 deletions

View File

@ -72,6 +72,12 @@ public:
StringAttr getClassName();
unsigned getSlotCount();
ArrayRef<Type> getSlotTypes();
// Shorthand to check whether the SlotObject is of a given className and
// arity.
bool isOfClassArity(StringRef className, int arity) {
return getClassName().getValue() == className && getSlotCount() == arity;
}
};
#include "npcomp/Dialect/Basicpy/BasicpyOpsDialect.h.inc"

View File

@ -73,4 +73,19 @@ def Basicpy_SingletonType : AnyTypeOf<[
Basicpy_EllipsisType
]>;
// A predicate to determine whether a Type is a SlotObject of a given
// className and arity. Does no checking of slot types.
class Basicpy_SlotObjectOfClassArity<string className, int arity> :
And<[
Basicpy_SlotObjectType.predicate,
CPred<
"$_self.cast<::mlir::NPCOMP::Basicpy::SlotObjectType>().isOfClassArity(\""
# className # "\", " # arity # ")">
]>;
// Type representing a 'slice' object, which mirrors the Python built-in
// slice class.
def Basicpy_SliceSlotObjectType :
Type<Basicpy_SlotObjectOfClassArity<"slice", 3>>;
#endif // NPCOMP_DIALECT_BASICPY_BASICPY_DIALECT

View File

@ -10,6 +10,7 @@
#define NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
include "mlir/IR/OpBase.td"
include "npcomp/Dialect/Basicpy/BasicpyDialect.td"
//===----------------------------------------------------------------------===//
// Dialect definition
@ -50,6 +51,25 @@ def Numpy_AnyDtype : DialectType<Numpy_Dialect,
// Type predicates
//===----------------------------------------------------------------------===//
// Any type, at any stage of analysis that can represent a numpy array.
def Numpy_AnyArray : TensorOf<[AnyType]>;
def Numpy_SliceTupleElement : AnyTypeOf<[
// Supports both "Index Arrays" and "Boolean mask index arrays".
Numpy_AnyArray,
// Indicates that an axis should be added (np.newaxis == None).
Basicpy_NoneType,
// Indicates that intervening axes should be preserved.
Basicpy_EllipsisType,
// A discrete numeric index (represented as IndexType so that a proper
// width can be target dependent).
Index,
// A generalized slice object.
Basicpy_SliceSlotObjectType,
], "types that are legal elements of a __getitem__ tuple operating on arrays">;
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT

View File

@ -150,4 +150,39 @@ def Numpy_TransposeOp : Numpy_Op<"transpose", []> {
}];
}
//----------------------------------------------------------------------------//
// Slicing
// See: https://docs.scipy.org/doc/numpy/user/basics.indexing.html
//----------------------------------------------------------------------------//
def Numpy_GetSlice : Numpy_Op<"get_slice", []> {
let summary = "Gets a slice of an array";
let description = [{
This op encapsulates all forms of indexing into an array by taking a
variable number of `slice` arguments, each of which represents a single
entry in a generalized indexing-tuple. Once full type inference has
been performed, there should be sufficient static information to determine
the exact slice semantics solely by the signature of types of the `slice`
arguments.
Note that there is a more general form of this op that is generally
needed for AST extraction that takes a variable length `tuple` instead
of a static list of arguments. It is expected that during type refinement
most such uses should degenerate to this static variant.
Per numpy semantics, many forms of slice return a view instead of a copy,
and determining the exact form requires additional analysis.
}];
let arguments = (ins
Numpy_AnyArray:$a,
Variadic<Numpy_SliceTupleElement>:$slice_elements
);
let results = (outs
Numpy_AnyArray:$result
);
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, $result)
}];
}
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS

View File

@ -10,6 +10,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
namespace mlir {

View File

@ -6,13 +6,13 @@ from npcomp.dialect import Basicpy
from _npcomp.mlir import ir
__all__ = [
"load_builtin_module",
"DialectHelper",
"load_builtin_module",
"DialectHelper",
]
class DialectHelper(Basicpy.DialectHelper):
r"""Dialect helper.
r"""Dialect helper.
>>> c = ir.MLIRContext()
>>> h = DialectHelper(c)
@ -49,27 +49,33 @@ class DialectHelper(Basicpy.DialectHelper):
tensor<*xf32>
>>> t.function_type([t.i32_type], [t.f32_type])
(i32) -> f32
>>> t.unknown_array_type
tensor<*x!numpy.any_dtype>
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.numpy_any_dtype = self.context.parse_type("!numpy.any_dtype")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.numpy_any_dtype = self.context.parse_type("!numpy.any_dtype")
self.unknown_array_type = self.tensor_type(self.numpy_any_dtype)
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
"""Creates a numpy.ufunc_call op."""
c = self.context
attrs = c.dictionary_attr({
"ufunc_ref": c.flat_symbol_ref_attr(callee_symbol)
})
return self.op("numpy.ufunc_call", [result_type], args, attrs)
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
"""Creates a numpy.ufunc_call op."""
c = self.context
attrs = c.dictionary_attr(
{"ufunc_ref": c.flat_symbol_ref_attr(callee_symbol)})
return self.op("numpy.ufunc_call", [result_type], args, attrs)
def numpy_narrow_op(self, result_type, operand):
"""Creates a numpy.narrow op."""
return self.op("numpy.narrow", [result_type], [operand])
def numpy_narrow_op(self, result_type, operand):
"""Creates a numpy.narrow op."""
return self.op("numpy.narrow", [result_type], [operand])
def numpy_get_slice_op(self, result_type, array, *slice_elements):
return self.op("numpy.get_slice", [result_type],
[array] + list(slice_elements))
def load_builtin_module(context=None):
"""Loads a module populated with numpy built-ins.
"""Loads a module populated with numpy built-ins.
This is not a long-term solution but overcomes some bootstrapping
issues.
@ -79,15 +85,15 @@ def load_builtin_module(context=None):
>>> op.is_registered
True
>>> op.name
'numpy.generic_ufunc'
'numpy.builtin_ufunc'
Args:
context: The MLIRContext to use (None to create a new one).
Returns:
A ModuleOp.
"""
if context is None: context = ir.MLIRContext()
return context.parse_asm(_BUILTIN_MODULE_ASM)
if context is None: context = ir.MLIRContext()
return context.parse_asm(_BUILTIN_MODULE_ASM)
_BUILTIN_MODULE_ASM = r"""
@ -96,5 +102,5 @@ _BUILTIN_MODULE_ASM = r"""
"""
if __name__ == "__main__":
import doctest
doctest.testmod()
import doctest
doctest.testmod()

View File

@ -58,6 +58,10 @@ class TraceContext:
self._next_id = 1
self.active = False
def _handle_array_getitem(self, array, key):
"""Handles a call to __getitem__ on a traced array."""
raise NotImplementedError("Array slicing not implemented")
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
"""Handles a ufunc invocation involving at least one TracedArray."""
return NotImplemented
@ -144,6 +148,11 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
def __repr__(self):
return "<TracedArray %d>" % self._uid
def __getitem__(self, key):
tc = self._tc
_assert_active(tc)
return tc._handle_array_getitem(self, key)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
tc = self._tc
_assert_active(tc)

View File

@ -95,8 +95,6 @@ class FunctionTracer(TraceContext):
return_operands = []
for py_result, mlir_result_type in zip(py_results, self._f_types):
mlir_result = self.get_traced_array_value(py_result)
if mlir_result is None:
raise TracingError("Unregistered traced array: %r", (py_result,))
# narrow to declared result type.
return_operands.extend(
h.numpy_narrow_op(mlir_result_type, mlir_result).results)
@ -108,7 +106,10 @@ class FunctionTracer(TraceContext):
self._traced_arrays[traced_array] = value
def get_traced_array_value(self, traced_array):
return self._traced_arrays.get(traced_array)
traced_value = self._traced_arrays.get(traced_array)
if traced_value is None:
raise TracingError("Unregistered traced array: %r", (traced_array,))
return traced_value
def _validate(self):
if not all(
@ -150,9 +151,6 @@ class FunctionTracer(TraceContext):
assert tv.type == TraceValueType.NDARRAY, (
"Unsupported TraceValueType: %r" % tv.type)
ssa_value = self.get_traced_array_value(tv.value)
if ssa_value is None:
raise TracingError(
"Required a traced python NDARRAY but not found: %r" % (tv,))
ssa_values.append(ssa_value)
return ssa_values
@ -196,6 +194,48 @@ class FunctionTracer(TraceContext):
invocation = TraceInvocation(inputs, kwargs, Protocol.ARRAY_FUNC)
return self._emit_invocation(emitter, invocation)
def _emit_slice_value(self, slice_element):
h = self._helper
if slice_element == None:
return h.basicpy_singleton_op(h.basicpy_NoneType).result
elif slice_element == Ellipsis:
return h.basicpy_singleton_op(h.basicpy_EllipsisType).result
elif isinstance(slice_element, int):
return h.constant_op(h.index_type,
h.context.index_attr(slice_element)).result
elif isinstance(slice_element, slice):
return self._emit_slice_object(slice_element)
else:
# Assume array convertible.
raise NotImplementedError(
"TODO: Slicing with generic arrays not yet implemented")
def _emit_slice_object(self, slice_object: slice):
h = self._helper
def emit_index(index):
if index is None:
return h.basicpy_singleton_op(h.basicpy_NoneType).result
else:
return h.constant_op(h.index_type,
h.context.index_attr(int(index))).result
start = emit_index(slice_object.start)
stop = emit_index(slice_object.stop)
step = emit_index(slice_object.step)
return h.basicpy_slot_object_make_op("slice", start, stop, step).result
def _handle_array_getitem(self, array, key):
h = self._helper
array_value = self.get_traced_array_value(array)
# Array slicing is always based on a tuple.
slice_tuple = key if isinstance(key, tuple) else (key,)
# Resolve and emit each slice element.
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
result_value = h.numpy_get_slice_op(
h.unknown_array_type, array_value, *slice_values).result
result_array = TracedArray(self)
self.set_traced_array(result_array, result_value)
return result_array
if __name__ == "__main__":
import doctest

View File

@ -0,0 +1,18 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
def slice_array1(a: np.ndarray) -> np.ndarray:
return a[1, 2:10:2, 3:4, ..., :, 0]
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.slice_array1 = slice_array1
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.slice_array1)
print(mb.module.to_asm())

View File

@ -195,60 +195,71 @@ void PyDialectHelper::bind(py::module m) {
pyOperands.end());
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
})
.def("constant_op",
[](PyDialectHelper &self, PyType type, PyAttribute value) {
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
Location loc = UnknownLoc::get(opBuilder.getContext());
return PyOperationRef(
opBuilder.create<ConstantOp>(loc, type.type, value.attr));
})
// Types.
.def_property_readonly("index_type",
[](PyDialectHelper &self) -> PyType {
return IndexType::get(&self.context->context);
})
.def("integer_type",
[](PyDialectHelper &self, unsigned width) {
return PyType(IntegerType::get(width, &self.context->context));
[](PyDialectHelper &self, unsigned width) -> PyType {
return IntegerType::get(width, &self.context->context);
},
py::arg("width") = 32)
.def_property_readonly("i1_type",
[](PyDialectHelper &self) {
return PyType(
IntegerType::get(1, &self.context->context));
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(1,
&self.context->context);
})
.def_property_readonly("i16_type",
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(32,
&self.context->context);
})
.def_property_readonly("i32_type",
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(32,
&self.context->context);
})
.def_property_readonly("i64_type",
[](PyDialectHelper &self) -> PyType {
return IntegerType::get(64,
&self.context->context);
})
.def_property_readonly(
"i16_type",
[](PyDialectHelper &self) {
return PyType(IntegerType::get(32, &self.context->context));
})
.def_property_readonly(
"i32_type",
[](PyDialectHelper &self) {
return PyType(IntegerType::get(32, &self.context->context));
})
.def_property_readonly(
"i64_type",
[](PyDialectHelper &self) {
return PyType(IntegerType::get(64, &self.context->context));
})
.def_property_readonly("f32_type",
[](PyDialectHelper &self) {
return PyType(FloatType::get(
StandardTypes::F32, &self.context->context));
[](PyDialectHelper &self) -> PyType {
return FloatType::get(StandardTypes::F32,
&self.context->context);
})
.def_property_readonly("f64_type",
[](PyDialectHelper &self) {
return PyType(FloatType::get(
StandardTypes::F64, &self.context->context));
[](PyDialectHelper &self) -> PyType {
return FloatType::get(StandardTypes::F64,
&self.context->context);
})
.def("tensor_type",
[](PyDialectHelper &self, PyType elementType,
llvm::Optional<std::vector<int64_t>> shape) {
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
if (!elementType.type) {
throw py::raiseValueError("Null element type");
}
if (shape) {
return PyType(RankedTensorType::get(*shape, elementType.type));
return RankedTensorType::get(*shape, elementType.type);
} else {
return PyType(UnrankedTensorType::get(elementType.type));
return UnrankedTensorType::get(elementType.type);
}
},
py::arg("element_type"),
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
.def("function_type",
[](PyDialectHelper &self, std::vector<PyType> inputs,
std::vector<PyType> results) {
std::vector<PyType> results) -> PyType {
llvm::SmallVector<Type, 4> inputTypes;
llvm::SmallVector<Type, 1> resultTypes;
for (auto input : inputs) {
@ -257,8 +268,8 @@ void PyDialectHelper::bind(py::module m) {
for (auto result : results) {
resultTypes.push_back(result.type);
}
return PyType(FunctionType::get(inputTypes, resultTypes,
&self.context->context));
return FunctionType::get(inputTypes, resultTypes,
&self.context->context);
});
}
@ -326,6 +337,10 @@ void PyContext::bind(py::module m) {
}
return PyType(t);
})
.def("index_attr",
[](PyContext &self, int64_t indexValue) -> PyAttribute {
return IntegerAttr::get(IndexType::get(&self.context), indexValue);
})
.def("string_attr",
[](PyContext &self, const std::string &s) -> PyAttribute {
return StringAttr::get(s, &self.context);