""" Translator from Torch.FX to MLIR. The following defines a set of classes that take a module generated by the `torch.fx.experimental.fx_acc.acc_tracer` function and converts parts of it into an MLIR module. The expected use for this module is to use the function `build_module(py_module: torch.fx.GraphModule) -> ir.Module` to convert the output from the tracer into MLIR using the `torch` dialect. This file is licensed under a pytorch-style license See frontends/pytorch/LICENSE for license information. """ # pylint: disable=no-member, no-name-in-module, invalid-name, fixme from typing import MutableMapping, Mapping, Optional, List, Callable, Any import abc from itertools import chain from torch_mlir import ir import torch_mlir.dialects.torch as torch_d from torch_mlir.dialects import builtin, std import torch.fx from torch.fx.experimental.fx_acc import acc_ops from .torch_mlir_types import TorchTensorType, PythonType, TorchNnModuleType Environment = MutableMapping[torch.fx.Node, ir.Value] class _Builder(abc.ABC): """ Abstract class for an MLIR builder. A builder is an object that takes a torch.fx.GraphModule and an ir.Module, and inserts information into the ir.Module using information from the torch.fx.GraphModule. The builders are expected to modify the ir.Module in place. This means that using a builder on the same ir.Module twice will result in duplicated information in the ir.Module that is returned. The builders should not modify the torch.fx.GraphModule. The expected use of the builders is quite simple. 1. Initialize builder 2. call `to_mlir` method to get the updated ir.Module Parameters ---------- py_module: torch.fx.GraphModule GraphModule produced by the `acc_tracer` from `torch.fx.experimental.fx_acc`. mlir_module: ir.Module `ir.Module` that will be modified to include the MLIR generated by this builder. Attributes ---------- py_module: torch.fx.GraphModule mlir_module: ir.Module context: ir.Context Context used by the `mlir_module`. module_ip: ir.InsertionPoint Insertion point for the body of the `mlir_module`. loc: ir.Location Used to keep track of source code location information. class_type_name: str Qualified name of the class given by the type of `py_module`. Methods ------- to_mlir() -> ir.Module Insert into `mlir_module` the MLIR produced by the builder. """ def __init__(self, py_module: torch.fx.GraphModule, mlir_module: ir.Module): self.py_module = py_module self.mlir_module = mlir_module self.context = mlir_module.context self.module_ip = ir.InsertionPoint(mlir_module.body) # TODO: find a way to get a real location self.loc = ir.Location.unknown(self.context) # TODO: is qualified module name necessary? self.class_type_name = type(py_module).__name__ @abc.abstractmethod def to_mlir(self) -> ir.Module: """ Insert into `mlir_module` the MLIR produced by the builder. Returns ------- ir.Module Modified `mlir_module` with the new MLIR produced. """ class _ClassDeclAndInitBuilder(_Builder): """ Builder for creating a class in MLIR with attributes initialized. This builder performs the following two tasks: 1. Create an MLIR class declaration based on the public attributes of the `py_module` as well as the `forward` method. 2. Create MLIR that initializes each attribute of the declaration. Parameters ---------- py_module: torch.fx.GraphModule GraphModule produced by the `acc_tracer` from `torch.fx.experimental.fx_acc`. mlir_module: ir.Module `ir.Module` that will be modified to include the MLIR generated by this builder. Attributes ---------- class_type_ip : Optional[ir.InsertionPoint] Insertion point for `torch_d.ClassTypeOp`. nn_module_ip : Optional[ir.InsertionPoint] Insertion point for `torch_d.NnModuleOp`. Methods ------- to_mlir() -> ir.Module Insert into `mlir_module` the MLIR produced by the builder. """ def __init__(self, py_module: torch.fx.GraphModule, mlir_module: ir.Module): super().__init__(py_module, mlir_module) self.class_type_ip: Optional[ir.InsertionPoint] = None self.nn_module_ip: Optional[ir.InsertionPoint] = None def to_mlir(self) -> ir.Module: with self.context: class_name_attr = ir.StringAttr.get(self.class_type_name) class_type_op = torch_d.ClassTypeOp(class_name_attr, loc=self.loc, ip=self.module_ip) new_class_block = class_type_op.regions[0].blocks.append() self.class_type_ip = ir.InsertionPoint(new_class_block) module_type = TorchNnModuleType(self.class_type_name ).to_mlir(self.context) nn_module_op = torch_d.NnModuleOp(module_type, loc=self.loc, ip=self.module_ip) new_nn_module_block = nn_module_op.regions[0].blocks.append() self.nn_module_ip = ir.InsertionPoint(new_nn_module_block) self._insert_attr_declarations_and_definitions() self._insert_forward_method_declaration() torch_d.ClassTypeTerminatorOp(loc=self.loc, ip=self.class_type_ip) torch_d.NnModuleTerminatorOp(loc=self.loc, ip=self.nn_module_ip) return self.mlir_module def _insert_attr_declarations_and_definitions(self): # TODO: not sure how good this definition is for unhidden vars unhidden_vars = filter(lambda v: not v[0].startswith('_'), self.py_module.__dict__.items()) # TODO: is anything else needed? There are some hidden # attributes that get added by the torch.jit.script # compilation pipeline, such as: # torch.attr private "_is_full_backward_hook" # that are not being added here attrs = chain(unhidden_vars, self.py_module.named_parameters()) for name, value in attrs: type_attr: Optional[ir.TypeAttr] = None operand: Optional[ir.OpResult] = None # TODO: this should be meta-programmable if isinstance(value, bool): with self.context: bool_type = PythonType(bool).to_mlir(self.context) type_attr = ir.TypeAttr.get(bool_type) bool_attr = ir.BoolAttr.get(value) operand = torch_d.ConstantBoolOp( bool_type, bool_attr, loc=self.loc, ip=self.module_ip).result else: err = f'Unsupported attribute type: {type(value)}' raise NotImplementedError(err) assert type_attr is not None and operand is not None, \ 'Each clause must specify a value for`type_attr` and `operand`' with self.context: name_attr = ir.StringAttr.get(name) # TODO: don't hardcode `private` field in `AttrOp` torch_d.AttrOp(name_attr, type_attr, True, loc=self.loc, ip=self.class_type_ip) torch_d.SlotOp(name_attr, operand, loc=self.loc, ip=self.nn_module_ip) def _insert_forward_method_declaration(self): if not hasattr(self.py_module, 'forward'): return with self.context: method_name = 'forward' name_attr = ir.StringAttr.get(method_name) qualified_name = f'{self.class_type_name}.{method_name}' # TODO: is there a nice python binding for this? function_attr = ir.Attribute.parse(f'@{qualified_name}') # TODO: don't hardcode `private` field in `AttrOp` torch_d.MethodOp(name_attr, function_attr, False, loc=self.loc, ip=self.class_type_ip) class _ForwardFunctionBuilderError(Exception): def __init__(self, value: str): super().__init__() self.value = value def __str__(self) -> str: return self.value class _ForwardFunctionBuilder(_Builder): """ Builder for converting the forward method into MLIR. This builder transverses the `torch.fx.Graph` of the `py_module`, and translates the operations into MLIR. Parameters ---------- py_module: torch.fx.GraphModule GraphModule produced by the `acc_tracer` from `torch.fx.experimental.fx_acc`. mlir_module: ir.Module `ir.Module` that will be modified to include the MLIR generated by this builder. Attributes ---------- func_ip : Optional[ir.InsertionPoint] Insertion point for `torch_d.FuncOp` representing the forward method. env : Environment Used to keep track of the `ir.Value` corresponding to each `torch.fx.Node` that has already been handled. Methods ------- to_mlir() -> ir.Module Insert into `mlir_module` the MLIR produced by the builder. """ def __init__(self, py_module: torch.fx.GraphModule, mlir_module: ir.Module): super().__init__(py_module, mlir_module) self.func_ip: Optional[ir.InsertionPoint] = None self.env: Environment = {} def to_mlir(self) -> ir.Module: tensor_type = TorchTensorType().to_mlir(self.context) module_type = TorchNnModuleType(self.class_type_name ).to_mlir(self.context) # TODO: currently I am assuming that forward always returns a tensor func_type = ([module_type] + self._get_arg_types(), [tensor_type]) with self.context: # TODO: Don't hardcode method name # TODO: is visibility always private? func_op = builtin.FuncOp(f'{self.class_type_name}.forward', func_type, visibility='private', loc=self.loc, ip=self.module_ip) func_op.add_entry_block() self.func_ip = ir.InsertionPoint(func_op.entry_block) self._initialize_environment(func_op.entry_block.arguments) for node in self.py_module.graph.nodes: if node.op == 'call_function': result = self._insert_function_call(node) self.env[node] = result elif node.op == 'output': std.ReturnOp([self.env[node_arg] for node_arg in node.args], loc=self.loc, ip=self.func_ip) elif node.op == 'placeholder': continue elif node.op == 'call_module': err = f'Unsupported node.op type: {node.op}' raise NotImplementedError(err) elif node.op == 'get_attr': err = f'Unsupported node.op type: {node.op}' raise NotImplementedError(err) else: err = f'Unsupported node.op type: {node.op}' raise NotImplementedError(err) return self.mlir_module def _initialize_environment(self, arg_list: ir.BlockArgumentList) -> None: placeholders = filter(lambda node: node.op == 'placeholder', self.py_module.graph.nodes) self_type = TorchNnModuleType(self.class_type_name ).to_mlir(self.context) non_self_args = filter(lambda arg: arg.type != self_type, arg_list) self.env.update(zip(placeholders, non_self_args)) def _get_arg_types(self) -> List[ir.Type]: operands = filter(lambda node: node.op == 'placeholder', self.py_module.graph.nodes) types = [] for operand in operands: type_ = operand.kwargs.get('torch_mlir_type') types.append(type_.to_mlir(self.context)) return types def _insert_function_call(self, f_node: torch.fx.Node) -> ir.OpResult: assert f_node.op == 'call_function' args: MutableMapping[str, ir.Value] = {} for name, arg_node in f_node.kwargs.items(): if isinstance(arg_node, torch.fx.Node): args[name] = self.env[arg_node] if isinstance(f_node.target, str): err = f'f_node.targe = {f_node.target} must be of type \ Callable[..., Any], not str. Make sure the torch.fx.Graph has been \ normalized to using torch.fx.experimental.fx_acc.acc_ops' raise _ForwardFunctionBuilderError(err) handler = ACC_OP_HANDLERS.get(f_node.target) if handler is not None: return handler(self, args) raise NotImplementedError(f'Unsupported function: {f_node.target}') _AccOpHandler = Callable[[_ForwardFunctionBuilder, Mapping[str, ir.Value]], ir.OpResult] _AccOpHandlerTable = MutableMapping[Callable[..., Any], _AccOpHandler] ACC_OP_HANDLERS: _AccOpHandlerTable = {} def _add_handler(table: _AccOpHandlerTable, acc_op: Callable[..., Any]): def decorator(f: _AccOpHandler): table[acc_op] = f return f return decorator # TODO: these handlers should be meta-programmed @_add_handler(ACC_OP_HANDLERS, acc_ops.sigmoid) def _sigmoid_handler(func_builder: _ForwardFunctionBuilder, args: Mapping[str, ir.Value]) -> ir.OpResult: input_arg = args.get('input') assert input_arg is not None, 'A call to this handler must include \ an argument named `input`' tensor_type = TorchTensorType().to_mlir(func_builder.context) result = torch_d.AtenSigmoidOp(tensor_type, input_arg, loc=func_builder.loc, ip=func_builder.func_ip).result return result @_add_handler(ACC_OP_HANDLERS, acc_ops.tanh) def _tanh_handler(func_builder: _ForwardFunctionBuilder, args: Mapping[str, ir.Value]) -> ir.OpResult: input_arg = args.get('input') assert input_arg is not None, 'A call to this handler must include \ an argument named `input`' tensor_type = TorchTensorType().to_mlir(func_builder.context) result = torch_d.AtenTanhOp(tensor_type, input_arg, loc=func_builder.loc, ip=func_builder.func_ip).result return result @_add_handler(ACC_OP_HANDLERS, acc_ops.add) def _add_handler(func_builder: _ForwardFunctionBuilder, args: Mapping[str, ir.Value]) -> ir.OpResult: input_arg = args.get('input') other_arg = args.get('other') assert input_arg is not None and other_arg is not None, \ 'A call to this handler must include an argument named `input` \ and an argument named `other`' tensor_type = TorchTensorType().to_mlir(func_builder.context) int_type = PythonType(int).to_mlir(func_builder.context) int_attr = ir.IntegerAttr.get(int_type, 1) alpha_arg = torch_d.ConstantIntOp(int_type, int_attr, loc=func_builder.loc, ip=func_builder.func_ip).result result = torch_d.AtenAddTensorOp(tensor_type, input_arg, other_arg, alpha_arg, loc=func_builder.loc, ip=func_builder.func_ip).result return result def build_module(py_module: torch.fx.GraphModule) -> ir.Module: """ Translate input module into an MLIR module in the `torch` dialect. Parameters ---------- py_module: torch.fx.GraphModule GraphModule produced by the `acc_tracer` from `torch.fx.experimental.fx_acc`. Returns ------- ir.Module Translation of the input module into an MLIR module """ with ir.Context(): loc = ir.Location.unknown() empty_mlir_module = ir.Module.create(loc) torch_d.register_dialect(empty_mlir_module.context) mlir_module = _ClassDeclAndInitBuilder(py_module, empty_mlir_module).to_mlir() return _ForwardFunctionBuilder(py_module, mlir_module).to_mlir()