torch-mlir/python/torch_mlir/eager_mode/ir_building.py

231 lines
7.3 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Translator from torch.jit.ScriptFunction to MLIR.
The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the
`torch` dialect.
The expected use of this module is to create an instance of one of the classes below, and then calling the
`to_mlir` method to generate the MLIR representation of the type.
Information about what types are supported by each class can be found in docstrings of each of the classes.
In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module.
The expected use for this module is to use the function
`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module`
to convert the TorchScript function into MLIR using the `torch` dialect.
"""
import abc
from typing import Any, Optional, Iterable
from typing import Union
import torch
from torch.jit import ScriptFunction
from torch_mlir import ir
from torch_mlir.dialects.builtin import FuncOp
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
class TorchMlirType(abc.ABC):
"""
A `TorchMlirType` is an object that produces MLIR
types in the `torch` dialect. The only requirement
for a class to be a subclass of `TorchMlirType` is
to define a `to_mlir(self, ir.Context) -> ir.Type`.
Each class is allowed to have different types of
__init__ methods depending on the information they
require to produce the given MLIR representation.
"""
@abc.abstractmethod
def to_mlir(self, context: ir.Context) -> ir.Type:
pass
class TorchTensorTypeError(Exception):
def __init__(self, value: str):
super().__init__()
self.value = value
def __str__(self) -> str:
return self.value
class TorchTensorType(TorchMlirType):
"""
This class is used to generate types of the form
!torch.tensor and !torch.vtensor<SHAPE, DTYPE>,
where SHAPE is a list representing the shape of the tensor,
and DTYPE is an MLIR data type.
"""
def __init__(
self,
*,
shape: Optional[Iterable[Optional[int]]] = None,
dtype: Optional[torch.dtype] = None,
):
self.shape = shape
self.dtype = dtype
if dtype is None and shape is not None:
err = "If shape is specified, dtype must also be specified"
raise TorchTensorTypeError(err)
def __str__(self):
return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})"
def to_mlir(self, context: ir.Context) -> ir.Type:
if self.dtype is None:
return ir.Type.parse("!torch.tensor", context=context)
shape_asm = self._shape_to_mlir_asm()
dtype_asm = self._dtype_to_mlir_asm()
return ir.Type.parse(
f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context
)
def _shape_to_mlir_asm(self) -> str:
if self.shape is None:
return "*"
str_sizes = map(lambda x: "?" if x is None else str(x), self.shape)
return f'[{",".join(str_sizes)}]'
def _dtype_to_mlir_asm(self) -> str:
if self.dtype in [torch.float64]:
return "f64"
if self.dtype in [torch.float, torch.float32]:
return "f32"
if self.dtype in [torch.int, torch.int32]:
return "si32"
if self.dtype in [torch.int64]:
return "si64"
if self.dtype in [torch.bool]:
return "i1"
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
class TorchNnModuleType(TorchMlirType):
"""This class is used to generate types for `!torch.nn.Module`s."""
def __init__(self, module_name: str):
self.module_name = module_name
def __str__(self):
return "torch.nn.Module"
def to_mlir(self, context: ir.Context) -> ir.Type:
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context)
class PythonType(TorchMlirType):
"""
This class is used to convert regular Python types
into their corresponding `torch` dialect representation.
The list of supported types can be found in the dictionary
`_type_to_asm_dict`.
"""
_type_to_asm_dict = {
bool: "!torch.bool",
int: "!torch.int",
type(None): "!torch.none",
}
def __init__(self, type_: Any):
self.type_ = type_
def __str__(self):
return str(self.type_)
def to_mlir(self, context: ir.Context) -> ir.Type:
asm = self._type_to_asm_dict.get(self.type_)
if asm is None:
raise NotImplementedError(f"Unsupported type: {self.type_}")
return ir.Type.parse(asm, context=context)
# TODO: This functionality should be incorporated into ModuleBuilder.import_function.
class Annotation:
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
self.types = list(
map(lambda t: PythonType(t) if isinstance(t, type) else t, types)
)
def __str__(self):
result = f"Annotation instance with {len(self.types)} types\n"
for e, type_ in enumerate(self.types):
result += f" Type of argument {e + 1}: {str(type_)}\n"
return result
def __iter__(self):
return iter(self.types)
class AnnotationConverter:
@staticmethod
def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr:
dict_attrs = []
for type_ in annotation:
if not isinstance(type_, TorchTensorType):
dict_attrs.append(ir.DictAttr.get({}, context=context))
continue
ir_type = type_.to_mlir(context)
with context:
type_attr = ir.TypeAttr.get(ir_type)
dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr})
dict_attrs.append(dict_attr)
return ir.ArrayAttr.get(dict_attrs, context=context)
def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
with module.context:
name_attr = ir.StringAttr.get(name)
for op in module.body.operations:
if isinstance(op, FuncOp) and op.name == name_attr:
return op
return None
def build_module(jit_function: ScriptFunction, annotations) -> ir.Module:
"""Translate input function into an MLIR module in the `torch` dialect.
Parameters
----------
jit_function: ScriptFunction
Function in TorchScript IR to turn into MLIR.
annotation: Annotation
Annotation object representing the types of
the operands of `jit_function`.
Returns
-------
ir.Module
Translation of the input module into an MLIR module
"""
mb = ModuleBuilder()
mb.import_function(jit_function)
func_op = get_func_op_with_name(mb.module, jit_function.name)
assert (
func_op is not None
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
func_annotation = Annotation(annotations)
arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context)
func_op.attributes["arg_attrs"] = arg_attrs
return mb.module