torch-mlir/python/npcomp/compiler/utils/mlir_utils.py

168 lines
5.5 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""General utilities for working with MLIR."""
from typing import Optional, Tuple
from mlir import ir as _ir
from npcomp import _cext
__all__ = [
"ImportContext",
]
class ImportContext:
"""Simple container for things that we update while importing.
This is also where we stash various helpers to work around awkward/missing
MLIR Python API features.
"""
__slots__ = [
"context",
"loc",
"module",
"_ip_stack",
# Cached types.
"unknown_type",
"bool_type",
"bytes_type",
"ellipsis_type",
"i1_type",
"index_type",
"none_type",
"str_type",
"unknown_array_type",
"unknown_tensor_type",
# Cached attributes.
"i1_true",
"i1_false",
]
def __init__(self, context: Optional[_ir.Context]):
self.context = _ir.Context() if not context else context
_cext.register_all_dialects(self.context)
self.loc = _ir.Location.unknown(context=self.context) # type: _ir.Location
self.module = None # type: Optional[_ir.Module]
self._ip_stack = []
# Cache some types and attributes.
with self.context:
# Types.
# TODO: Consolidate numpy.any_dtype and basicpy.UnknownType.
self.unknown_type = _ir.Type.parse("!basicpy.UnknownType")
self.bool_type = _ir.Type.parse("!basicpy.BoolType")
self.bytes_type = _ir.Type.parse("!basicpy.BytesType")
self.ellipsis_type = _ir.Type.parse("!basicpy.EllipsisType")
self.none_type = _ir.Type.parse("!basicpy.NoneType")
self.str_type = _ir.Type.parse("!basicpy.StrType")
self.i1_type = _ir.IntegerType.get_signless(1)
self.index_type = _ir.IndexType.get()
self.unknown_tensor_type = _ir.UnrankedTensorType.get(self.unknown_type,
loc=self.loc)
self.unknown_array_type = _cext.shaped_to_ndarray_type(
self.unknown_tensor_type)
# Attributes.
self.i1_true = _ir.IntegerAttr.get(self.i1_type, 1)
self.i1_false = _ir.IntegerAttr.get(self.i1_type, 0)
def set_file_line_col(self, file: str, line: int, col: int):
self.loc = _ir.Location.file(file, line, col, context=self.context)
def push_ip(self, new_ip: _ir.InsertionPoint):
self._ip_stack.append(new_ip)
def pop_ip(self):
assert self._ip_stack, "Mismatched push_ip/pop_ip: stack is empty on pop"
del self._ip_stack[-1]
@property
def ip(self):
assert self._ip_stack, "InsertionPoint requested but stack is empty"
return self._ip_stack[-1]
def insert_before_terminator(self, block: _ir.Block):
self.push_ip(_ir.InsertionPoint.at_block_terminator(block))
def insert_end_of_block(self, block: _ir.Block):
self.push_ip(_ir.InsertionPoint(block))
def FuncOp(self, name: str, func_type: _ir.Type,
create_entry_block: bool) -> Tuple[_ir.Operation, _ir.Block]:
"""Creates a |func| op.
TODO: This should really be in the MLIR API.
Returns:
(operation, entry_block)
"""
with self.context, self.loc:
attrs = {
"type": _ir.TypeAttr.get(func_type),
"sym_name": _ir.StringAttr.get(name),
}
op = _ir.Operation.create("func", regions=1, attributes=attrs, ip=self.ip)
body_region = op.regions[0]
entry_block = body_region.blocks.append(*func_type.inputs)
return op, entry_block
def basicpy_ExecOp(self):
"""Creates a basicpy.exec op.
Returns:
Insertion point to the body.
"""
op = _ir.Operation.create("basicpy.exec",
regions=1,
ip=self.ip,
loc=self.loc)
b = op.regions[0].blocks.append()
return _ir.InsertionPoint(b)
def basicpy_FuncTemplateCallOp(self, result_type, callee_symbol, args,
arg_names):
with self.loc, self.ip:
attributes = {
"callee":
_ir.FlatSymbolRefAttr.get(callee_symbol),
"arg_names":
_ir.ArrayAttr.get([_ir.StringAttr.get(n) for n in arg_names]),
}
op = _ir.Operation.create("basicpy.func_template_call",
results=[result_type],
operands=args,
attributes=attributes,
ip=self.ip)
return op
def scf_IfOp(self, results, condition: _ir.Value, with_else_region: bool):
"""Creates an SCF if op.
Returns:
(if_op, then_ip, else_ip) if with_else_region, otherwise (if_op, then_ip)
"""
op = _ir.Operation.create("scf.if",
results=results,
operands=[condition],
regions=2 if with_else_region else 1,
loc=self.loc,
ip=self.ip)
then_region = op.regions[0]
then_block = then_region.blocks.append()
if with_else_region:
else_region = op.regions[1]
else_block = else_region.blocks.append()
return op, _ir.InsertionPoint(then_block), _ir.InsertionPoint(else_block)
else:
return op, _ir.InsertionPoint(then_block)
def scf_YieldOp(self, operands):
return _ir.Operation.create("scf.yield",
operands=operands,
loc=self.loc,
ip=self.ip)