mirror of https://github.com/llvm/torch-mlir
Add partial evaluator for explicit numpy ufuncs.
* This enables emission of "numpy.add(a, b)" and several dozen others. * Will deprecate original ufunc infra in a follow-on.pull/1/head
parent
1024c508f8
commit
7ca292ade5
|
@ -47,8 +47,8 @@ class NdArrayType
|
|||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) { return kind == NumpyTypes::NdArray; }
|
||||
static NdArrayType get(Type optionalDtype, MLIRContext *context);
|
||||
Type getOptionalDtype();
|
||||
static NdArrayType get(Type optionalDtype);
|
||||
Type getDtype();
|
||||
};
|
||||
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOpsDialect.h.inc"
|
||||
|
|
|
@ -80,11 +80,35 @@ def Numpy_CopyToTensorOp : Numpy_Op<"copy_to_tensor", []> {
|
|||
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def Numpy_BuiltinUfuncCallOp : Numpy_Op<"builtin_ufunc_call"> {
|
||||
let summary = "A __call__ operation on a named/builtin ufunc";
|
||||
let description = [{
|
||||
Simple ufunc call semantics for builtin ufuncs with none of the advanced
|
||||
arguments specified.
|
||||
|
||||
Note that without the `out=` parameter, ufunc call operations (unlike
|
||||
others like `at`) are defined purely in the value domain and do not alias.
|
||||
As such, they operate on tensors, not ndarray.
|
||||
}];
|
||||
let arguments = (ins
|
||||
StrAttr:$qualified_name,
|
||||
Variadic<Numpy_AnyTensor>:$inputs
|
||||
);
|
||||
let results = (outs
|
||||
Numpy_AnyTensor:$output
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
`<` $qualified_name `>` `(` operands `)` attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> {
|
||||
let summary = "References a built-in universal function";
|
||||
let description = [{
|
||||
This module-level op binds by name to a fully-qualified numpy built-in
|
||||
ufunc (i.e. "numpy.add") and carries metadata associated with it.
|
||||
|
||||
Deprecated: Remove once using new builtin_ufunc_call.
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -93,6 +117,7 @@ def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
|
|||
Symbol]> {
|
||||
let summary = "Defines a ufunc in terms of overloaded element-wise functions";
|
||||
let description = [{
|
||||
Deprecated: Remove once using new builtin_ufunc_call.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
|
@ -108,6 +133,7 @@ def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
|
|||
let summary = "Return a value from a generic_ufunc";
|
||||
let description = [{
|
||||
Must terminate the basic block of a generic_ufunc overload.
|
||||
Deprecated: Remove once using new builtin_ufunc_call.
|
||||
}];
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$operands
|
||||
|
@ -122,6 +148,7 @@ def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> {
|
|||
Invokes a ufunc with the given arguments. This variant models the __call__
|
||||
behavior of a python ufunc except that it does not model the `out`
|
||||
parameter, which indicates an in-place update.
|
||||
Deprecated: Remove once using new builtin_ufunc_call.
|
||||
}];
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$ufunc_ref,
|
||||
|
|
|
@ -6,6 +6,7 @@ add_mlir_dialect_library(NPCOMPNumpyDialect
|
|||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy
|
||||
|
||||
DEPENDS
|
||||
NPCOMPBasicpyDialect
|
||||
MLIRNumpyOpsIncGen
|
||||
)
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -33,7 +34,7 @@ Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
|||
// Parse:
|
||||
// ndarray<?>
|
||||
// ndarray<i32>
|
||||
Type dtype;
|
||||
Type dtype = Basicpy::UnknownType::get(getContext());
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
if (failed(parser.parseOptionalQuestion())) {
|
||||
|
@ -43,7 +44,7 @@ Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
|||
}
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
return NdArrayType::get(dtype, getContext());
|
||||
return NdArrayType::get(dtype);
|
||||
}
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
|
||||
|
@ -56,10 +57,11 @@ void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
os << "any_dtype";
|
||||
return;
|
||||
case NumpyTypes::NdArray: {
|
||||
auto unknownType = Basicpy::UnknownType::get(getContext());
|
||||
auto ndarray = type.cast<NdArrayType>();
|
||||
auto dtype = ndarray.getOptionalDtype();
|
||||
auto dtype = ndarray.getDtype();
|
||||
os << "ndarray<";
|
||||
if (dtype)
|
||||
if (dtype != unknownType)
|
||||
os.printType(dtype);
|
||||
else
|
||||
os << "?";
|
||||
|
@ -100,8 +102,9 @@ struct NdArrayTypeStorage : public TypeStorage {
|
|||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
NdArrayType NdArrayType::get(Type optionalDtype, MLIRContext *context) {
|
||||
return Base::get(context, NumpyTypes::NdArray, optionalDtype);
|
||||
NdArrayType NdArrayType::get(Type dtype) {
|
||||
assert(dtype && "dtype cannot be null");
|
||||
return Base::get(dtype.getContext(), NumpyTypes::NdArray, dtype);
|
||||
}
|
||||
|
||||
Type NdArrayType::getOptionalDtype() { return getImpl()->optionalDtype; }
|
||||
Type NdArrayType::getDtype() { return getImpl()->optionalDtype; }
|
||||
|
|
|
@ -25,6 +25,9 @@ public:
|
|||
py::class_<BasicpyDialectHelper, PyDialectHelper>(m, "BasicpyDialectHelper")
|
||||
.def(py::init<PyContext &, PyOpBuilder &>(), py::keep_alive<1, 2>(),
|
||||
py::keep_alive<1, 3>())
|
||||
// ---------------------------------------------------------------------
|
||||
// Basicpy dialect
|
||||
// ---------------------------------------------------------------------
|
||||
.def_property_readonly("basicpy_BoolType",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::BoolType::get(
|
||||
|
@ -105,6 +108,26 @@ public:
|
|||
loc, resultType, slotObject, indexAttr);
|
||||
return op.getOperation();
|
||||
})
|
||||
// ---------------------------------------------------------------------
|
||||
// Numpy dialect
|
||||
// ---------------------------------------------------------------------
|
||||
.def("numpy_copy_to_tensor_op",
|
||||
[](BasicpyDialectHelper &self, PyValue source) -> PyOperationRef {
|
||||
auto sourceType =
|
||||
source.value.getType().dyn_cast<Numpy::NdArrayType>();
|
||||
if (!sourceType) {
|
||||
source.value.dump();
|
||||
throw py::raiseValueError("expected ndarray type for "
|
||||
"numpy_copy_to_tensor_op");
|
||||
}
|
||||
auto dtype = sourceType.getDtype();
|
||||
auto tensorType = UnrankedTensorType::get(dtype);
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
auto op = opBuilder.create<Numpy::CopyToTensorOp>(
|
||||
loc, tensorType, source.value);
|
||||
return op.getOperation();
|
||||
})
|
||||
.def("numpy_create_array_from_tensor_op",
|
||||
[](BasicpyDialectHelper &self, PyValue source) -> PyOperationRef {
|
||||
auto sourceType = source.value.getType().dyn_cast<TensorType>();
|
||||
|
@ -113,8 +136,7 @@ public:
|
|||
"numpy_create_array_from_tensor_op");
|
||||
}
|
||||
auto dtype = sourceType.getElementType();
|
||||
auto ndarrayType =
|
||||
Numpy::NdArrayType::get(dtype, self.getContext());
|
||||
auto ndarrayType = Numpy::NdArrayType::get(dtype);
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = self.pyOpBuilder.getCurrentLoc();
|
||||
auto op = opBuilder.create<Numpy::CreateArrayFromTensorOp>(
|
||||
|
@ -123,7 +145,7 @@ public:
|
|||
})
|
||||
.def("numpy_NdArrayType",
|
||||
[](BasicpyDialectHelper &self, PyType dtype) -> PyType {
|
||||
return Numpy::NdArrayType::get(dtype.type, self.getContext());
|
||||
return Numpy::NdArrayType::get(dtype.type);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
import numpy as np
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
|
||||
(2, 1)))
|
||||
|
||||
a = np.asarray([1.0, 2.0])
|
||||
b = np.asarray([3.0, 4.0])
|
||||
|
||||
|
||||
# Test the basic flow of invoking a ufunc call with constants captured from
|
||||
# a global using explicit function syntax (np.add(a, b)).
|
||||
# CHECK-LABEL: func @global_add
|
||||
@import_global
|
||||
def global_add():
|
||||
# CHECK-DAG: %[[CST_A_TENSOR:.*]] = constant dense<[1.000000e+00, 2.000000e+00]>
|
||||
# CHECK-DAG: %[[CST_B_TENSOR:.*]] = constant dense<[3.000000e+00, 4.000000e+00]>
|
||||
# CHECK-DAG: %[[A_ARRAY:.*]] = numpy.create_array_from_tensor %[[CST_A_TENSOR]]
|
||||
# CHECK-DAG: %[[B_ARRAY:.*]] = numpy.create_array_from_tensor %[[CST_B_TENSOR]]
|
||||
# CHECK-DAG: %[[A:.*]] = numpy.copy_to_tensor %[[A_ARRAY]]
|
||||
# CHECK-DAG: %[[B:.*]] = numpy.copy_to_tensor %[[B_ARRAY]]
|
||||
# CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?>
|
||||
return np.add(a, b)
|
|
@ -2,4 +2,5 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .builtin_ops import *
|
||||
from .value_coder import *
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Configures evaluation support for numpy builtin ops."""
|
||||
|
||||
from typing import Callable, Iterator, Sequence, Tuple
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
from ... import logging
|
||||
from ...interfaces import *
|
||||
from ...partial_eval_base import *
|
||||
|
||||
__all__ = [
|
||||
"get_ufuncs_from_module",
|
||||
"bind_ufuncs",
|
||||
]
|
||||
|
||||
################################################################################
|
||||
# Ufunc evaluation
|
||||
################################################################################
|
||||
|
||||
|
||||
def _default_ufunc_predicate(ufunc: np.ufunc) -> bool:
|
||||
"""Filters ufuncs based on ability to evaluate them."""
|
||||
# Support up to 2 input, 1 output functions.
|
||||
if ufunc.nin > 2 or ufunc.nout != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_ufuncs_from_module(
|
||||
*,
|
||||
module=np,
|
||||
prefix: str = "numpy.",
|
||||
predicate: Callable[[np.ufunc], bool] = _default_ufunc_predicate,
|
||||
) -> Iterator[Tuple[str, np.ufunc]]:
|
||||
"""Iterates over all ufuncs in a module.
|
||||
|
||||
Yields:
|
||||
Tuple of (prefixed_name, ufunc).
|
||||
"""
|
||||
ufunc_class = np.ufunc
|
||||
for local_name in dir(module):
|
||||
value = getattr(module, local_name)
|
||||
if isinstance(value, ufunc_class):
|
||||
if not predicate(value):
|
||||
logging.debug("Skipped ufunc: {}{} ({})", prefix, local_name, value)
|
||||
else:
|
||||
yield (prefix + local_name), value
|
||||
|
||||
|
||||
def bind_ufuncs(ufuncs: Iterator[Tuple[str, np.ufunc]],
|
||||
pe_hook: MappedPartialEvalHook):
|
||||
"""Binds a set of ufuncs to a partial eval hook."""
|
||||
for qualified_name, ufunc in ufuncs:
|
||||
pe_hook.bind_action(functools.partial(BuiltinUfuncLiveValueRef,
|
||||
qualified_name, ufunc),
|
||||
for_ref=ufunc)
|
||||
|
||||
|
||||
class BuiltinUfuncLiveValueRef(LiveValueRef):
|
||||
"""A partial evaluation that emits IR for invoking a ufunc."""
|
||||
__slots__ = ["_qualified_name", "_ufunc"]
|
||||
|
||||
def __init__(self, qualified_name: str, ufunc: np.ufunc, live_value):
|
||||
super().__init__(live_value)
|
||||
self._qualified_name = qualified_name
|
||||
self._ufunc = ufunc
|
||||
|
||||
def resolve_call(self, env: Environment, args: Sequence[ir.Value],
|
||||
keywords: Sequence[str]) -> PartialEvalResult:
|
||||
if keywords:
|
||||
return PartialEvalResult.error_message(
|
||||
"ufunc call does not currently support keyword args")
|
||||
if len(args) != self._ufunc.nin:
|
||||
return PartialEvalResult.error_message(
|
||||
"ufunc {} expected {} inputs but got {}".format(
|
||||
self._qualified_name, self._ufunc.nin, len(args)))
|
||||
ir_h = env.ir_h
|
||||
# Because a ufunc call is defined in terms of tensors and, at this stage,
|
||||
# all "public" values are ndarray, do appropriate conversions.
|
||||
tensor_args = [ir_h.numpy_copy_to_tensor_op(arg).result for arg in args]
|
||||
result_type = ir_h.numpy_unknown_tensor_type
|
||||
tensor_result = ir_h.numpy_builtin_ufunc_call_op(
|
||||
*tensor_args,
|
||||
qualified_name=self._qualified_name,
|
||||
result_type=result_type).result
|
||||
array_result = ir_h.numpy_create_array_from_tensor_op(tensor_result).result
|
||||
return PartialEvalResult.yields_ir_value(array_result)
|
|
@ -153,6 +153,9 @@ class PartialEvalType(Enum):
|
|||
class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
|
||||
"""Encapsulates the result of a partial evaluation."""
|
||||
|
||||
def as_partial_eval_result(self) -> "PartialEvalResult":
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def not_evaluated() -> "PartialEvalResult":
|
||||
return PartialEvalResult(PartialEvalType.NOT_EVALUATED, NotImplemented)
|
||||
|
@ -195,19 +198,22 @@ class LiveValueRef:
|
|||
super().__init__()
|
||||
self.live_value = live_value
|
||||
|
||||
def as_partial_eval_result(self) -> PartialEvalResult:
|
||||
return PartialEvalResult.yields_live_value(self)
|
||||
|
||||
def resolve_getattr(self, env: "Environment",
|
||||
attr_name: str) -> PartialEvalResult:
|
||||
"""Gets a named attribute from the live value."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def resolve_call(self, env: "Environment", args,
|
||||
def resolve_call(self, env: "Environment", args: Sequence[ir.Value],
|
||||
keywords: Sequence[str]) -> PartialEvalResult:
|
||||
"""Resolves a function call given 'args' and 'keywords'."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def __repr__(self):
|
||||
return "MacroValueRef({}, {})".format(self.__class__.__name__,
|
||||
self.live_value)
|
||||
return "LiveValueRef({}, {})".format(self.__class__.__name__,
|
||||
self.live_value)
|
||||
|
||||
|
||||
class PartialEvalHook:
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Partial evaluation helpers and support for built-in and common scenarios."""
|
||||
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from .interfaces import *
|
||||
from .py_value_utils import *
|
||||
from . import logging
|
||||
|
@ -76,7 +78,8 @@ class MappedPartialEvalHook(PartialEvalHook):
|
|||
|
||||
An action can be one of
|
||||
- A `lambda python_value: PartialEvalResult...`
|
||||
- A PartialEvalResult to directly return
|
||||
- An object that supports as_partial_eval_result() (either a
|
||||
PartialEvalResult or LiveValueRef qualify).
|
||||
- None to indicate that the python value should be processed directly
|
||||
"""
|
||||
__slots__ = [
|
||||
|
@ -87,30 +90,37 @@ class MappedPartialEvalHook(PartialEvalHook):
|
|||
super().__init__()
|
||||
self._value_map = PyValueMap()
|
||||
|
||||
def __repr__(self):
|
||||
return "MappedPartialEvalHook({})".format(self._value_map)
|
||||
|
||||
def partial_evaluate(self, py_value) -> PartialEvalResult:
|
||||
"""Performs partial evaluation on a python value."""
|
||||
binding = self._value_map.lookup(py_value)
|
||||
if binding is None:
|
||||
logging.debug("PARTIAL EVAL RESOLVE {}: Passthrough", py_value)
|
||||
logging.debug("LOOKUP: {}", py_value)
|
||||
action = self._value_map.lookup(py_value)
|
||||
if action is None:
|
||||
# Passthrough.
|
||||
return PartialEvalResult.yields_live_value(LiveValueRef(py_value))
|
||||
if isinstance(binding, PartialEvalResult):
|
||||
return binding
|
||||
# Attempt to call.
|
||||
try:
|
||||
binding = binding(py_value)
|
||||
assert isinstance(binding, PartialEvalResult), (
|
||||
"Expected PartialEvalResult but got {}".format(binding))
|
||||
logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding)
|
||||
return binding
|
||||
result = action(py_value).as_partial_eval_result()
|
||||
assert isinstance(result, PartialEvalResult), (
|
||||
"Expected PartialEvalResult but got {}".format(result))
|
||||
logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, result)
|
||||
return result
|
||||
except:
|
||||
return PartialEvalResult.error()
|
||||
|
||||
def _bind_action(self,
|
||||
action,
|
||||
*,
|
||||
for_ref=_Unspec,
|
||||
for_type=_Unspec,
|
||||
for_predicate=_Unspec):
|
||||
def bind_action(self,
|
||||
action: Union[PartialEvalResult, LiveValueRef,
|
||||
Callable[[Any], PartialEvalResult]],
|
||||
*,
|
||||
for_ref=_Unspec,
|
||||
for_type=_Unspec,
|
||||
for_predicate=_Unspec):
|
||||
if hasattr(action, "as_partial_eval_result"):
|
||||
# Registers a casting action.
|
||||
action = lambda pv: pv.as_partial_eval_result()
|
||||
|
||||
if for_ref is not _Unspec:
|
||||
self._value_map.bind_reference(for_ref, action)
|
||||
elif for_type is not _Unspec:
|
||||
|
@ -123,12 +133,12 @@ class MappedPartialEvalHook(PartialEvalHook):
|
|||
|
||||
def enable_getattr(self, **kwargs):
|
||||
"""Enables partial evaluation of getattr."""
|
||||
self._bind_action(
|
||||
self.bind_action(
|
||||
lambda pv: PartialEvalResult.yields_live_value(
|
||||
ResolveAttrLiveValueRef(pv)), **kwargs)
|
||||
|
||||
def enable_template_call(self, callee_name, **kwargs):
|
||||
""""Enables a global template call."""
|
||||
self._bind_action(
|
||||
self.bind_action(
|
||||
lambda pv: PartialEvalResult.yields_live_value(
|
||||
TemplateCallLiveValueRef(callee_name, pv)), **kwargs)
|
||||
|
|
|
@ -88,6 +88,19 @@ class PyValueMap:
|
|||
self._fallback_filters = list() # of: list[(lambda v, Any)]
|
||||
self._validator = validator
|
||||
|
||||
def __repr__(self):
|
||||
lines = ["refs={"]
|
||||
for ref, binding in self._reference_map.items():
|
||||
lines.append(" {}: {}".format(ref.referrent, binding))
|
||||
lines.append("}, types={")
|
||||
for t, binding in self._type_filters:
|
||||
lines.append(" {}: {}".format(t, binding))
|
||||
lines.append("}, filters={")
|
||||
for f, binding in self._fallback_filters:
|
||||
lines.append(" {}: {}".format(f, binding))
|
||||
lines.append("}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def bind_reference(self, match_value, binding):
|
||||
assert self._validator(binding), "Illegal binding"
|
||||
self._reference_map[HashableReference.create(match_value)] = binding
|
||||
|
|
|
@ -37,6 +37,12 @@ def create_test_config(target_factory: TargetFactory = GenericTarget64):
|
|||
])
|
||||
pe_hook = build_default_partial_eval_hook()
|
||||
|
||||
# Populate numpy partial evaluators.
|
||||
npc.bind_ufuncs(npc.get_ufuncs_from_module(), pe_hook)
|
||||
|
||||
if logging.debug_enabled:
|
||||
logging.debug("Partial eval mapping: {}", pe_hook)
|
||||
|
||||
return Configuration(target_factory=target_factory,
|
||||
value_coder=value_coder,
|
||||
partial_eval_hook=pe_hook)
|
||||
|
|
|
@ -50,29 +50,39 @@ class DialectHelper(Basicpy.DialectHelper):
|
|||
>>> c = ir.MLIRContext()
|
||||
>>> t = DialectHelper(c, ir.OpBuilder(c))
|
||||
>>> t.numpy_any_dtype
|
||||
!numpy.any_dtype
|
||||
!basicpy.UnknownType
|
||||
>>> t.tensor_type(t.numpy_any_dtype, [1, 2, 3])
|
||||
tensor<1x2x3x!numpy.any_dtype>
|
||||
tensor<1x2x3x!basicpy.UnknownType>
|
||||
>>> t.tensor_type(t.numpy_any_dtype)
|
||||
tensor<*x!numpy.any_dtype>
|
||||
tensor<*x!basicpy.UnknownType>
|
||||
>>> t.tensor_type(t.numpy_any_dtype, [-1, 2])
|
||||
tensor<?x2x!numpy.any_dtype>
|
||||
tensor<?x2x!basicpy.UnknownType>
|
||||
>>> t.tensor_type(t.f32_type)
|
||||
tensor<*xf32>
|
||||
>>> t.function_type([t.i32_type], [t.f32_type])
|
||||
(i32) -> f32
|
||||
>>> t.unknown_array_type
|
||||
tensor<*x!numpy.any_dtype>
|
||||
>>> t.unknown_tensor_type
|
||||
tensor<*x!basicpy.UnknownType>
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def numpy_any_dtype(self):
|
||||
return self.context.parse_type("!numpy.any_dtype")
|
||||
return self.basicpy_UnknownType
|
||||
|
||||
@property
|
||||
def numpy_unknown_tensor_type(self):
|
||||
return self.tensor_type(self.basicpy_UnknownType)
|
||||
|
||||
@property
|
||||
def unknown_array_type(self):
|
||||
return self.tensor_type(self.numpy_any_dtype)
|
||||
return self.numpy_NdArrayType(self.basicpy_UnknownType)
|
||||
|
||||
def numpy_builtin_ufunc_call_op(self, *args, qualified_name, result_type):
|
||||
"""Creates a numpy.builtin_ufunc_call op."""
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({"qualified_name": c.string_attr(qualified_name)})
|
||||
return self.op("numpy.builtin_ufunc_call", [result_type], args, attrs)
|
||||
|
||||
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
|
||||
"""Creates a numpy.ufunc_call op."""
|
||||
|
|
Loading…
Reference in New Issue