mirror of https://github.com/llvm/torch-mlir
Merge ir.Ops and ir.Types into ir.DialectHelper.
This will aid in managing hierarchies of custom dialect helpers.pull/1/head
parent
aa9ffc3a11
commit
4ebf972503
|
@ -6,25 +6,24 @@ from npcomp.native.mlir import ir
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_builtin_module",
|
"load_builtin_module",
|
||||||
"Types",
|
"DialectHelper",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Ops(ir.Ops):
|
class DialectHelper(ir.DialectHelper):
|
||||||
r"""Dialect ops.
|
r"""Dialect helper.
|
||||||
|
|
||||||
>>> c = ir.MLIRContext()
|
>>> c = ir.MLIRContext()
|
||||||
>>> t = Types(c)
|
>>> h = DialectHelper(c)
|
||||||
>>> m = c.new_module()
|
>>> m = c.new_module()
|
||||||
>>> tensor_type = t.tensor(t.f32)
|
>>> tensor_type = h.tensor_type(h.f32_type)
|
||||||
>>> ops = Ops(c)
|
>>> h.builder.insert_block_start(m.first_block)
|
||||||
>>> ops.builder.insert_block_start(m.first_block)
|
>>> f = h.func_op("foobar", h.function_type(
|
||||||
>>> f = ops.func_op("foobar", t.function(
|
|
||||||
... [tensor_type, tensor_type], [tensor_type]),
|
... [tensor_type, tensor_type], [tensor_type]),
|
||||||
... create_entry_block=True)
|
... create_entry_block=True)
|
||||||
>>> uf = ops.numpy_ufunc_call_op("numpy.add", tensor_type,
|
>>> uf = h.numpy_ufunc_call_op("numpy.add", tensor_type,
|
||||||
... *f.first_block.args)
|
... *f.first_block.args)
|
||||||
>>> _ = ops.return_op(uf.results)
|
>>> _ = h.return_op(uf.results)
|
||||||
>>> print(m.to_asm())
|
>>> print(m.to_asm())
|
||||||
<BLANKLINE>
|
<BLANKLINE>
|
||||||
<BLANKLINE>
|
<BLANKLINE>
|
||||||
|
@ -34,7 +33,27 @@ class Ops(ir.Ops):
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Types:
|
||||||
|
>>> t = DialectHelper(ir.MLIRContext())
|
||||||
|
>>> t.numpy_any_dtype
|
||||||
|
!numpy.any_dtype
|
||||||
|
>>> t.tensor_type(t.numpy_any_dtype, [1, 2, 3])
|
||||||
|
tensor<1x2x3x!numpy.any_dtype>
|
||||||
|
>>> t.tensor_type(t.numpy_any_dtype)
|
||||||
|
tensor<*x!numpy.any_dtype>
|
||||||
|
>>> t.tensor_type(t.numpy_any_dtype, [-1, 2])
|
||||||
|
tensor<?x2x!numpy.any_dtype>
|
||||||
|
>>> t.tensor_type(t.f32_type)
|
||||||
|
tensor<*xf32>
|
||||||
|
>>> t.function_type([t.i32_type], [t.f32_type])
|
||||||
|
(i32) -> f32
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.numpy_any_dtype = self.context.parse_type("!numpy.any_dtype")
|
||||||
|
|
||||||
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
|
def numpy_ufunc_call_op(self, callee_symbol, result_type, *args):
|
||||||
"""Creates a numpy.ufunc_call op."""
|
"""Creates a numpy.ufunc_call op."""
|
||||||
c = self.context
|
c = self.context
|
||||||
|
@ -43,34 +62,11 @@ class Ops(ir.Ops):
|
||||||
})
|
})
|
||||||
return self.op("numpy.ufunc_call", [result_type], args, attrs)
|
return self.op("numpy.ufunc_call", [result_type], args, attrs)
|
||||||
|
|
||||||
def numpy_narrow(self, result_type, operand):
|
def numpy_narrow_op(self, result_type, operand):
|
||||||
"""Creates a numpy.narrow op."""
|
"""Creates a numpy.narrow op."""
|
||||||
return self.op("numpy.narrow", [result_type], [operand])
|
return self.op("numpy.narrow", [result_type], [operand])
|
||||||
|
|
||||||
|
|
||||||
class Types(ir.Types):
|
|
||||||
"""Container/factory for dialect types.
|
|
||||||
|
|
||||||
>>> t = Types(ir.MLIRContext())
|
|
||||||
>>> t.numpy_any_dtype
|
|
||||||
!numpy.any_dtype
|
|
||||||
>>> t.tensor(t.numpy_any_dtype, [1, 2, 3])
|
|
||||||
tensor<1x2x3x!numpy.any_dtype>
|
|
||||||
>>> t.tensor(t.numpy_any_dtype)
|
|
||||||
tensor<*x!numpy.any_dtype>
|
|
||||||
>>> t.tensor(t.numpy_any_dtype, [-1, 2])
|
|
||||||
tensor<?x2x!numpy.any_dtype>
|
|
||||||
>>> t.tensor(t.f32)
|
|
||||||
tensor<*xf32>
|
|
||||||
>>> t.function([t.i32], [t.f32])
|
|
||||||
(i32) -> f32
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(self, context):
|
|
||||||
super().__init__(context)
|
|
||||||
self.numpy_any_dtype = context.parse_type("!numpy.any_dtype")
|
|
||||||
|
|
||||||
|
|
||||||
def load_builtin_module(context=None):
|
def load_builtin_module(context=None):
|
||||||
"""Loads a module populated with numpy built-ins.
|
"""Loads a module populated with numpy built-ins.
|
||||||
|
|
||||||
|
|
|
@ -120,22 +120,29 @@ private:
|
||||||
// Python only classes
|
// Python only classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class PyOps {
|
/// Helper for creating (possibly dialect specific) IR objects. This class
|
||||||
|
/// is intended to be subclassed on the Python side (possibly with multiple
|
||||||
|
/// inheritance) to provide Python level APIs for custom dialects. The base
|
||||||
|
/// class contains helpers for std types and ops.
|
||||||
|
class PyDialectHelper {
|
||||||
public:
|
public:
|
||||||
PyOps(std::shared_ptr<PyContext> context)
|
PyDialectHelper(std::shared_ptr<PyContext> context)
|
||||||
: pyOpBuilder(*context), context(std::move(context)) {}
|
: pyOpBuilder(*context), context(std::move(context)) {
|
||||||
|
}
|
||||||
static void bind(py::module m) {
|
static void bind(py::module m) {
|
||||||
py::class_<PyOps>(m, "Ops")
|
py::class_<PyDialectHelper>(m, "DialectHelper")
|
||||||
.def(py::init<std::shared_ptr<PyContext>>())
|
.def(py::init<std::shared_ptr<PyContext>>())
|
||||||
.def_property_readonly(
|
.def_property_readonly("builder",
|
||||||
"builder",
|
[](PyDialectHelper &self) -> PyBaseOpBuilder & {
|
||||||
[](PyOps &self) -> PyBaseOpBuilder & { return self.pyOpBuilder; })
|
return self.pyOpBuilder;
|
||||||
.def_property_readonly("context",
|
|
||||||
[](PyOps &self) -> std::shared_ptr<PyContext> {
|
|
||||||
return self.context;
|
|
||||||
})
|
})
|
||||||
|
.def_property_readonly(
|
||||||
|
"context",
|
||||||
|
[](PyDialectHelper &self) -> std::shared_ptr<PyContext> {
|
||||||
|
return self.context;
|
||||||
|
})
|
||||||
.def("op",
|
.def("op",
|
||||||
[](PyOps &self, const std::string &opNameStr,
|
[](PyDialectHelper &self, const std::string &opNameStr,
|
||||||
std::vector<PyType> pyResultTypes,
|
std::vector<PyType> pyResultTypes,
|
||||||
std::vector<PyValue> pyOperands,
|
std::vector<PyValue> pyOperands,
|
||||||
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
||||||
|
@ -163,7 +170,7 @@ public:
|
||||||
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
||||||
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
||||||
.def("func_op",
|
.def("func_op",
|
||||||
[](PyOps &self, const std::string &name, PyType type,
|
[](PyDialectHelper &self, const std::string &name, PyType type,
|
||||||
bool createEntryBlock) {
|
bool createEntryBlock) {
|
||||||
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
||||||
if (!functionType) {
|
if (!functionType) {
|
||||||
|
@ -188,64 +195,55 @@ public:
|
||||||
R"(Creates a new `func` op, optionally creating an entry block.
|
R"(Creates a new `func` op, optionally creating an entry block.
|
||||||
If an entry block is created, the builder will be positioned
|
If an entry block is created, the builder will be positioned
|
||||||
to its start.)")
|
to its start.)")
|
||||||
.def("return_op", [](PyOps &self, std::vector<PyValue> pyOperands) {
|
.def("return_op",
|
||||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
[](PyDialectHelper &self, std::vector<PyValue> pyOperands) {
|
||||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
SmallVector<Value, 4> operands(pyOperands.begin(), pyOperands.end());
|
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||||
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||||
});
|
pyOperands.end());
|
||||||
}
|
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
||||||
PyOpBuilder pyOpBuilder;
|
})
|
||||||
std::shared_ptr<PyContext> context;
|
|
||||||
};
|
|
||||||
|
|
||||||
class PyTypes {
|
// Types.
|
||||||
public:
|
.def("integer_type",
|
||||||
PyTypes(std::shared_ptr<PyContext> context) : context(std::move(context)) {}
|
[](PyDialectHelper &self, unsigned width) {
|
||||||
static void bind(py::module m) {
|
|
||||||
py::class_<PyTypes>(m, "Types")
|
|
||||||
.def(py::init<std::shared_ptr<PyContext>>())
|
|
||||||
.def_property_readonly("context",
|
|
||||||
[](PyTypes &self) { return self.context; })
|
|
||||||
.def("integer",
|
|
||||||
[](PyTypes &self, unsigned width) {
|
|
||||||
return PyType(IntegerType::get(width, &self.context->context));
|
return PyType(IntegerType::get(width, &self.context->context));
|
||||||
},
|
},
|
||||||
py::arg("width") = 32)
|
py::arg("width") = 32)
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"i1",
|
"i1_type",
|
||||||
[](PyTypes &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(IntegerType::get(1, &self.context->context));
|
return PyType(IntegerType::get(1, &self.context->context));
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"i16",
|
"i16_type",
|
||||||
[](PyTypes &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(IntegerType::get(32, &self.context->context));
|
return PyType(IntegerType::get(32, &self.context->context));
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"i32",
|
"i32_type",
|
||||||
[](PyTypes &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(IntegerType::get(32, &self.context->context));
|
return PyType(IntegerType::get(32, &self.context->context));
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"i64",
|
"i64_type",
|
||||||
[](PyTypes &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(IntegerType::get(64, &self.context->context));
|
return PyType(IntegerType::get(64, &self.context->context));
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"f32",
|
"f32_type",
|
||||||
[](PyTypes &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(
|
return PyType(
|
||||||
FloatType::get(StandardTypes::F32, &self.context->context));
|
FloatType::get(StandardTypes::F32, &self.context->context));
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"f64",
|
"f64_type",
|
||||||
[](PyTypes &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(
|
return PyType(
|
||||||
FloatType::get(StandardTypes::F64, &self.context->context));
|
FloatType::get(StandardTypes::F64, &self.context->context));
|
||||||
})
|
})
|
||||||
.def("tensor",
|
.def("tensor_type",
|
||||||
[](PyTypes &self, PyType elementType,
|
[](PyDialectHelper &self, PyType elementType,
|
||||||
llvm::Optional<std::vector<int64_t>> shape) {
|
llvm::Optional<std::vector<int64_t>> shape) {
|
||||||
if (!elementType.type) {
|
if (!elementType.type) {
|
||||||
throw py::raiseValueError("Null element type");
|
throw py::raiseValueError("Null element type");
|
||||||
|
@ -258,22 +256,22 @@ public:
|
||||||
},
|
},
|
||||||
py::arg("element_type"),
|
py::arg("element_type"),
|
||||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||||
.def("function", [](PyTypes &self, std::vector<PyType> inputs,
|
.def("function_type",
|
||||||
std::vector<PyType> results) {
|
[](PyDialectHelper &self, std::vector<PyType> inputs,
|
||||||
llvm::SmallVector<Type, 4> inputTypes;
|
std::vector<PyType> results) {
|
||||||
llvm::SmallVector<Type, 1> resultTypes;
|
llvm::SmallVector<Type, 4> inputTypes;
|
||||||
for (auto input : inputs) {
|
llvm::SmallVector<Type, 1> resultTypes;
|
||||||
inputTypes.push_back(input.type);
|
for (auto input : inputs) {
|
||||||
}
|
inputTypes.push_back(input.type);
|
||||||
for (auto result : results) {
|
}
|
||||||
resultTypes.push_back(result.type);
|
for (auto result : results) {
|
||||||
}
|
resultTypes.push_back(result.type);
|
||||||
return PyType(FunctionType::get(inputTypes, resultTypes,
|
}
|
||||||
&self.context->context));
|
return PyType(FunctionType::get(inputTypes, resultTypes,
|
||||||
});
|
&self.context->context));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
PyOpBuilder pyOpBuilder;
|
||||||
private:
|
|
||||||
std::shared_ptr<PyContext> context;
|
std::shared_ptr<PyContext> context;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -285,8 +283,7 @@ void defineMlirIrModule(py::module m) {
|
||||||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||||
|
|
||||||
// Python only types.
|
// Python only types.
|
||||||
PyOps::bind(m);
|
PyDialectHelper::bind(m);
|
||||||
PyTypes::bind(m);
|
|
||||||
|
|
||||||
// Utility types.
|
// Utility types.
|
||||||
PyBlockList::bind(m, "BlockList");
|
PyBlockList::bind(m, "BlockList");
|
||||||
|
|
|
@ -36,7 +36,8 @@ class TraceInvocation(
|
||||||
|
|
||||||
|
|
||||||
class EmissionRequest(
|
class EmissionRequest(
|
||||||
namedtuple("EmissionRequest", ["input_ssa_values", "ops", "types", "extra"],
|
namedtuple("EmissionRequest",
|
||||||
|
["input_ssa_values", "dialect_helper", "extra"],
|
||||||
defaults=(None,))):
|
defaults=(None,))):
|
||||||
"""Represents the result of processing inputs from an invocation.
|
"""Represents the result of processing inputs from an invocation.
|
||||||
|
|
||||||
|
@ -47,8 +48,7 @@ class EmissionRequest(
|
||||||
blackbox mechanism to transfer un-tracked state from an invocation to
|
blackbox mechanism to transfer un-tracked state from an invocation to
|
||||||
emission.
|
emission.
|
||||||
|
|
||||||
The `ops` and `types` fields correspond to mlir.ir.Ops and mlir.ir.Types
|
The `dialect_helper` fields correspond to mlir.ir.DialectHelper.
|
||||||
instances respectively.
|
|
||||||
"""
|
"""
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
|
@ -166,9 +166,10 @@ class GenericCallUfuncEmitter(FuncEmitter):
|
||||||
return py_results[0]
|
return py_results[0]
|
||||||
|
|
||||||
def emit(self, request: EmissionRequest):
|
def emit(self, request: EmissionRequest):
|
||||||
op_result_type = request.types.tensor(request.types.numpy_any_dtype)
|
h = request.dialect_helper
|
||||||
call_op = request.ops.numpy_ufunc_call_op(self._ufunc_name, op_result_type,
|
op_result_type = h.tensor_type(h.numpy_any_dtype)
|
||||||
*request.input_ssa_values)
|
call_op = h.numpy_ufunc_call_op(self._ufunc_name, op_result_type,
|
||||||
|
*request.input_ssa_values)
|
||||||
return call_op.results
|
return call_op.results
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,9 +211,10 @@ class GenericArrayFuncEmitter(FuncEmitter):
|
||||||
return tuple(py_results)
|
return tuple(py_results)
|
||||||
|
|
||||||
def emit(self, request: EmissionRequest):
|
def emit(self, request: EmissionRequest):
|
||||||
op_result_types = [request.types.tensor(request.types.numpy_any_dtype)
|
h = request.dialect_helper
|
||||||
|
op_result_types = [h.tensor_type(h.numpy_any_dtype)
|
||||||
] * self._nresults
|
] * self._nresults
|
||||||
op = request.ops.op(self._op_name, op_result_types,
|
op = h.op(self._op_name, op_result_types,
|
||||||
request.input_ssa_values)
|
request.input_ssa_values)
|
||||||
return op.results
|
return op.results
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,7 @@ class ModuleBuilder:
|
||||||
# TODO: Instead of bootstrapping a large module, populate imports
|
# TODO: Instead of bootstrapping a large module, populate imports
|
||||||
# dynamically.
|
# dynamically.
|
||||||
self.module = Numpy.load_builtin_module(self.context)
|
self.module = Numpy.load_builtin_module(self.context)
|
||||||
self.ops = Numpy.Ops(self.context)
|
self.helper = Numpy.DialectHelper(self.context)
|
||||||
self.types = Numpy.Types(self.context)
|
|
||||||
self.emitters = (emitter_registry
|
self.emitters = (emitter_registry
|
||||||
if emitter_registry else EmitterRegistry.create_default())
|
if emitter_registry else EmitterRegistry.create_default())
|
||||||
|
|
||||||
|
@ -46,13 +45,12 @@ class FunctionTracer(TraceContext):
|
||||||
"_args_array_params",
|
"_args_array_params",
|
||||||
"_f",
|
"_f",
|
||||||
"_f_types",
|
"_f_types",
|
||||||
|
"_helper",
|
||||||
"_mlir_m",
|
"_mlir_m",
|
||||||
"_mlir_c",
|
"_mlir_c",
|
||||||
"_python_args",
|
"_python_args",
|
||||||
"_ops",
|
|
||||||
"_result_array_params",
|
"_result_array_params",
|
||||||
"_traced_arrays",
|
"_traced_arrays",
|
||||||
"_types",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
||||||
|
@ -65,8 +63,7 @@ class FunctionTracer(TraceContext):
|
||||||
# Alias some parent members for convenience.
|
# Alias some parent members for convenience.
|
||||||
self._mlir_m = module_builder.module
|
self._mlir_m = module_builder.module
|
||||||
self._mlir_c = module_builder.context
|
self._mlir_c = module_builder.context
|
||||||
self._ops = module_builder.ops
|
self._helper = module_builder.helper
|
||||||
self._types = module_builder.types
|
|
||||||
|
|
||||||
# Extract ArrayParams for all args and results.
|
# Extract ArrayParams for all args and results.
|
||||||
self._args_array_params = [
|
self._args_array_params = [
|
||||||
|
@ -86,7 +83,7 @@ class FunctionTracer(TraceContext):
|
||||||
# TODO: More sophisticated signature merging
|
# TODO: More sophisticated signature merging
|
||||||
# TODO: Multiple results
|
# TODO: Multiple results
|
||||||
# TODO: Error reporting
|
# TODO: Error reporting
|
||||||
ops = self._ops
|
h = self._helper
|
||||||
py_results = (self.epf.pyfunc(*self._python_args),)
|
py_results = (self.epf.pyfunc(*self._python_args),)
|
||||||
if len(py_results) != len(self._f_types):
|
if len(py_results) != len(self._f_types):
|
||||||
raise TracingError("Traced function returned != %d results: %r" % (
|
raise TracingError("Traced function returned != %d results: %r" % (
|
||||||
|
@ -102,8 +99,8 @@ class FunctionTracer(TraceContext):
|
||||||
raise TracingError("Unregistered traced array: %r", (py_result,))
|
raise TracingError("Unregistered traced array: %r", (py_result,))
|
||||||
# narrow to declared result type.
|
# narrow to declared result type.
|
||||||
return_operands.extend(
|
return_operands.extend(
|
||||||
ops.numpy_narrow(mlir_result_type, mlir_result).results)
|
h.numpy_narrow_op(mlir_result_type, mlir_result).results)
|
||||||
ops.return_op(return_operands)
|
h.return_op(return_operands)
|
||||||
|
|
||||||
def set_traced_array(self, traced_array, value):
|
def set_traced_array(self, traced_array, value):
|
||||||
"""Sets the current SSA value for a traced_array."""
|
"""Sets the current SSA value for a traced_array."""
|
||||||
|
@ -124,8 +121,7 @@ class FunctionTracer(TraceContext):
|
||||||
def _create_mlir_function(self):
|
def _create_mlir_function(self):
|
||||||
mlir_c = self._mlir_c
|
mlir_c = self._mlir_c
|
||||||
mlir_m = self._mlir_m
|
mlir_m = self._mlir_m
|
||||||
ops = self._ops
|
h = self._helper
|
||||||
types = self._types
|
|
||||||
epf = self.epf
|
epf = self.epf
|
||||||
f_args = [
|
f_args = [
|
||||||
mlir_c.parse_type(ap.mlir_tensor_type_asm)
|
mlir_c.parse_type(ap.mlir_tensor_type_asm)
|
||||||
|
@ -134,9 +130,9 @@ class FunctionTracer(TraceContext):
|
||||||
f_types = [
|
f_types = [
|
||||||
mlir_c.parse_type(self._result_array_params.mlir_tensor_type_asm)
|
mlir_c.parse_type(self._result_array_params.mlir_tensor_type_asm)
|
||||||
]
|
]
|
||||||
ops.builder.insert_before_terminator(mlir_m.first_block)
|
h.builder.insert_before_terminator(mlir_m.first_block)
|
||||||
f_type = types.function(f_args, f_types)
|
f_type = h.function_type(f_args, f_types)
|
||||||
f = ops.func_op(epf.__name__, f_type, create_entry_block=True)
|
f = h.func_op(epf.__name__, f_type, create_entry_block=True)
|
||||||
return f, f_types
|
return f, f_types
|
||||||
|
|
||||||
def _create_trace_roots(self):
|
def _create_trace_roots(self):
|
||||||
|
@ -179,8 +175,7 @@ class FunctionTracer(TraceContext):
|
||||||
tv_map = emitter.map_invocation(invocation)
|
tv_map = emitter.map_invocation(invocation)
|
||||||
input_ssa_values = self._resolve_input_ssa_values(tv_map.input_trace_values)
|
input_ssa_values = self._resolve_input_ssa_values(tv_map.input_trace_values)
|
||||||
request = EmissionRequest(input_ssa_values,
|
request = EmissionRequest(input_ssa_values,
|
||||||
ops=self._ops,
|
dialect_helper=self._helper,
|
||||||
types=self._types,
|
|
||||||
extra=tv_map.extra)
|
extra=tv_map.extra)
|
||||||
result_ssa_values = emitter.emit(request)
|
result_ssa_values = emitter.emit(request)
|
||||||
py_values = self._resolve_result_py_values(tv_map.result_trace_value_types,
|
py_values = self._resolve_result_py_values(tv_map.result_trace_value_types,
|
||||||
|
|
Loading…
Reference in New Issue