torch-mlir/examples/torchfx/builder.py

436 lines
16 KiB
Python

"""
Translator from Torch.FX to MLIR.
The following defines a set of classes that take a module
generated by the `torch.fx.experimental.fx_acc.acc_tracer` function
and converts parts of it into an MLIR module.
The expected use for this module is to use the function
`build_module(py_module: torch.fx.GraphModule) -> ir.Module`
to convert the output from the tracer into MLIR using the `torch`
dialect.
This file is licensed under a pytorch-style license
See frontends/pytorch/LICENSE for license information.
"""
# pylint: disable=no-member, no-name-in-module, invalid-name, fixme
from typing import MutableMapping, Mapping, Optional, List, Callable, Any
import abc
from itertools import chain
from torch_mlir import ir
import torch_mlir.dialects.torch as torch_d
from torch_mlir.dialects import builtin, std
import torch.fx
from torch.fx.experimental.fx_acc import acc_ops
from .torch_mlir_types import TorchTensorType, PythonType, TorchNnModuleType
Environment = MutableMapping[torch.fx.Node, ir.Value]
class _Builder(abc.ABC):
"""
Abstract class for an MLIR builder.
A builder is an object that takes a torch.fx.GraphModule and
an ir.Module, and inserts information into the ir.Module
using information from the torch.fx.GraphModule.
The builders are expected to modify the ir.Module in place.
This means that using a builder on the same ir.Module
twice will result in duplicated information in the ir.Module
that is returned.
The builders should not modify the torch.fx.GraphModule.
The expected use of the builders is quite simple.
1. Initialize builder
2. call `to_mlir` method to get the updated ir.Module
Parameters
----------
py_module: torch.fx.GraphModule
GraphModule produced by the `acc_tracer` from
`torch.fx.experimental.fx_acc`.
mlir_module: ir.Module
`ir.Module` that will be modified to include the
MLIR generated by this builder.
Attributes
----------
py_module: torch.fx.GraphModule
mlir_module: ir.Module
context: ir.Context
Context used by the `mlir_module`.
module_ip: ir.InsertionPoint
Insertion point for the body of the `mlir_module`.
loc: ir.Location
Used to keep track of source code location information.
class_type_name: str
Qualified name of the class given by the type of `py_module`.
Methods
-------
to_mlir() -> ir.Module
Insert into `mlir_module` the MLIR produced by the builder.
"""
def __init__(self, py_module: torch.fx.GraphModule,
mlir_module: ir.Module):
self.py_module = py_module
self.mlir_module = mlir_module
self.context = mlir_module.context
self.module_ip = ir.InsertionPoint(mlir_module.body)
# TODO: find a way to get a real location
self.loc = ir.Location.unknown(self.context)
# TODO: is qualified module name necessary?
self.class_type_name = type(py_module).__name__
@abc.abstractmethod
def to_mlir(self) -> ir.Module:
"""
Insert into `mlir_module` the MLIR produced by the builder.
Returns
-------
ir.Module
Modified `mlir_module` with the new MLIR produced.
"""
class _ClassDeclAndInitBuilder(_Builder):
"""
Builder for creating a class in MLIR with attributes initialized.
This builder performs the following two tasks:
1. Create an MLIR class declaration based on the public
attributes of the `py_module` as well as the `forward` method.
2. Create MLIR that initializes each attribute of the declaration.
Parameters
----------
py_module: torch.fx.GraphModule
GraphModule produced by the `acc_tracer` from
`torch.fx.experimental.fx_acc`.
mlir_module: ir.Module
`ir.Module` that will be modified to include the
MLIR generated by this builder.
Attributes
----------
class_type_ip : Optional[ir.InsertionPoint]
Insertion point for `torch_d.ClassTypeOp`.
nn_module_ip : Optional[ir.InsertionPoint]
Insertion point for `torch_d.NnModuleOp`.
Methods
-------
to_mlir() -> ir.Module
Insert into `mlir_module` the MLIR produced by the builder.
"""
def __init__(self, py_module: torch.fx.GraphModule,
mlir_module: ir.Module):
super().__init__(py_module, mlir_module)
self.class_type_ip: Optional[ir.InsertionPoint] = None
self.nn_module_ip: Optional[ir.InsertionPoint] = None
def to_mlir(self) -> ir.Module:
with self.context:
class_name_attr = ir.StringAttr.get(self.class_type_name)
class_type_op = torch_d.ClassTypeOp(class_name_attr,
loc=self.loc,
ip=self.module_ip)
new_class_block = class_type_op.regions[0].blocks.append()
self.class_type_ip = ir.InsertionPoint(new_class_block)
module_type = TorchNnModuleType(self.class_type_name
).to_mlir(self.context)
nn_module_op = torch_d.NnModuleOp(module_type,
loc=self.loc, ip=self.module_ip)
new_nn_module_block = nn_module_op.regions[0].blocks.append()
self.nn_module_ip = ir.InsertionPoint(new_nn_module_block)
self._insert_attr_declarations_and_definitions()
self._insert_forward_method_declaration()
torch_d.ClassTypeTerminatorOp(loc=self.loc, ip=self.class_type_ip)
torch_d.NnModuleTerminatorOp(loc=self.loc, ip=self.nn_module_ip)
return self.mlir_module
def _insert_attr_declarations_and_definitions(self):
# TODO: not sure how good this definition is for unhidden vars
unhidden_vars = filter(lambda v: not v[0].startswith('_'),
self.py_module.__dict__.items())
# TODO: is anything else needed? There are some hidden
# attributes that get added by the torch.jit.script
# compilation pipeline, such as:
# torch.attr private "_is_full_backward_hook"
# that are not being added here
attrs = chain(unhidden_vars, self.py_module.named_parameters())
for name, value in attrs:
type_attr: Optional[ir.TypeAttr] = None
operand: Optional[ir.OpResult] = None
# TODO: this should be meta-programmable
if isinstance(value, bool):
with self.context:
bool_type = PythonType(bool).to_mlir(self.context)
type_attr = ir.TypeAttr.get(bool_type)
bool_attr = ir.BoolAttr.get(value)
operand = torch_d.ConstantBoolOp(
bool_type,
bool_attr,
loc=self.loc,
ip=self.module_ip).result
else:
err = f'Unsupported attribute type: {type(value)}'
raise NotImplementedError(err)
assert type_attr is not None and operand is not None, \
'Each clause must specify a value for`type_attr` and `operand`'
with self.context:
name_attr = ir.StringAttr.get(name)
# TODO: don't hardcode `private` field in `AttrOp`
torch_d.AttrOp(name_attr, type_attr, True,
loc=self.loc, ip=self.class_type_ip)
torch_d.SlotOp(name_attr, operand, loc=self.loc,
ip=self.nn_module_ip)
def _insert_forward_method_declaration(self):
if not hasattr(self.py_module, 'forward'):
return
with self.context:
method_name = 'forward'
name_attr = ir.StringAttr.get(method_name)
qualified_name = f'{self.class_type_name}.{method_name}'
# TODO: is there a nice python binding for this?
function_attr = ir.Attribute.parse(f'@{qualified_name}')
# TODO: don't hardcode `private` field in `AttrOp`
torch_d.MethodOp(name_attr, function_attr, False,
loc=self.loc, ip=self.class_type_ip)
class _ForwardFunctionBuilderError(Exception):
def __init__(self, value: str):
super().__init__()
self.value = value
def __str__(self) -> str:
return self.value
class _ForwardFunctionBuilder(_Builder):
"""
Builder for converting the forward method into MLIR.
This builder transverses the `torch.fx.Graph` of the
`py_module`, and translates the operations into MLIR.
Parameters
----------
py_module: torch.fx.GraphModule
GraphModule produced by the `acc_tracer` from
`torch.fx.experimental.fx_acc`.
mlir_module: ir.Module
`ir.Module` that will be modified to include the
MLIR generated by this builder.
Attributes
----------
func_ip : Optional[ir.InsertionPoint]
Insertion point for `torch_d.FuncOp` representing the forward method.
env : Environment
Used to keep track of the `ir.Value` corresponding to each
`torch.fx.Node` that has already been handled.
Methods
-------
to_mlir() -> ir.Module
Insert into `mlir_module` the MLIR produced by the builder.
"""
def __init__(self, py_module: torch.fx.GraphModule,
mlir_module: ir.Module):
super().__init__(py_module, mlir_module)
self.func_ip: Optional[ir.InsertionPoint] = None
self.env: Environment = {}
def to_mlir(self) -> ir.Module:
tensor_type = TorchTensorType().to_mlir(self.context)
module_type = TorchNnModuleType(self.class_type_name
).to_mlir(self.context)
# TODO: currently I am assuming that forward always returns a tensor
func_type = ([module_type] + self._get_arg_types(), [tensor_type])
with self.context:
# TODO: Don't hardcode method name
# TODO: is visibility always private?
func_op = builtin.FuncOp(f'{self.class_type_name}.forward',
func_type, visibility='private',
loc=self.loc, ip=self.module_ip)
func_op.add_entry_block()
self.func_ip = ir.InsertionPoint(func_op.entry_block)
self._initialize_environment(func_op.entry_block.arguments)
for node in self.py_module.graph.nodes:
if node.op == 'call_function':
result = self._insert_function_call(node)
self.env[node] = result
elif node.op == 'output':
std.ReturnOp([self.env[node_arg] for node_arg in node.args],
loc=self.loc, ip=self.func_ip)
elif node.op == 'placeholder':
continue
elif node.op == 'call_module':
err = f'Unsupported node.op type: {node.op}'
raise NotImplementedError(err)
elif node.op == 'get_attr':
err = f'Unsupported node.op type: {node.op}'
raise NotImplementedError(err)
else:
err = f'Unsupported node.op type: {node.op}'
raise NotImplementedError(err)
return self.mlir_module
def _initialize_environment(self, arg_list: ir.BlockArgumentList) -> None:
placeholders = filter(lambda node: node.op == 'placeholder',
self.py_module.graph.nodes)
self_type = TorchNnModuleType(self.class_type_name
).to_mlir(self.context)
non_self_args = filter(lambda arg: arg.type != self_type,
arg_list)
self.env.update(zip(placeholders, non_self_args))
def _get_arg_types(self) -> List[ir.Type]:
operands = filter(lambda node: node.op == 'placeholder',
self.py_module.graph.nodes)
types = []
for operand in operands:
type_ = operand.kwargs.get('torch_mlir_type')
types.append(type_.to_mlir(self.context))
return types
def _insert_function_call(self, f_node: torch.fx.Node) -> ir.OpResult:
assert f_node.op == 'call_function'
args: MutableMapping[str, ir.Value] = {}
for name, arg_node in f_node.kwargs.items():
if isinstance(arg_node, torch.fx.Node):
args[name] = self.env[arg_node]
if isinstance(f_node.target, str):
err = f'f_node.targe = {f_node.target} must be of type \
Callable[..., Any], not str. Make sure the torch.fx.Graph has been \
normalized to using torch.fx.experimental.fx_acc.acc_ops'
raise _ForwardFunctionBuilderError(err)
handler = ACC_OP_HANDLERS.get(f_node.target)
if handler is not None:
return handler(self, args)
raise NotImplementedError(f'Unsupported function: {f_node.target}')
_AccOpHandler = Callable[[_ForwardFunctionBuilder, Mapping[str, ir.Value]],
ir.OpResult]
_AccOpHandlerTable = MutableMapping[Callable[..., Any], _AccOpHandler]
ACC_OP_HANDLERS: _AccOpHandlerTable = {}
def _add_handler(table: _AccOpHandlerTable, acc_op: Callable[..., Any]):
def decorator(f: _AccOpHandler):
table[acc_op] = f
return f
return decorator
# TODO: these handlers should be meta-programmed
@_add_handler(ACC_OP_HANDLERS, acc_ops.sigmoid)
def _sigmoid_handler(func_builder: _ForwardFunctionBuilder,
args: Mapping[str, ir.Value]) -> ir.OpResult:
input_arg = args.get('input')
assert input_arg is not None, 'A call to this handler must include \
an argument named `input`'
tensor_type = TorchTensorType().to_mlir(func_builder.context)
result = torch_d.AtenSigmoidOp(tensor_type,
input_arg,
loc=func_builder.loc,
ip=func_builder.func_ip).result
return result
@_add_handler(ACC_OP_HANDLERS, acc_ops.tanh)
def _tanh_handler(func_builder: _ForwardFunctionBuilder,
args: Mapping[str, ir.Value]) -> ir.OpResult:
input_arg = args.get('input')
assert input_arg is not None, 'A call to this handler must include \
an argument named `input`'
tensor_type = TorchTensorType().to_mlir(func_builder.context)
result = torch_d.AtenTanhOp(tensor_type,
input_arg,
loc=func_builder.loc,
ip=func_builder.func_ip).result
return result
@_add_handler(ACC_OP_HANDLERS, acc_ops.add)
def _add_handler(func_builder: _ForwardFunctionBuilder,
args: Mapping[str, ir.Value]) -> ir.OpResult:
input_arg = args.get('input')
other_arg = args.get('other')
assert input_arg is not None and other_arg is not None, \
'A call to this handler must include an argument named `input` \
and an argument named `other`'
tensor_type = TorchTensorType().to_mlir(func_builder.context)
int_type = PythonType(int).to_mlir(func_builder.context)
int_attr = ir.IntegerAttr.get(int_type, 1)
alpha_arg = torch_d.ConstantIntOp(int_type,
int_attr,
loc=func_builder.loc,
ip=func_builder.func_ip).result
result = torch_d.AtenAddTensorOp(tensor_type,
input_arg,
other_arg,
alpha_arg,
loc=func_builder.loc,
ip=func_builder.func_ip).result
return result
def build_module(py_module: torch.fx.GraphModule) -> ir.Module:
"""
Translate input module into an MLIR module in the `torch` dialect.
Parameters
----------
py_module: torch.fx.GraphModule
GraphModule produced by the `acc_tracer` from
`torch.fx.experimental.fx_acc`.
Returns
-------
ir.Module
Translation of the input module into an MLIR module
"""
with ir.Context():
loc = ir.Location.unknown()
empty_mlir_module = ir.Module.create(loc)
torch_d.register_dialect(empty_mlir_module.context)
mlir_module = _ClassDeclAndInitBuilder(py_module,
empty_mlir_module).to_mlir()
return _ForwardFunctionBuilder(py_module, mlir_module).to_mlir()