mirror of https://github.com/llvm/torch-mlir
Add numpy.get_slice op and wire it up to the tracer.
parent
db0b0ef1b2
commit
a91b0bfbe1
|
@ -72,6 +72,12 @@ public:
|
|||
StringAttr getClassName();
|
||||
unsigned getSlotCount();
|
||||
ArrayRef<Type> getSlotTypes();
|
||||
|
||||
// Shorthand to check whether the SlotObject is of a given className and
|
||||
// arity.
|
||||
bool isOfClassArity(StringRef className, int arity) {
|
||||
return getClassName().getValue() == className && getSlotCount() == arity;
|
||||
}
|
||||
};
|
||||
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyOpsDialect.h.inc"
|
||||
|
|
|
@ -73,4 +73,19 @@ def Basicpy_SingletonType : AnyTypeOf<[
|
|||
Basicpy_EllipsisType
|
||||
]>;
|
||||
|
||||
// A predicate to determine whether a Type is a SlotObject of a given
|
||||
// className and arity. Does no checking of slot types.
|
||||
class Basicpy_SlotObjectOfClassArity<string className, int arity> :
|
||||
And<[
|
||||
Basicpy_SlotObjectType.predicate,
|
||||
CPred<
|
||||
"$_self.cast<::mlir::NPCOMP::Basicpy::SlotObjectType>().isOfClassArity(\""
|
||||
# className # "\", " # arity # ")">
|
||||
]>;
|
||||
|
||||
// Type representing a 'slice' object, which mirrors the Python built-in
|
||||
// slice class.
|
||||
def Basicpy_SliceSlotObjectType :
|
||||
Type<Basicpy_SlotObjectOfClassArity<"slice", 3>>;
|
||||
|
||||
#endif // NPCOMP_DIALECT_BASICPY_BASICPY_DIALECT
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "npcomp/Dialect/Basicpy/BasicpyDialect.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect definition
|
||||
|
@ -50,6 +51,25 @@ def Numpy_AnyDtype : DialectType<Numpy_Dialect,
|
|||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Any type, at any stage of analysis that can represent a numpy array.
|
||||
def Numpy_AnyArray : TensorOf<[AnyType]>;
|
||||
|
||||
def Numpy_SliceTupleElement : AnyTypeOf<[
|
||||
// Supports both "Index Arrays" and "Boolean mask index arrays".
|
||||
Numpy_AnyArray,
|
||||
|
||||
// Indicates that an axis should be added (np.newaxis == None).
|
||||
Basicpy_NoneType,
|
||||
|
||||
// Indicates that intervening axes should be preserved.
|
||||
Basicpy_EllipsisType,
|
||||
|
||||
// A discrete numeric index (represented as IndexType so that a proper
|
||||
// width can be target dependent).
|
||||
Index,
|
||||
|
||||
// A generalized slice object.
|
||||
Basicpy_SliceSlotObjectType,
|
||||
], "types that are legal elements of a __getitem__ tuple operating on arrays">;
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
||||
|
|
|
@ -150,4 +150,39 @@ def Numpy_TransposeOp : Numpy_Op<"transpose", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Slicing
|
||||
// See: https://docs.scipy.org/doc/numpy/user/basics.indexing.html
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def Numpy_GetSlice : Numpy_Op<"get_slice", []> {
|
||||
let summary = "Gets a slice of an array";
|
||||
let description = [{
|
||||
This op encapsulates all forms of indexing into an array by taking a
|
||||
variable number of `slice` arguments, each of which represents a single
|
||||
entry in a generalized indexing-tuple. Once full type inference has
|
||||
been performed, there should be sufficient static information to determine
|
||||
the exact slice semantics solely by the signature of types of the `slice`
|
||||
arguments.
|
||||
|
||||
Note that there is a more general form of this op that is generally
|
||||
needed for AST extraction that takes a variable length `tuple` instead
|
||||
of a static list of arguments. It is expected that during type refinement
|
||||
most such uses should degenerate to this static variant.
|
||||
|
||||
Per numpy semantics, many forms of slice return a view instead of a copy,
|
||||
and determining the exact form requires additional analysis.
|
||||
}];
|
||||
let arguments = (ins
|
||||
Numpy_AnyArray:$a,
|
||||
Variadic<Numpy_SliceTupleElement>:$slice_elements
|
||||
);
|
||||
let results = (outs
|
||||
Numpy_AnyArray:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` functional-type(operands, $result)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -49,24 +49,30 @@ class DialectHelper(Basicpy.DialectHelper):
|
|||
tensor<*xf32>
|
||||
>>> t.function_type([t.i32_type], [t.f32_type])
|
||||
(i32) -> f32
|
||||
>>> t.unknown_array_type
|
||||
tensor<*x!numpy.any_dtype>
|
||||
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.numpy_any_dtype = self.context.parse_type("!numpy.any_dtype")
|
||||
self.unknown_array_type = self.tensor_type(self.numpy_any_dtype)
|
||||
|
||||
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
|
||||
"""Creates a numpy.ufunc_call op."""
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({
|
||||
"ufunc_ref": c.flat_symbol_ref_attr(callee_symbol)
|
||||
})
|
||||
attrs = c.dictionary_attr(
|
||||
{"ufunc_ref": c.flat_symbol_ref_attr(callee_symbol)})
|
||||
return self.op("numpy.ufunc_call", [result_type], args, attrs)
|
||||
|
||||
def numpy_narrow_op(self, result_type, operand):
|
||||
"""Creates a numpy.narrow op."""
|
||||
return self.op("numpy.narrow", [result_type], [operand])
|
||||
|
||||
def numpy_get_slice_op(self, result_type, array, *slice_elements):
|
||||
return self.op("numpy.get_slice", [result_type],
|
||||
[array] + list(slice_elements))
|
||||
|
||||
|
||||
def load_builtin_module(context=None):
|
||||
"""Loads a module populated with numpy built-ins.
|
||||
|
@ -79,7 +85,7 @@ def load_builtin_module(context=None):
|
|||
>>> op.is_registered
|
||||
True
|
||||
>>> op.name
|
||||
'numpy.generic_ufunc'
|
||||
'numpy.builtin_ufunc'
|
||||
|
||||
Args:
|
||||
context: The MLIRContext to use (None to create a new one).
|
||||
|
|
|
@ -58,6 +58,10 @@ class TraceContext:
|
|||
self._next_id = 1
|
||||
self.active = False
|
||||
|
||||
def _handle_array_getitem(self, array, key):
|
||||
"""Handles a call to __getitem__ on a traced array."""
|
||||
raise NotImplementedError("Array slicing not implemented")
|
||||
|
||||
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
||||
"""Handles a ufunc invocation involving at least one TracedArray."""
|
||||
return NotImplemented
|
||||
|
@ -144,6 +148,11 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
|||
def __repr__(self):
|
||||
return "<TracedArray %d>" % self._uid
|
||||
|
||||
def __getitem__(self, key):
|
||||
tc = self._tc
|
||||
_assert_active(tc)
|
||||
return tc._handle_array_getitem(self, key)
|
||||
|
||||
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
||||
tc = self._tc
|
||||
_assert_active(tc)
|
||||
|
|
|
@ -95,8 +95,6 @@ class FunctionTracer(TraceContext):
|
|||
return_operands = []
|
||||
for py_result, mlir_result_type in zip(py_results, self._f_types):
|
||||
mlir_result = self.get_traced_array_value(py_result)
|
||||
if mlir_result is None:
|
||||
raise TracingError("Unregistered traced array: %r", (py_result,))
|
||||
# narrow to declared result type.
|
||||
return_operands.extend(
|
||||
h.numpy_narrow_op(mlir_result_type, mlir_result).results)
|
||||
|
@ -108,7 +106,10 @@ class FunctionTracer(TraceContext):
|
|||
self._traced_arrays[traced_array] = value
|
||||
|
||||
def get_traced_array_value(self, traced_array):
|
||||
return self._traced_arrays.get(traced_array)
|
||||
traced_value = self._traced_arrays.get(traced_array)
|
||||
if traced_value is None:
|
||||
raise TracingError("Unregistered traced array: %r", (traced_array,))
|
||||
return traced_value
|
||||
|
||||
def _validate(self):
|
||||
if not all(
|
||||
|
@ -150,9 +151,6 @@ class FunctionTracer(TraceContext):
|
|||
assert tv.type == TraceValueType.NDARRAY, (
|
||||
"Unsupported TraceValueType: %r" % tv.type)
|
||||
ssa_value = self.get_traced_array_value(tv.value)
|
||||
if ssa_value is None:
|
||||
raise TracingError(
|
||||
"Required a traced python NDARRAY but not found: %r" % (tv,))
|
||||
ssa_values.append(ssa_value)
|
||||
return ssa_values
|
||||
|
||||
|
@ -196,6 +194,48 @@ class FunctionTracer(TraceContext):
|
|||
invocation = TraceInvocation(inputs, kwargs, Protocol.ARRAY_FUNC)
|
||||
return self._emit_invocation(emitter, invocation)
|
||||
|
||||
def _emit_slice_value(self, slice_element):
|
||||
h = self._helper
|
||||
if slice_element == None:
|
||||
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||
elif slice_element == Ellipsis:
|
||||
return h.basicpy_singleton_op(h.basicpy_EllipsisType).result
|
||||
elif isinstance(slice_element, int):
|
||||
return h.constant_op(h.index_type,
|
||||
h.context.index_attr(slice_element)).result
|
||||
elif isinstance(slice_element, slice):
|
||||
return self._emit_slice_object(slice_element)
|
||||
else:
|
||||
# Assume array convertible.
|
||||
raise NotImplementedError(
|
||||
"TODO: Slicing with generic arrays not yet implemented")
|
||||
|
||||
def _emit_slice_object(self, slice_object: slice):
|
||||
h = self._helper
|
||||
def emit_index(index):
|
||||
if index is None:
|
||||
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||
else:
|
||||
return h.constant_op(h.index_type,
|
||||
h.context.index_attr(int(index))).result
|
||||
start = emit_index(slice_object.start)
|
||||
stop = emit_index(slice_object.stop)
|
||||
step = emit_index(slice_object.step)
|
||||
return h.basicpy_slot_object_make_op("slice", start, stop, step).result
|
||||
|
||||
def _handle_array_getitem(self, array, key):
|
||||
h = self._helper
|
||||
array_value = self.get_traced_array_value(array)
|
||||
# Array slicing is always based on a tuple.
|
||||
slice_tuple = key if isinstance(key, tuple) else (key,)
|
||||
# Resolve and emit each slice element.
|
||||
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
|
||||
result_value = h.numpy_get_slice_op(
|
||||
h.unknown_array_type, array_value, *slice_values).result
|
||||
result_array = TracedArray(self)
|
||||
self.set_traced_array(result_array, result_value)
|
||||
return result_array
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
def slice_array1(a: np.ndarray) -> np.ndarray:
|
||||
return a[1, 2:10:2, 3:4, ..., :, 0]
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.slice_array1 = slice_array1
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.slice_array1)
|
||||
print(mb.module.to_asm())
|
|
@ -195,60 +195,71 @@ void PyDialectHelper::bind(py::module m) {
|
|||
pyOperands.end());
|
||||
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
||||
})
|
||||
.def("constant_op",
|
||||
[](PyDialectHelper &self, PyType type, PyAttribute value) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||
return PyOperationRef(
|
||||
opBuilder.create<ConstantOp>(loc, type.type, value.attr));
|
||||
})
|
||||
|
||||
// Types.
|
||||
.def_property_readonly("index_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IndexType::get(&self.context->context);
|
||||
})
|
||||
.def("integer_type",
|
||||
[](PyDialectHelper &self, unsigned width) {
|
||||
return PyType(IntegerType::get(width, &self.context->context));
|
||||
[](PyDialectHelper &self, unsigned width) -> PyType {
|
||||
return IntegerType::get(width, &self.context->context);
|
||||
},
|
||||
py::arg("width") = 32)
|
||||
.def_property_readonly("i1_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(
|
||||
IntegerType::get(1, &self.context->context));
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(1,
|
||||
&self.context->context);
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i16_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(32, &self.context->context));
|
||||
.def_property_readonly("i16_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(32,
|
||||
&self.context->context);
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i32_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(32, &self.context->context));
|
||||
.def_property_readonly("i32_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(32,
|
||||
&self.context->context);
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i64_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(64, &self.context->context));
|
||||
.def_property_readonly("i64_type",
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return IntegerType::get(64,
|
||||
&self.context->context);
|
||||
})
|
||||
.def_property_readonly("f32_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(FloatType::get(
|
||||
StandardTypes::F32, &self.context->context));
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return FloatType::get(StandardTypes::F32,
|
||||
&self.context->context);
|
||||
})
|
||||
.def_property_readonly("f64_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(FloatType::get(
|
||||
StandardTypes::F64, &self.context->context));
|
||||
[](PyDialectHelper &self) -> PyType {
|
||||
return FloatType::get(StandardTypes::F64,
|
||||
&self.context->context);
|
||||
})
|
||||
.def("tensor_type",
|
||||
[](PyDialectHelper &self, PyType elementType,
|
||||
llvm::Optional<std::vector<int64_t>> shape) {
|
||||
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
|
||||
if (!elementType.type) {
|
||||
throw py::raiseValueError("Null element type");
|
||||
}
|
||||
if (shape) {
|
||||
return PyType(RankedTensorType::get(*shape, elementType.type));
|
||||
return RankedTensorType::get(*shape, elementType.type);
|
||||
} else {
|
||||
return PyType(UnrankedTensorType::get(elementType.type));
|
||||
return UnrankedTensorType::get(elementType.type);
|
||||
}
|
||||
},
|
||||
py::arg("element_type"),
|
||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||
.def("function_type",
|
||||
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
||||
std::vector<PyType> results) {
|
||||
std::vector<PyType> results) -> PyType {
|
||||
llvm::SmallVector<Type, 4> inputTypes;
|
||||
llvm::SmallVector<Type, 1> resultTypes;
|
||||
for (auto input : inputs) {
|
||||
|
@ -257,8 +268,8 @@ void PyDialectHelper::bind(py::module m) {
|
|||
for (auto result : results) {
|
||||
resultTypes.push_back(result.type);
|
||||
}
|
||||
return PyType(FunctionType::get(inputTypes, resultTypes,
|
||||
&self.context->context));
|
||||
return FunctionType::get(inputTypes, resultTypes,
|
||||
&self.context->context);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -326,6 +337,10 @@ void PyContext::bind(py::module m) {
|
|||
}
|
||||
return PyType(t);
|
||||
})
|
||||
.def("index_attr",
|
||||
[](PyContext &self, int64_t indexValue) -> PyAttribute {
|
||||
return IntegerAttr::get(IndexType::get(&self.context), indexValue);
|
||||
})
|
||||
.def("string_attr",
|
||||
[](PyContext &self, const std::string &s) -> PyAttribute {
|
||||
return StringAttr::get(s, &self.context);
|
||||
|
|
Loading…
Reference in New Issue