mirror of https://github.com/llvm/torch-mlir
65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
"""
|
|
Translator from torch.jit.ScriptFunction to MLIR.
|
|
|
|
The following defines a function that take a torch.jit.ScriptFunction
|
|
and converts it into an MLIR module.
|
|
|
|
The expected use for this module is to use the function
|
|
`build_module(jit_function: torch.jit.ScriptFunction
|
|
annotation: Annotation) -> ir.Module`
|
|
to convert the TorchScript function into MLIR using the `torch`
|
|
dialect.
|
|
"""
|
|
|
|
from typing import Optional
|
|
|
|
from torch.jit import ScriptFunction
|
|
|
|
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
|
from torch_mlir.dialects.builtin import FuncOp
|
|
from torch_mlir import ir
|
|
|
|
from utils.annotator import AnnotationConverter as ac
|
|
from utils.annotator import Annotation
|
|
|
|
def _get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
|
with module.context:
|
|
name_attr = ir.StringAttr.get(name)
|
|
for op in module.body.operations:
|
|
if isinstance(op, FuncOp) and op.name == name_attr:
|
|
return op
|
|
|
|
return None
|
|
|
|
def build_module(jit_function: ScriptFunction,
|
|
annotation: Annotation) -> ir.Module:
|
|
"""
|
|
Translate input function into an MLIR module in the `torch` dialect.
|
|
|
|
Parameters
|
|
----------
|
|
jit_function: ScriptFunction
|
|
Function in TorchScript IR to turn into MLIR.
|
|
annotation: Annotation
|
|
Annotation object representing the types of
|
|
the operands of `jit_function`.
|
|
|
|
Returns
|
|
-------
|
|
ir.Module
|
|
Translation of the input module into an MLIR module
|
|
"""
|
|
mb = ModuleBuilder()
|
|
mb.import_function(jit_function)
|
|
|
|
func_op = _get_func_op_with_name(mb.module, jit_function.name)
|
|
assert func_op is not None, 'Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder'
|
|
|
|
arg_attrs = ac.to_mlir_array_attr(annotation, mb.context)
|
|
func_op.attributes['arg_attrs'] = arg_attrs
|
|
|
|
return mb.module
|