mirror of https://github.com/llvm/torch-mlir
168 lines
5.5 KiB
Python
168 lines
5.5 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
"""General utilities for working with MLIR."""
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
from mlir import ir as _ir
|
|
from npcomp import _cext
|
|
|
|
__all__ = [
|
|
"ImportContext",
|
|
]
|
|
|
|
|
|
class ImportContext:
|
|
"""Simple container for things that we update while importing.
|
|
|
|
This is also where we stash various helpers to work around awkward/missing
|
|
MLIR Python API features.
|
|
"""
|
|
__slots__ = [
|
|
"context",
|
|
"loc",
|
|
"module",
|
|
"_ip_stack",
|
|
|
|
# Cached types.
|
|
"unknown_type",
|
|
"bool_type",
|
|
"bytes_type",
|
|
"ellipsis_type",
|
|
"i1_type",
|
|
"index_type",
|
|
"none_type",
|
|
"str_type",
|
|
"unknown_array_type",
|
|
"unknown_tensor_type",
|
|
|
|
# Cached attributes.
|
|
"i1_true",
|
|
"i1_false",
|
|
]
|
|
|
|
def __init__(self, context: Optional[_ir.Context]):
|
|
self.context = _ir.Context() if not context else context
|
|
_cext.register_all_dialects(self.context)
|
|
|
|
self.loc = _ir.Location.unknown(context=self.context) # type: _ir.Location
|
|
self.module = None # type: Optional[_ir.Module]
|
|
self._ip_stack = []
|
|
|
|
# Cache some types and attributes.
|
|
with self.context:
|
|
# Types.
|
|
# TODO: Consolidate numpy.any_dtype and basicpy.UnknownType.
|
|
self.unknown_type = _ir.Type.parse("!basicpy.UnknownType")
|
|
self.bool_type = _ir.Type.parse("!basicpy.BoolType")
|
|
self.bytes_type = _ir.Type.parse("!basicpy.BytesType")
|
|
self.ellipsis_type = _ir.Type.parse("!basicpy.EllipsisType")
|
|
self.none_type = _ir.Type.parse("!basicpy.NoneType")
|
|
self.str_type = _ir.Type.parse("!basicpy.StrType")
|
|
self.i1_type = _ir.IntegerType.get_signless(1)
|
|
self.index_type = _ir.IndexType.get()
|
|
self.unknown_tensor_type = _ir.UnrankedTensorType.get(self.unknown_type,
|
|
loc=self.loc)
|
|
self.unknown_array_type = _cext.shaped_to_ndarray_type(
|
|
self.unknown_tensor_type)
|
|
|
|
# Attributes.
|
|
self.i1_true = _ir.IntegerAttr.get(self.i1_type, 1)
|
|
self.i1_false = _ir.IntegerAttr.get(self.i1_type, 0)
|
|
|
|
def set_file_line_col(self, file: str, line: int, col: int):
|
|
self.loc = _ir.Location.file(file, line, col, context=self.context)
|
|
|
|
def push_ip(self, new_ip: _ir.InsertionPoint):
|
|
self._ip_stack.append(new_ip)
|
|
|
|
def pop_ip(self):
|
|
assert self._ip_stack, "Mismatched push_ip/pop_ip: stack is empty on pop"
|
|
del self._ip_stack[-1]
|
|
|
|
@property
|
|
def ip(self):
|
|
assert self._ip_stack, "InsertionPoint requested but stack is empty"
|
|
return self._ip_stack[-1]
|
|
|
|
def insert_before_terminator(self, block: _ir.Block):
|
|
self.push_ip(_ir.InsertionPoint.at_block_terminator(block))
|
|
|
|
def insert_end_of_block(self, block: _ir.Block):
|
|
self.push_ip(_ir.InsertionPoint(block))
|
|
|
|
def FuncOp(self, name: str, func_type: _ir.Type,
|
|
create_entry_block: bool) -> Tuple[_ir.Operation, _ir.Block]:
|
|
"""Creates a |func| op.
|
|
|
|
TODO: This should really be in the MLIR API.
|
|
Returns:
|
|
(operation, entry_block)
|
|
"""
|
|
with self.context, self.loc:
|
|
attrs = {
|
|
"type": _ir.TypeAttr.get(func_type),
|
|
"sym_name": _ir.StringAttr.get(name),
|
|
}
|
|
op = _ir.Operation.create("func", regions=1, attributes=attrs, ip=self.ip)
|
|
body_region = op.regions[0]
|
|
entry_block = body_region.blocks.append(*func_type.inputs)
|
|
return op, entry_block
|
|
|
|
def basicpy_ExecOp(self):
|
|
"""Creates a basicpy.exec op.
|
|
|
|
Returns:
|
|
Insertion point to the body.
|
|
"""
|
|
op = _ir.Operation.create("basicpy.exec",
|
|
regions=1,
|
|
ip=self.ip,
|
|
loc=self.loc)
|
|
b = op.regions[0].blocks.append()
|
|
return _ir.InsertionPoint(b)
|
|
|
|
def basicpy_FuncTemplateCallOp(self, result_type, callee_symbol, args,
|
|
arg_names):
|
|
with self.loc, self.ip:
|
|
attributes = {
|
|
"callee":
|
|
_ir.FlatSymbolRefAttr.get(callee_symbol),
|
|
"arg_names":
|
|
_ir.ArrayAttr.get([_ir.StringAttr.get(n) for n in arg_names]),
|
|
}
|
|
op = _ir.Operation.create("basicpy.func_template_call",
|
|
results=[result_type],
|
|
operands=args,
|
|
attributes=attributes,
|
|
ip=self.ip)
|
|
return op
|
|
|
|
def scf_IfOp(self, results, condition: _ir.Value, with_else_region: bool):
|
|
"""Creates an SCF if op.
|
|
|
|
Returns:
|
|
(if_op, then_ip, else_ip) if with_else_region, otherwise (if_op, then_ip)
|
|
"""
|
|
op = _ir.Operation.create("scf.if",
|
|
results=results,
|
|
operands=[condition],
|
|
regions=2 if with_else_region else 1,
|
|
loc=self.loc,
|
|
ip=self.ip)
|
|
then_region = op.regions[0]
|
|
then_block = then_region.blocks.append()
|
|
if with_else_region:
|
|
else_region = op.regions[1]
|
|
else_block = else_region.blocks.append()
|
|
return op, _ir.InsertionPoint(then_block), _ir.InsertionPoint(else_block)
|
|
else:
|
|
return op, _ir.InsertionPoint(then_block)
|
|
|
|
def scf_YieldOp(self, operands):
|
|
return _ir.Operation.create("scf.yield",
|
|
operands=operands,
|
|
loc=self.loc,
|
|
ip=self.ip)
|