mirror of https://github.com/llvm/torch-mlir
436 lines
16 KiB
Python
436 lines
16 KiB
Python
"""
|
|
Translator from Torch.FX to MLIR.
|
|
|
|
The following defines a set of classes that take a module
|
|
generated by the `torch.fx.experimental.fx_acc.acc_tracer` function
|
|
and converts parts of it into an MLIR module.
|
|
|
|
The expected use for this module is to use the function
|
|
`build_module(py_module: torch.fx.GraphModule) -> ir.Module`
|
|
to convert the output from the tracer into MLIR using the `torch`
|
|
dialect.
|
|
|
|
This file is licensed under a pytorch-style license
|
|
See frontends/pytorch/LICENSE for license information.
|
|
"""
|
|
|
|
# pylint: disable=no-member, no-name-in-module, invalid-name, fixme
|
|
|
|
from typing import MutableMapping, Mapping, Optional, List, Callable, Any
|
|
import abc
|
|
from itertools import chain
|
|
|
|
from torch_mlir import ir
|
|
import torch_mlir.dialects.torch as torch_d
|
|
from torch_mlir.dialects import builtin, std
|
|
|
|
import torch.fx
|
|
from torch.fx.experimental.fx_acc import acc_ops
|
|
|
|
from .torch_mlir_types import TorchTensorType, PythonType, TorchNnModuleType
|
|
|
|
Environment = MutableMapping[torch.fx.Node, ir.Value]
|
|
|
|
|
|
class _Builder(abc.ABC):
|
|
"""
|
|
Abstract class for an MLIR builder.
|
|
|
|
A builder is an object that takes a torch.fx.GraphModule and
|
|
an ir.Module, and inserts information into the ir.Module
|
|
using information from the torch.fx.GraphModule.
|
|
|
|
The builders are expected to modify the ir.Module in place.
|
|
This means that using a builder on the same ir.Module
|
|
twice will result in duplicated information in the ir.Module
|
|
that is returned.
|
|
|
|
The builders should not modify the torch.fx.GraphModule.
|
|
|
|
The expected use of the builders is quite simple.
|
|
1. Initialize builder
|
|
2. call `to_mlir` method to get the updated ir.Module
|
|
|
|
Parameters
|
|
----------
|
|
py_module: torch.fx.GraphModule
|
|
GraphModule produced by the `acc_tracer` from
|
|
`torch.fx.experimental.fx_acc`.
|
|
mlir_module: ir.Module
|
|
`ir.Module` that will be modified to include the
|
|
MLIR generated by this builder.
|
|
|
|
Attributes
|
|
----------
|
|
py_module: torch.fx.GraphModule
|
|
mlir_module: ir.Module
|
|
context: ir.Context
|
|
Context used by the `mlir_module`.
|
|
module_ip: ir.InsertionPoint
|
|
Insertion point for the body of the `mlir_module`.
|
|
loc: ir.Location
|
|
Used to keep track of source code location information.
|
|
class_type_name: str
|
|
Qualified name of the class given by the type of `py_module`.
|
|
|
|
Methods
|
|
-------
|
|
to_mlir() -> ir.Module
|
|
Insert into `mlir_module` the MLIR produced by the builder.
|
|
"""
|
|
|
|
def __init__(self, py_module: torch.fx.GraphModule,
|
|
mlir_module: ir.Module):
|
|
self.py_module = py_module
|
|
self.mlir_module = mlir_module
|
|
self.context = mlir_module.context
|
|
self.module_ip = ir.InsertionPoint(mlir_module.body)
|
|
# TODO: find a way to get a real location
|
|
self.loc = ir.Location.unknown(self.context)
|
|
# TODO: is qualified module name necessary?
|
|
self.class_type_name = type(py_module).__name__
|
|
|
|
@abc.abstractmethod
|
|
def to_mlir(self) -> ir.Module:
|
|
"""
|
|
Insert into `mlir_module` the MLIR produced by the builder.
|
|
|
|
Returns
|
|
-------
|
|
ir.Module
|
|
Modified `mlir_module` with the new MLIR produced.
|
|
"""
|
|
|
|
|
|
class _ClassDeclAndInitBuilder(_Builder):
|
|
"""
|
|
Builder for creating a class in MLIR with attributes initialized.
|
|
|
|
This builder performs the following two tasks:
|
|
|
|
1. Create an MLIR class declaration based on the public
|
|
attributes of the `py_module` as well as the `forward` method.
|
|
2. Create MLIR that initializes each attribute of the declaration.
|
|
|
|
Parameters
|
|
----------
|
|
py_module: torch.fx.GraphModule
|
|
GraphModule produced by the `acc_tracer` from
|
|
`torch.fx.experimental.fx_acc`.
|
|
mlir_module: ir.Module
|
|
`ir.Module` that will be modified to include the
|
|
MLIR generated by this builder.
|
|
|
|
Attributes
|
|
----------
|
|
class_type_ip : Optional[ir.InsertionPoint]
|
|
Insertion point for `torch_d.ClassTypeOp`.
|
|
nn_module_ip : Optional[ir.InsertionPoint]
|
|
Insertion point for `torch_d.NnModuleOp`.
|
|
|
|
Methods
|
|
-------
|
|
to_mlir() -> ir.Module
|
|
Insert into `mlir_module` the MLIR produced by the builder.
|
|
"""
|
|
|
|
def __init__(self, py_module: torch.fx.GraphModule,
|
|
mlir_module: ir.Module):
|
|
super().__init__(py_module, mlir_module)
|
|
self.class_type_ip: Optional[ir.InsertionPoint] = None
|
|
self.nn_module_ip: Optional[ir.InsertionPoint] = None
|
|
|
|
def to_mlir(self) -> ir.Module:
|
|
with self.context:
|
|
class_name_attr = ir.StringAttr.get(self.class_type_name)
|
|
class_type_op = torch_d.ClassTypeOp(class_name_attr,
|
|
loc=self.loc,
|
|
ip=self.module_ip)
|
|
new_class_block = class_type_op.regions[0].blocks.append()
|
|
self.class_type_ip = ir.InsertionPoint(new_class_block)
|
|
|
|
module_type = TorchNnModuleType(self.class_type_name
|
|
).to_mlir(self.context)
|
|
nn_module_op = torch_d.NnModuleOp(module_type,
|
|
loc=self.loc, ip=self.module_ip)
|
|
new_nn_module_block = nn_module_op.regions[0].blocks.append()
|
|
self.nn_module_ip = ir.InsertionPoint(new_nn_module_block)
|
|
|
|
self._insert_attr_declarations_and_definitions()
|
|
self._insert_forward_method_declaration()
|
|
|
|
torch_d.ClassTypeTerminatorOp(loc=self.loc, ip=self.class_type_ip)
|
|
torch_d.NnModuleTerminatorOp(loc=self.loc, ip=self.nn_module_ip)
|
|
|
|
return self.mlir_module
|
|
|
|
def _insert_attr_declarations_and_definitions(self):
|
|
# TODO: not sure how good this definition is for unhidden vars
|
|
unhidden_vars = filter(lambda v: not v[0].startswith('_'),
|
|
self.py_module.__dict__.items())
|
|
|
|
# TODO: is anything else needed? There are some hidden
|
|
# attributes that get added by the torch.jit.script
|
|
# compilation pipeline, such as:
|
|
# torch.attr private "_is_full_backward_hook"
|
|
# that are not being added here
|
|
attrs = chain(unhidden_vars, self.py_module.named_parameters())
|
|
|
|
for name, value in attrs:
|
|
type_attr: Optional[ir.TypeAttr] = None
|
|
operand: Optional[ir.OpResult] = None
|
|
# TODO: this should be meta-programmable
|
|
if isinstance(value, bool):
|
|
with self.context:
|
|
bool_type = PythonType(bool).to_mlir(self.context)
|
|
type_attr = ir.TypeAttr.get(bool_type)
|
|
bool_attr = ir.BoolAttr.get(value)
|
|
operand = torch_d.ConstantBoolOp(
|
|
bool_type,
|
|
bool_attr,
|
|
loc=self.loc,
|
|
ip=self.module_ip).result
|
|
else:
|
|
err = f'Unsupported attribute type: {type(value)}'
|
|
raise NotImplementedError(err)
|
|
|
|
assert type_attr is not None and operand is not None, \
|
|
'Each clause must specify a value for`type_attr` and `operand`'
|
|
with self.context:
|
|
name_attr = ir.StringAttr.get(name)
|
|
# TODO: don't hardcode `private` field in `AttrOp`
|
|
torch_d.AttrOp(name_attr, type_attr, True,
|
|
loc=self.loc, ip=self.class_type_ip)
|
|
torch_d.SlotOp(name_attr, operand, loc=self.loc,
|
|
ip=self.nn_module_ip)
|
|
|
|
def _insert_forward_method_declaration(self):
|
|
if not hasattr(self.py_module, 'forward'):
|
|
return
|
|
|
|
with self.context:
|
|
method_name = 'forward'
|
|
name_attr = ir.StringAttr.get(method_name)
|
|
qualified_name = f'{self.class_type_name}.{method_name}'
|
|
# TODO: is there a nice python binding for this?
|
|
function_attr = ir.Attribute.parse(f'@{qualified_name}')
|
|
# TODO: don't hardcode `private` field in `AttrOp`
|
|
torch_d.MethodOp(name_attr, function_attr, False,
|
|
loc=self.loc, ip=self.class_type_ip)
|
|
|
|
|
|
class _ForwardFunctionBuilderError(Exception):
|
|
def __init__(self, value: str):
|
|
super().__init__()
|
|
self.value = value
|
|
|
|
def __str__(self) -> str:
|
|
return self.value
|
|
|
|
|
|
class _ForwardFunctionBuilder(_Builder):
|
|
"""
|
|
Builder for converting the forward method into MLIR.
|
|
|
|
This builder transverses the `torch.fx.Graph` of the
|
|
`py_module`, and translates the operations into MLIR.
|
|
|
|
Parameters
|
|
----------
|
|
py_module: torch.fx.GraphModule
|
|
GraphModule produced by the `acc_tracer` from
|
|
`torch.fx.experimental.fx_acc`.
|
|
mlir_module: ir.Module
|
|
`ir.Module` that will be modified to include the
|
|
MLIR generated by this builder.
|
|
|
|
Attributes
|
|
----------
|
|
func_ip : Optional[ir.InsertionPoint]
|
|
Insertion point for `torch_d.FuncOp` representing the forward method.
|
|
env : Environment
|
|
Used to keep track of the `ir.Value` corresponding to each
|
|
`torch.fx.Node` that has already been handled.
|
|
|
|
Methods
|
|
-------
|
|
to_mlir() -> ir.Module
|
|
Insert into `mlir_module` the MLIR produced by the builder.
|
|
"""
|
|
|
|
def __init__(self, py_module: torch.fx.GraphModule,
|
|
mlir_module: ir.Module):
|
|
super().__init__(py_module, mlir_module)
|
|
self.func_ip: Optional[ir.InsertionPoint] = None
|
|
self.env: Environment = {}
|
|
|
|
def to_mlir(self) -> ir.Module:
|
|
tensor_type = TorchTensorType().to_mlir(self.context)
|
|
module_type = TorchNnModuleType(self.class_type_name
|
|
).to_mlir(self.context)
|
|
# TODO: currently I am assuming that forward always returns a tensor
|
|
func_type = ([module_type] + self._get_arg_types(), [tensor_type])
|
|
with self.context:
|
|
# TODO: Don't hardcode method name
|
|
# TODO: is visibility always private?
|
|
func_op = builtin.FuncOp(f'{self.class_type_name}.forward',
|
|
func_type, visibility='private',
|
|
loc=self.loc, ip=self.module_ip)
|
|
|
|
func_op.add_entry_block()
|
|
self.func_ip = ir.InsertionPoint(func_op.entry_block)
|
|
self._initialize_environment(func_op.entry_block.arguments)
|
|
|
|
for node in self.py_module.graph.nodes:
|
|
if node.op == 'call_function':
|
|
result = self._insert_function_call(node)
|
|
self.env[node] = result
|
|
elif node.op == 'output':
|
|
std.ReturnOp([self.env[node_arg] for node_arg in node.args],
|
|
loc=self.loc, ip=self.func_ip)
|
|
elif node.op == 'placeholder':
|
|
continue
|
|
elif node.op == 'call_module':
|
|
err = f'Unsupported node.op type: {node.op}'
|
|
raise NotImplementedError(err)
|
|
elif node.op == 'get_attr':
|
|
err = f'Unsupported node.op type: {node.op}'
|
|
raise NotImplementedError(err)
|
|
else:
|
|
err = f'Unsupported node.op type: {node.op}'
|
|
raise NotImplementedError(err)
|
|
|
|
return self.mlir_module
|
|
|
|
def _initialize_environment(self, arg_list: ir.BlockArgumentList) -> None:
|
|
placeholders = filter(lambda node: node.op == 'placeholder',
|
|
self.py_module.graph.nodes)
|
|
|
|
self_type = TorchNnModuleType(self.class_type_name
|
|
).to_mlir(self.context)
|
|
non_self_args = filter(lambda arg: arg.type != self_type,
|
|
arg_list)
|
|
|
|
self.env.update(zip(placeholders, non_self_args))
|
|
|
|
def _get_arg_types(self) -> List[ir.Type]:
|
|
operands = filter(lambda node: node.op == 'placeholder',
|
|
self.py_module.graph.nodes)
|
|
types = []
|
|
for operand in operands:
|
|
type_ = operand.kwargs.get('torch_mlir_type')
|
|
types.append(type_.to_mlir(self.context))
|
|
|
|
return types
|
|
|
|
def _insert_function_call(self, f_node: torch.fx.Node) -> ir.OpResult:
|
|
assert f_node.op == 'call_function'
|
|
args: MutableMapping[str, ir.Value] = {}
|
|
for name, arg_node in f_node.kwargs.items():
|
|
if isinstance(arg_node, torch.fx.Node):
|
|
args[name] = self.env[arg_node]
|
|
|
|
if isinstance(f_node.target, str):
|
|
err = f'f_node.targe = {f_node.target} must be of type \
|
|
Callable[..., Any], not str. Make sure the torch.fx.Graph has been \
|
|
normalized to using torch.fx.experimental.fx_acc.acc_ops'
|
|
raise _ForwardFunctionBuilderError(err)
|
|
|
|
handler = ACC_OP_HANDLERS.get(f_node.target)
|
|
if handler is not None:
|
|
return handler(self, args)
|
|
|
|
raise NotImplementedError(f'Unsupported function: {f_node.target}')
|
|
|
|
|
|
_AccOpHandler = Callable[[_ForwardFunctionBuilder, Mapping[str, ir.Value]],
|
|
ir.OpResult]
|
|
_AccOpHandlerTable = MutableMapping[Callable[..., Any], _AccOpHandler]
|
|
|
|
ACC_OP_HANDLERS: _AccOpHandlerTable = {}
|
|
|
|
|
|
def _add_handler(table: _AccOpHandlerTable, acc_op: Callable[..., Any]):
|
|
def decorator(f: _AccOpHandler):
|
|
table[acc_op] = f
|
|
return f
|
|
|
|
return decorator
|
|
|
|
|
|
# TODO: these handlers should be meta-programmed
|
|
@_add_handler(ACC_OP_HANDLERS, acc_ops.sigmoid)
|
|
def _sigmoid_handler(func_builder: _ForwardFunctionBuilder,
|
|
args: Mapping[str, ir.Value]) -> ir.OpResult:
|
|
input_arg = args.get('input')
|
|
assert input_arg is not None, 'A call to this handler must include \
|
|
an argument named `input`'
|
|
tensor_type = TorchTensorType().to_mlir(func_builder.context)
|
|
result = torch_d.AtenSigmoidOp(tensor_type,
|
|
input_arg,
|
|
loc=func_builder.loc,
|
|
ip=func_builder.func_ip).result
|
|
return result
|
|
|
|
|
|
@_add_handler(ACC_OP_HANDLERS, acc_ops.tanh)
|
|
def _tanh_handler(func_builder: _ForwardFunctionBuilder,
|
|
args: Mapping[str, ir.Value]) -> ir.OpResult:
|
|
input_arg = args.get('input')
|
|
assert input_arg is not None, 'A call to this handler must include \
|
|
an argument named `input`'
|
|
tensor_type = TorchTensorType().to_mlir(func_builder.context)
|
|
result = torch_d.AtenTanhOp(tensor_type,
|
|
input_arg,
|
|
loc=func_builder.loc,
|
|
ip=func_builder.func_ip).result
|
|
return result
|
|
|
|
|
|
@_add_handler(ACC_OP_HANDLERS, acc_ops.add)
|
|
def _add_handler(func_builder: _ForwardFunctionBuilder,
|
|
args: Mapping[str, ir.Value]) -> ir.OpResult:
|
|
input_arg = args.get('input')
|
|
other_arg = args.get('other')
|
|
assert input_arg is not None and other_arg is not None, \
|
|
'A call to this handler must include an argument named `input` \
|
|
and an argument named `other`'
|
|
tensor_type = TorchTensorType().to_mlir(func_builder.context)
|
|
int_type = PythonType(int).to_mlir(func_builder.context)
|
|
int_attr = ir.IntegerAttr.get(int_type, 1)
|
|
alpha_arg = torch_d.ConstantIntOp(int_type,
|
|
int_attr,
|
|
loc=func_builder.loc,
|
|
ip=func_builder.func_ip).result
|
|
result = torch_d.AtenAddTensorOp(tensor_type,
|
|
input_arg,
|
|
other_arg,
|
|
alpha_arg,
|
|
loc=func_builder.loc,
|
|
ip=func_builder.func_ip).result
|
|
return result
|
|
|
|
|
|
def build_module(py_module: torch.fx.GraphModule) -> ir.Module:
|
|
"""
|
|
Translate input module into an MLIR module in the `torch` dialect.
|
|
|
|
Parameters
|
|
----------
|
|
py_module: torch.fx.GraphModule
|
|
GraphModule produced by the `acc_tracer` from
|
|
`torch.fx.experimental.fx_acc`.
|
|
|
|
Returns
|
|
-------
|
|
ir.Module
|
|
Translation of the input module into an MLIR module
|
|
"""
|
|
with ir.Context():
|
|
loc = ir.Location.unknown()
|
|
empty_mlir_module = ir.Module.create(loc)
|
|
torch_d.register_dialect(empty_mlir_module.context)
|
|
mlir_module = _ClassDeclAndInitBuilder(py_module,
|
|
empty_mlir_module).to_mlir()
|
|
return _ForwardFunctionBuilder(py_module, mlir_module).to_mlir()
|