mirror of https://github.com/llvm/torch-mlir
Add numpy.get_slice op and wire it up to the tracer.
parent
db0b0ef1b2
commit
a91b0bfbe1
|
@ -72,6 +72,12 @@ public:
|
||||||
StringAttr getClassName();
|
StringAttr getClassName();
|
||||||
unsigned getSlotCount();
|
unsigned getSlotCount();
|
||||||
ArrayRef<Type> getSlotTypes();
|
ArrayRef<Type> getSlotTypes();
|
||||||
|
|
||||||
|
// Shorthand to check whether the SlotObject is of a given className and
|
||||||
|
// arity.
|
||||||
|
bool isOfClassArity(StringRef className, int arity) {
|
||||||
|
return getClassName().getValue() == className && getSlotCount() == arity;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "npcomp/Dialect/Basicpy/BasicpyOpsDialect.h.inc"
|
#include "npcomp/Dialect/Basicpy/BasicpyOpsDialect.h.inc"
|
||||||
|
|
|
@ -73,4 +73,19 @@ def Basicpy_SingletonType : AnyTypeOf<[
|
||||||
Basicpy_EllipsisType
|
Basicpy_EllipsisType
|
||||||
]>;
|
]>;
|
||||||
|
|
||||||
|
// A predicate to determine whether a Type is a SlotObject of a given
|
||||||
|
// className and arity. Does no checking of slot types.
|
||||||
|
class Basicpy_SlotObjectOfClassArity<string className, int arity> :
|
||||||
|
And<[
|
||||||
|
Basicpy_SlotObjectType.predicate,
|
||||||
|
CPred<
|
||||||
|
"$_self.cast<::mlir::NPCOMP::Basicpy::SlotObjectType>().isOfClassArity(\""
|
||||||
|
# className # "\", " # arity # ")">
|
||||||
|
]>;
|
||||||
|
|
||||||
|
// Type representing a 'slice' object, which mirrors the Python built-in
|
||||||
|
// slice class.
|
||||||
|
def Basicpy_SliceSlotObjectType :
|
||||||
|
Type<Basicpy_SlotObjectOfClassArity<"slice", 3>>;
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_BASICPY_BASICPY_DIALECT
|
#endif // NPCOMP_DIALECT_BASICPY_BASICPY_DIALECT
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#define NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
#define NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "npcomp/Dialect/Basicpy/BasicpyDialect.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Dialect definition
|
// Dialect definition
|
||||||
|
@ -50,6 +51,25 @@ def Numpy_AnyDtype : DialectType<Numpy_Dialect,
|
||||||
// Type predicates
|
// Type predicates
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Any type, at any stage of analysis that can represent a numpy array.
|
||||||
def Numpy_AnyArray : TensorOf<[AnyType]>;
|
def Numpy_AnyArray : TensorOf<[AnyType]>;
|
||||||
|
|
||||||
|
def Numpy_SliceTupleElement : AnyTypeOf<[
|
||||||
|
// Supports both "Index Arrays" and "Boolean mask index arrays".
|
||||||
|
Numpy_AnyArray,
|
||||||
|
|
||||||
|
// Indicates that an axis should be added (np.newaxis == None).
|
||||||
|
Basicpy_NoneType,
|
||||||
|
|
||||||
|
// Indicates that intervening axes should be preserved.
|
||||||
|
Basicpy_EllipsisType,
|
||||||
|
|
||||||
|
// A discrete numeric index (represented as IndexType so that a proper
|
||||||
|
// width can be target dependent).
|
||||||
|
Index,
|
||||||
|
|
||||||
|
// A generalized slice object.
|
||||||
|
Basicpy_SliceSlotObjectType,
|
||||||
|
], "types that are legal elements of a __getitem__ tuple operating on arrays">;
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
||||||
|
|
|
@ -150,4 +150,39 @@ def Numpy_TransposeOp : Numpy_Op<"transpose", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
// Slicing
|
||||||
|
// See: https://docs.scipy.org/doc/numpy/user/basics.indexing.html
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
|
||||||
|
def Numpy_GetSlice : Numpy_Op<"get_slice", []> {
|
||||||
|
let summary = "Gets a slice of an array";
|
||||||
|
let description = [{
|
||||||
|
This op encapsulates all forms of indexing into an array by taking a
|
||||||
|
variable number of `slice` arguments, each of which represents a single
|
||||||
|
entry in a generalized indexing-tuple. Once full type inference has
|
||||||
|
been performed, there should be sufficient static information to determine
|
||||||
|
the exact slice semantics solely by the signature of types of the `slice`
|
||||||
|
arguments.
|
||||||
|
|
||||||
|
Note that there is a more general form of this op that is generally
|
||||||
|
needed for AST extraction that takes a variable length `tuple` instead
|
||||||
|
of a static list of arguments. It is expected that during type refinement
|
||||||
|
most such uses should degenerate to this static variant.
|
||||||
|
|
||||||
|
Per numpy semantics, many forms of slice return a view instead of a copy,
|
||||||
|
and determining the exact form requires additional analysis.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Numpy_AnyArray:$a,
|
||||||
|
Variadic<Numpy_SliceTupleElement>:$slice_elements
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Numpy_AnyArray:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
operands attr-dict `:` functional-type(operands, $result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/FunctionImplementation.h"
|
#include "mlir/IR/FunctionImplementation.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -6,13 +6,13 @@ from npcomp.dialect import Basicpy
|
||||||
from _npcomp.mlir import ir
|
from _npcomp.mlir import ir
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_builtin_module",
|
"load_builtin_module",
|
||||||
"DialectHelper",
|
"DialectHelper",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class DialectHelper(Basicpy.DialectHelper):
|
class DialectHelper(Basicpy.DialectHelper):
|
||||||
r"""Dialect helper.
|
r"""Dialect helper.
|
||||||
|
|
||||||
>>> c = ir.MLIRContext()
|
>>> c = ir.MLIRContext()
|
||||||
>>> h = DialectHelper(c)
|
>>> h = DialectHelper(c)
|
||||||
|
@ -49,27 +49,33 @@ class DialectHelper(Basicpy.DialectHelper):
|
||||||
tensor<*xf32>
|
tensor<*xf32>
|
||||||
>>> t.function_type([t.i32_type], [t.f32_type])
|
>>> t.function_type([t.i32_type], [t.f32_type])
|
||||||
(i32) -> f32
|
(i32) -> f32
|
||||||
|
>>> t.unknown_array_type
|
||||||
|
tensor<*x!numpy.any_dtype>
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.numpy_any_dtype = self.context.parse_type("!numpy.any_dtype")
|
self.numpy_any_dtype = self.context.parse_type("!numpy.any_dtype")
|
||||||
|
self.unknown_array_type = self.tensor_type(self.numpy_any_dtype)
|
||||||
|
|
||||||
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
|
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
|
||||||
"""Creates a numpy.ufunc_call op."""
|
"""Creates a numpy.ufunc_call op."""
|
||||||
c = self.context
|
c = self.context
|
||||||
attrs = c.dictionary_attr({
|
attrs = c.dictionary_attr(
|
||||||
"ufunc_ref": c.flat_symbol_ref_attr(callee_symbol)
|
{"ufunc_ref": c.flat_symbol_ref_attr(callee_symbol)})
|
||||||
})
|
return self.op("numpy.ufunc_call", [result_type], args, attrs)
|
||||||
return self.op("numpy.ufunc_call", [result_type], args, attrs)
|
|
||||||
|
|
||||||
def numpy_narrow_op(self, result_type, operand):
|
def numpy_narrow_op(self, result_type, operand):
|
||||||
"""Creates a numpy.narrow op."""
|
"""Creates a numpy.narrow op."""
|
||||||
return self.op("numpy.narrow", [result_type], [operand])
|
return self.op("numpy.narrow", [result_type], [operand])
|
||||||
|
|
||||||
|
def numpy_get_slice_op(self, result_type, array, *slice_elements):
|
||||||
|
return self.op("numpy.get_slice", [result_type],
|
||||||
|
[array] + list(slice_elements))
|
||||||
|
|
||||||
|
|
||||||
def load_builtin_module(context=None):
|
def load_builtin_module(context=None):
|
||||||
"""Loads a module populated with numpy built-ins.
|
"""Loads a module populated with numpy built-ins.
|
||||||
|
|
||||||
This is not a long-term solution but overcomes some bootstrapping
|
This is not a long-term solution but overcomes some bootstrapping
|
||||||
issues.
|
issues.
|
||||||
|
@ -79,15 +85,15 @@ def load_builtin_module(context=None):
|
||||||
>>> op.is_registered
|
>>> op.is_registered
|
||||||
True
|
True
|
||||||
>>> op.name
|
>>> op.name
|
||||||
'numpy.generic_ufunc'
|
'numpy.builtin_ufunc'
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: The MLIRContext to use (None to create a new one).
|
context: The MLIRContext to use (None to create a new one).
|
||||||
Returns:
|
Returns:
|
||||||
A ModuleOp.
|
A ModuleOp.
|
||||||
"""
|
"""
|
||||||
if context is None: context = ir.MLIRContext()
|
if context is None: context = ir.MLIRContext()
|
||||||
return context.parse_asm(_BUILTIN_MODULE_ASM)
|
return context.parse_asm(_BUILTIN_MODULE_ASM)
|
||||||
|
|
||||||
|
|
||||||
_BUILTIN_MODULE_ASM = r"""
|
_BUILTIN_MODULE_ASM = r"""
|
||||||
|
@ -96,5 +102,5 @@ _BUILTIN_MODULE_ASM = r"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import doctest
|
import doctest
|
||||||
doctest.testmod()
|
doctest.testmod()
|
||||||
|
|
|
@ -58,6 +58,10 @@ class TraceContext:
|
||||||
self._next_id = 1
|
self._next_id = 1
|
||||||
self.active = False
|
self.active = False
|
||||||
|
|
||||||
|
def _handle_array_getitem(self, array, key):
|
||||||
|
"""Handles a call to __getitem__ on a traced array."""
|
||||||
|
raise NotImplementedError("Array slicing not implemented")
|
||||||
|
|
||||||
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
||||||
"""Handles a ufunc invocation involving at least one TracedArray."""
|
"""Handles a ufunc invocation involving at least one TracedArray."""
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
@ -144,6 +148,11 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<TracedArray %d>" % self._uid
|
return "<TracedArray %d>" % self._uid
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
tc = self._tc
|
||||||
|
_assert_active(tc)
|
||||||
|
return tc._handle_array_getitem(self, key)
|
||||||
|
|
||||||
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
||||||
tc = self._tc
|
tc = self._tc
|
||||||
_assert_active(tc)
|
_assert_active(tc)
|
||||||
|
|
|
@ -95,8 +95,6 @@ class FunctionTracer(TraceContext):
|
||||||
return_operands = []
|
return_operands = []
|
||||||
for py_result, mlir_result_type in zip(py_results, self._f_types):
|
for py_result, mlir_result_type in zip(py_results, self._f_types):
|
||||||
mlir_result = self.get_traced_array_value(py_result)
|
mlir_result = self.get_traced_array_value(py_result)
|
||||||
if mlir_result is None:
|
|
||||||
raise TracingError("Unregistered traced array: %r", (py_result,))
|
|
||||||
# narrow to declared result type.
|
# narrow to declared result type.
|
||||||
return_operands.extend(
|
return_operands.extend(
|
||||||
h.numpy_narrow_op(mlir_result_type, mlir_result).results)
|
h.numpy_narrow_op(mlir_result_type, mlir_result).results)
|
||||||
|
@ -108,7 +106,10 @@ class FunctionTracer(TraceContext):
|
||||||
self._traced_arrays[traced_array] = value
|
self._traced_arrays[traced_array] = value
|
||||||
|
|
||||||
def get_traced_array_value(self, traced_array):
|
def get_traced_array_value(self, traced_array):
|
||||||
return self._traced_arrays.get(traced_array)
|
traced_value = self._traced_arrays.get(traced_array)
|
||||||
|
if traced_value is None:
|
||||||
|
raise TracingError("Unregistered traced array: %r", (traced_array,))
|
||||||
|
return traced_value
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
if not all(
|
if not all(
|
||||||
|
@ -150,9 +151,6 @@ class FunctionTracer(TraceContext):
|
||||||
assert tv.type == TraceValueType.NDARRAY, (
|
assert tv.type == TraceValueType.NDARRAY, (
|
||||||
"Unsupported TraceValueType: %r" % tv.type)
|
"Unsupported TraceValueType: %r" % tv.type)
|
||||||
ssa_value = self.get_traced_array_value(tv.value)
|
ssa_value = self.get_traced_array_value(tv.value)
|
||||||
if ssa_value is None:
|
|
||||||
raise TracingError(
|
|
||||||
"Required a traced python NDARRAY but not found: %r" % (tv,))
|
|
||||||
ssa_values.append(ssa_value)
|
ssa_values.append(ssa_value)
|
||||||
return ssa_values
|
return ssa_values
|
||||||
|
|
||||||
|
@ -196,6 +194,48 @@ class FunctionTracer(TraceContext):
|
||||||
invocation = TraceInvocation(inputs, kwargs, Protocol.ARRAY_FUNC)
|
invocation = TraceInvocation(inputs, kwargs, Protocol.ARRAY_FUNC)
|
||||||
return self._emit_invocation(emitter, invocation)
|
return self._emit_invocation(emitter, invocation)
|
||||||
|
|
||||||
|
def _emit_slice_value(self, slice_element):
|
||||||
|
h = self._helper
|
||||||
|
if slice_element == None:
|
||||||
|
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||||
|
elif slice_element == Ellipsis:
|
||||||
|
return h.basicpy_singleton_op(h.basicpy_EllipsisType).result
|
||||||
|
elif isinstance(slice_element, int):
|
||||||
|
return h.constant_op(h.index_type,
|
||||||
|
h.context.index_attr(slice_element)).result
|
||||||
|
elif isinstance(slice_element, slice):
|
||||||
|
return self._emit_slice_object(slice_element)
|
||||||
|
else:
|
||||||
|
# Assume array convertible.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TODO: Slicing with generic arrays not yet implemented")
|
||||||
|
|
||||||
|
def _emit_slice_object(self, slice_object: slice):
|
||||||
|
h = self._helper
|
||||||
|
def emit_index(index):
|
||||||
|
if index is None:
|
||||||
|
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||||
|
else:
|
||||||
|
return h.constant_op(h.index_type,
|
||||||
|
h.context.index_attr(int(index))).result
|
||||||
|
start = emit_index(slice_object.start)
|
||||||
|
stop = emit_index(slice_object.stop)
|
||||||
|
step = emit_index(slice_object.step)
|
||||||
|
return h.basicpy_slot_object_make_op("slice", start, stop, step).result
|
||||||
|
|
||||||
|
def _handle_array_getitem(self, array, key):
|
||||||
|
h = self._helper
|
||||||
|
array_value = self.get_traced_array_value(array)
|
||||||
|
# Array slicing is always based on a tuple.
|
||||||
|
slice_tuple = key if isinstance(key, tuple) else (key,)
|
||||||
|
# Resolve and emit each slice element.
|
||||||
|
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
|
||||||
|
result_value = h.numpy_get_slice_op(
|
||||||
|
h.unknown_array_type, array_value, *slice_values).result
|
||||||
|
result_array = TracedArray(self)
|
||||||
|
self.set_traced_array(result_array, result_value)
|
||||||
|
return result_array
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import doctest
|
import doctest
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import npcomp as npc
|
||||||
|
from npcomp.types import *
|
||||||
|
|
||||||
|
def slice_array1(a: np.ndarray) -> np.ndarray:
|
||||||
|
return a[1, 2:10:2, 3:4, ..., :, 0]
|
||||||
|
|
||||||
|
# TODO: Implement subclassing and deriving constraints by run
|
||||||
|
exp = npc.Exporter()
|
||||||
|
exp.slice_array1 = slice_array1
|
||||||
|
|
||||||
|
mb = npc.tracing.ModuleBuilder()
|
||||||
|
mb.trace(exp.slice_array1)
|
||||||
|
print(mb.module.to_asm())
|
|
@ -195,60 +195,71 @@ void PyDialectHelper::bind(py::module m) {
|
||||||
pyOperands.end());
|
pyOperands.end());
|
||||||
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
||||||
})
|
})
|
||||||
|
.def("constant_op",
|
||||||
|
[](PyDialectHelper &self, PyType type, PyAttribute value) {
|
||||||
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
|
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||||
|
return PyOperationRef(
|
||||||
|
opBuilder.create<ConstantOp>(loc, type.type, value.attr));
|
||||||
|
})
|
||||||
|
|
||||||
// Types.
|
// Types.
|
||||||
|
.def_property_readonly("index_type",
|
||||||
|
[](PyDialectHelper &self) -> PyType {
|
||||||
|
return IndexType::get(&self.context->context);
|
||||||
|
})
|
||||||
.def("integer_type",
|
.def("integer_type",
|
||||||
[](PyDialectHelper &self, unsigned width) {
|
[](PyDialectHelper &self, unsigned width) -> PyType {
|
||||||
return PyType(IntegerType::get(width, &self.context->context));
|
return IntegerType::get(width, &self.context->context);
|
||||||
},
|
},
|
||||||
py::arg("width") = 32)
|
py::arg("width") = 32)
|
||||||
.def_property_readonly("i1_type",
|
.def_property_readonly("i1_type",
|
||||||
[](PyDialectHelper &self) {
|
[](PyDialectHelper &self) -> PyType {
|
||||||
return PyType(
|
return IntegerType::get(1,
|
||||||
IntegerType::get(1, &self.context->context));
|
&self.context->context);
|
||||||
|
})
|
||||||
|
.def_property_readonly("i16_type",
|
||||||
|
[](PyDialectHelper &self) -> PyType {
|
||||||
|
return IntegerType::get(32,
|
||||||
|
&self.context->context);
|
||||||
|
})
|
||||||
|
.def_property_readonly("i32_type",
|
||||||
|
[](PyDialectHelper &self) -> PyType {
|
||||||
|
return IntegerType::get(32,
|
||||||
|
&self.context->context);
|
||||||
|
})
|
||||||
|
.def_property_readonly("i64_type",
|
||||||
|
[](PyDialectHelper &self) -> PyType {
|
||||||
|
return IntegerType::get(64,
|
||||||
|
&self.context->context);
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
|
||||||
"i16_type",
|
|
||||||
[](PyDialectHelper &self) {
|
|
||||||
return PyType(IntegerType::get(32, &self.context->context));
|
|
||||||
})
|
|
||||||
.def_property_readonly(
|
|
||||||
"i32_type",
|
|
||||||
[](PyDialectHelper &self) {
|
|
||||||
return PyType(IntegerType::get(32, &self.context->context));
|
|
||||||
})
|
|
||||||
.def_property_readonly(
|
|
||||||
"i64_type",
|
|
||||||
[](PyDialectHelper &self) {
|
|
||||||
return PyType(IntegerType::get(64, &self.context->context));
|
|
||||||
})
|
|
||||||
.def_property_readonly("f32_type",
|
.def_property_readonly("f32_type",
|
||||||
[](PyDialectHelper &self) {
|
[](PyDialectHelper &self) -> PyType {
|
||||||
return PyType(FloatType::get(
|
return FloatType::get(StandardTypes::F32,
|
||||||
StandardTypes::F32, &self.context->context));
|
&self.context->context);
|
||||||
})
|
})
|
||||||
.def_property_readonly("f64_type",
|
.def_property_readonly("f64_type",
|
||||||
[](PyDialectHelper &self) {
|
[](PyDialectHelper &self) -> PyType {
|
||||||
return PyType(FloatType::get(
|
return FloatType::get(StandardTypes::F64,
|
||||||
StandardTypes::F64, &self.context->context));
|
&self.context->context);
|
||||||
})
|
})
|
||||||
.def("tensor_type",
|
.def("tensor_type",
|
||||||
[](PyDialectHelper &self, PyType elementType,
|
[](PyDialectHelper &self, PyType elementType,
|
||||||
llvm::Optional<std::vector<int64_t>> shape) {
|
llvm::Optional<std::vector<int64_t>> shape) -> PyType {
|
||||||
if (!elementType.type) {
|
if (!elementType.type) {
|
||||||
throw py::raiseValueError("Null element type");
|
throw py::raiseValueError("Null element type");
|
||||||
}
|
}
|
||||||
if (shape) {
|
if (shape) {
|
||||||
return PyType(RankedTensorType::get(*shape, elementType.type));
|
return RankedTensorType::get(*shape, elementType.type);
|
||||||
} else {
|
} else {
|
||||||
return PyType(UnrankedTensorType::get(elementType.type));
|
return UnrankedTensorType::get(elementType.type);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
py::arg("element_type"),
|
py::arg("element_type"),
|
||||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||||
.def("function_type",
|
.def("function_type",
|
||||||
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
||||||
std::vector<PyType> results) {
|
std::vector<PyType> results) -> PyType {
|
||||||
llvm::SmallVector<Type, 4> inputTypes;
|
llvm::SmallVector<Type, 4> inputTypes;
|
||||||
llvm::SmallVector<Type, 1> resultTypes;
|
llvm::SmallVector<Type, 1> resultTypes;
|
||||||
for (auto input : inputs) {
|
for (auto input : inputs) {
|
||||||
|
@ -257,8 +268,8 @@ void PyDialectHelper::bind(py::module m) {
|
||||||
for (auto result : results) {
|
for (auto result : results) {
|
||||||
resultTypes.push_back(result.type);
|
resultTypes.push_back(result.type);
|
||||||
}
|
}
|
||||||
return PyType(FunctionType::get(inputTypes, resultTypes,
|
return FunctionType::get(inputTypes, resultTypes,
|
||||||
&self.context->context));
|
&self.context->context);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,6 +337,10 @@ void PyContext::bind(py::module m) {
|
||||||
}
|
}
|
||||||
return PyType(t);
|
return PyType(t);
|
||||||
})
|
})
|
||||||
|
.def("index_attr",
|
||||||
|
[](PyContext &self, int64_t indexValue) -> PyAttribute {
|
||||||
|
return IntegerAttr::get(IndexType::get(&self.context), indexValue);
|
||||||
|
})
|
||||||
.def("string_attr",
|
.def("string_attr",
|
||||||
[](PyContext &self, const std::string &s) -> PyAttribute {
|
[](PyContext &self, const std::string &s) -> PyAttribute {
|
||||||
return StringAttr::get(s, &self.context);
|
return StringAttr::get(s, &self.context);
|
||||||
|
|
Loading…
Reference in New Issue