torch-mlir/python/torch_mlir/eager_mode/ir_building.py

360 lines
12 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Translator from torch.jit.ScriptFunction to MLIR.
The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the
`torch` dialect.
The expected use of this module is to create an instance of one of the classes below, and then calling the
`to_mlir` method to generate the MLIR representation of the type.
Information about what types are supported by each class can be found in docstrings of each of the classes.
In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module.
The expected use for this module is to use the function
`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module`
to convert the TorchScript function into MLIR using the `torch` dialect.
"""
import abc
2022-05-26 03:59:59 +08:00
import re
from typing import Any, Optional, Iterable, Dict
from typing import Union
2022-05-26 03:59:59 +08:00
import numpy as np
import torch
2022-05-26 03:59:59 +08:00
import torch._C
import torch.jit
from torch._ops import OpOverload
from torch_mlir import ir
from torch_mlir.dialects.func import FuncOp
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
class TorchMlirType(abc.ABC):
"""
A `TorchMlirType` is an object that produces MLIR
types in the `torch` dialect. The only requirement
for a class to be a subclass of `TorchMlirType` is
to define a `to_mlir(self, ir.Context) -> ir.Type`.
Each class is allowed to have different types of
__init__ methods depending on the information they
require to produce the given MLIR representation.
"""
@abc.abstractmethod
def to_mlir(self, context: ir.Context) -> ir.Type:
pass
class TorchTensorTypeError(Exception):
def __init__(self, value: str):
super().__init__()
self.value = value
def __str__(self) -> str:
return self.value
class TorchTensorType(TorchMlirType):
"""
This class is used to generate types of the form
!torch.tensor and !torch.vtensor<SHAPE, DTYPE>,
where SHAPE is a list representing the shape of the tensor,
and DTYPE is an MLIR data type.
"""
def __init__(
self,
*,
shape: Optional[Iterable[Optional[int]]] = None,
dtype: Optional[torch.dtype] = None,
):
self.shape = shape
self.dtype = dtype
if dtype is None and shape is not None:
err = "If shape is specified, dtype must also be specified"
raise TorchTensorTypeError(err)
def __str__(self):
return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})"
def to_mlir(self, context: ir.Context) -> ir.Type:
if self.dtype is None:
return ir.Type.parse("!torch.tensor", context=context)
shape_asm = self._shape_to_mlir_asm()
dtype_asm = self._dtype_to_mlir_asm()
return ir.Type.parse(
f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context
)
def _shape_to_mlir_asm(self) -> str:
if self.shape is None:
return "*"
str_sizes = map(lambda x: "?" if x is None else str(x), self.shape)
return f'[{",".join(str_sizes)}]'
def _dtype_to_mlir_asm(self) -> str:
if self.dtype in [torch.float64]:
return "f64"
if self.dtype in [torch.float, torch.float32]:
return "f32"
if self.dtype in [torch.int, torch.int32]:
return "si32"
if self.dtype in [torch.int64]:
return "si64"
if self.dtype in [torch.bool]:
return "i1"
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
class TorchNnModuleType(TorchMlirType):
"""This class is used to generate types for `!torch.nn.Module`s."""
def __init__(self, module_name: str):
self.module_name = module_name
def __str__(self):
return "torch.nn.Module"
def to_mlir(self, context: ir.Context) -> ir.Type:
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context)
class PythonType(TorchMlirType):
"""
This class is used to convert regular Python types
into their corresponding `torch` dialect representation.
The list of supported types can be found in the dictionary
`_type_to_asm_dict`.
"""
_type_to_asm_dict = {
bool: "!torch.bool",
int: "!torch.int",
type(None): "!torch.none",
}
def __init__(self, type_: Any):
self.type_ = type_
def __str__(self):
return str(self.type_)
def to_mlir(self, context: ir.Context) -> ir.Type:
asm = self._type_to_asm_dict.get(self.type_)
if asm is None:
raise NotImplementedError(f"Unsupported type: {self.type_}")
return ir.Type.parse(asm, context=context)
# TODO: This functionality should be incorporated into ModuleBuilder.import_function.
class Annotation:
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
self.types = list(
map(lambda t: PythonType(t) if isinstance(t, type) else t, types)
)
def __str__(self):
result = f"Annotation instance with {len(self.types)} types\n"
for e, type_ in enumerate(self.types):
result += f" Type of argument {e + 1}: {str(type_)}\n"
return result
def __iter__(self):
return iter(self.types)
class AnnotationConverter:
@staticmethod
def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr:
dict_attrs = []
for type_ in annotation:
if not isinstance(type_, TorchTensorType):
dict_attrs.append(ir.DictAttr.get({}, context=context))
continue
ir_type = type_.to_mlir(context)
with context:
type_attr = ir.TypeAttr.get(ir_type)
dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr})
dict_attrs.append(dict_attr)
return ir.ArrayAttr.get(dict_attrs, context=context)
def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
with module.context:
name_attr = ir.StringAttr.get(name)
for op in module.body.operations:
if isinstance(op, FuncOp) and op.name == name_attr:
2022-03-29 02:31:07 +08:00
# Add name of torch op as debug_module_name so that
# run_pipeline_with_repro_report can use it.
module.operation.attributes["torch.debug_module_name"] = name_attr
return op
return None
2022-05-26 03:59:59 +08:00
def is_tensor_type(typ: torch._C.Type):
return typ.isSubtypeOf(torch.TensorType.get()) or (
isinstance(typ, torch.OptionalType)
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
)
def is_list_of_tensors_type(typ: torch._C.Type):
return isinstance(typ, torch.ListType) and is_tensor_type(typ.getElementType())
name_mangle_regex = re.compile("[^a-zA-Z0-9]")
def build_ts_script_function(
schema: torch._C.FunctionSchema, kwargs: Dict[str, Any]
) -> torch.jit.ScriptFunction:
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
Constants are inlined for the purposes of invalidating the compile cache when they change.
Parameters
----------
schema: torch._C.FunctionSchema
PyTorch's representation for ops, contains type information needed for inlining constants into the TS graph.
kwargs: Dict
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float/bool params).
Returns
-------
torch.jit.ScriptFunction
Fully specialized (all constants) TS graph whose only arguments are tensors.
"""
# Creates empty TS graph.
graph = torch._C.Graph()
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
# Associate graph inputs/outputs with node inputs/outputs.
graph_inputs = []
for arg in schema.arguments:
arg_name = arg.name if arg.name != "self" else "input"
# If arg is a flattened list of tensors, such as in the case of torch.cat
# then add each element of the list to the graph corresponding to arg
# and insert a ListConstruct to function as input to the op.
if is_list_of_tensors_type(arg.type):
inps = []
for kwarg in [
kwarg for kwarg in kwargs if f"{arg_name}_flattened" in kwarg
]:
inp = graph.addInput()
el_typ = arg.type.getElementType()
if isinstance(el_typ, torch.OptionalType):
el_typ = el_typ.getElementType()
inp.setType(el_typ)
inp.setDebugName(kwarg)
inps.append(inp)
graph_inputs.append(kwarg)
list_cons = graph.insertNode(graph.create("prim::ListConstruct", inps))
list_cons.moveBefore(node)
inp = list_cons.output()
inp.setType(torch.ListType.ofTensors())
# If arg is a tensor, then add input to the graph corresponding to arg.
elif is_tensor_type(arg.type) and kwargs[arg_name] is not None:
inp = graph.addInput()
if isinstance(arg.type, torch.OptionalType):
el_typ = arg.type.getElementType()
else:
el_typ = arg.type
inp.setType(el_typ)
inp.setDebugName(arg_name)
graph_inputs.append(arg_name)
# If arg is a constant, inline (at the top of the graph).
else:
val = kwargs[arg_name]
if val == []:
# Some ops have empty list default values for args
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
# This might be an upstream bug but there doesn't seem to be a way to
# build a prim::ListConstruct list that's empty.
val = None
inp = graph.insertConstant(val)
inp.node().moveBefore(node)
node.addInput(inp)
# Reorder graph inputs to match kwargs.
permutes = [
{inp: i for i, inp in enumerate(graph_inputs)}[kwarg]
for kwarg in [kwarg for kwarg in kwargs if kwarg in graph_inputs]
]
graph.permuteInputs(permutes)
if node.hasMultipleOutputs():
for outp in node.outputs():
graph.registerOutput(outp)
else:
graph.registerOutput(node.output())
fn = torch._C._create_function_from_graph(
f"{name_mangle_regex.sub('', str(graph))}", graph
)
return fn
def build_mlir_module(op: OpOverload, kwargs: Dict[str, Any]) -> ir.Module:
"""Translate input function into an MLIR module in the `torch` dialect.
Parameters
----------
2022-05-26 03:59:59 +08:00
op: OpOverload
Callable from the torch.ops.aten module/namespace that has a _schema field.
kwargs: Dict
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float,bool params).
Returns
-------
ir.Module
2022-05-26 03:59:59 +08:00
Translation of the input module into an MLIR module.
"""
2022-05-26 03:59:59 +08:00
# The assert here is to catch tensor shapes that have size 0 dimensions, such as those produced in
# the course of evaluating SliceEndSleStartModule_basic and SliceOutOfLowerBoundEndIndexModule_basic.
# Such 0 size dimensions fail the assert at mlir/lib/IR/BuiltinTypes.cpp, line 887
annotations = []
for arg_name, arg in kwargs.items():
if isinstance(arg, torch.Tensor):
assert np.prod(arg.shape) != 0, f"{arg_name} has invalid shape {arg.shape}"
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg.dtype))
annotations = tuple(annotations)
script_fun = build_ts_script_function(op._schema, kwargs)
assert len(annotations) == len(
list(script_fun.graph.inputs())
), "Number of annotations and number of graph inputs differs."
mb = ModuleBuilder()
2022-05-26 03:59:59 +08:00
mb.import_function(script_fun)
2022-05-26 03:59:59 +08:00
func_op = get_func_op_with_name(mb.module, script_fun.name)
assert (
func_op is not None
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
func_annotation = Annotation(annotations)
arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context)
func_op.attributes["arg_attrs"] = arg_attrs
return mb.module