Add `extra_library` kwarg to `torch_mlir.compile` (#1986)

This commit adds the ability to specify extra abstract interpretation
functions in `torch_mlir.compile` to use during type refinement. This
allows users to easily add custom ops without having to interact with
MLIR or C++ directly.
pull/1988/head
Ramiro Leal-Cavazos 2023-03-30 09:20:19 -07:00 committed by GitHub
parent 6bb9965a41
commit e0f301c890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 27 deletions

View File

@ -3,11 +3,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Optional, Sequence, Union, List, Dict, Tuple
from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable
from enum import Enum
import sys
from io import StringIO
import tempfile
from torch._functorch.compile_utils import strip_overloads
import torch
@ -15,6 +16,7 @@ import torch
from torch_mlir.passmanager import PassManager
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library
class OutputType(Enum):
@ -252,7 +254,7 @@ def compile(model: torch.nn.Module,
use_tracing: bool = False,
ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None,
_completely_unsupported_in_progress_extra_library: Optional[str] = None,
extra_library: Iterable[Callable] = [],
verbose: bool = False):
"""Convert a PyTorch model to MLIR.
@ -278,12 +280,28 @@ def compile(model: torch.nn.Module,
backend_legal_ops: A list of ops that should be considered legal for
the backend. An op that is considered legal will not be decomposed.
This option is only valid with the `"torch"` output type.
extra_library: List of abstract interpretation functions to splice
into the abstract interpretation library. See
`docs/adding_abstract_interpretation_functions.md` for more info
on the format the functions should have.
verbose: If true, print extra information about the conversion.
Returns:
An MLIR module that contains the converted model in the specified
output type.
"""
extra_library_file_name = ""
if len(extra_library) != 0:
extra_library_dict = {}
for library_func in extra_library:
extra_library_dict[library_func.__name__] = library_func
mlir_library = generate_library(extra_library_dict)
extra_library_file_name = \
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
with open(extra_library_file_name, "w") as f:
f.write(mlir_library)
output_type = OutputType.get(output_type)
example_args = ExampleArgs.get(example_args)
if ignore_traced_shapes and not use_tracing:
@ -368,11 +386,8 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
if output_type == OutputType.RAW:
return mb.module
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + (
(" extra-library=" + _completely_unsupported_in_progress_extra_library)
if (_completely_unsupported_in_progress_extra_library is not None)
else ""
) + "}"
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
" extra-library=" + extra_library_file_name + "}"
run_pipeline_with_repro_report(
mb.module,
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",

View File

@ -5,7 +5,7 @@
import inspect
import re
from typing import List, Optional, Union
from typing import List, Optional, Union, Any, Dict
import torch
@ -138,25 +138,30 @@ def _verify_signature_matches_registry(f, registry: Registry):
atoms = function_name.split("")
if len(atoms) == 2:
atoms += [""]
operator = registry.get_by_triple(tuple(atoms))
try:
operator = registry.get_by_triple(tuple(atoms))
except KeyError as e:
raise ValueError(f"Unable to find op {'.'.join(atoms)} in registry")
if function_kind == "shape":
expected_signature = operator.get_shape_function_signature()
elif function_kind == "dtype":
expected_signature = operator.get_dtype_function_signature()
elif function_kind == "decomposition":
expected_signature = operator.get_decomposition_function_signature()
elif function_kind == "has_value_semantics":
expected_signature = operator.get_has_value_semantics_function_signature()
else:
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
if signature != expected_signature:
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
def generate_library(globals_) -> str:
"""Convert all op functions in `globals()` into MLIR."""
def generate_library(functions: Dict[str, Any]) -> str:
"""Convert all op functions in `functions` into MLIR."""
mb = ModuleBuilder()
# We use the registry to ensure that the shape functions are consistent
# with the ops.
registry = Registry.load()
for k, v in globals_.items():
for k, v in functions.items():
if "" not in k:
continue
if not hasattr(v, "_not_present_in_registry"):

View File

@ -267,6 +267,23 @@ class JitOperator:
return self._get_function_signature(
"decomposition", parameter_decl_builder, ret_decl_builder)
def get_has_value_semantics_function_signature(self):
"""Gets the Python function signature for this op's has_value_semantics function.
While this is technically debug-only output, it is useful to copy-paste
it from the debug dump into the library definitions, as many
ops have extra default arguments and stuff that are tedious to write out
right.
"""
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return ""
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return "None"
return self._get_function_signature(
"has_value_semantics", parameter_decl_builder, ret_decl_builder)
def __repr__(self):
f = io.StringIO()
emitter = TextEmitter(f)

View File

@ -1,5 +1,6 @@
import os
import tempfile
from typing import List, Tuple
import torch
import torch.utils.cpp_extension
@ -18,6 +19,18 @@ goofy_lib = torch.library.Library("goofy", "DEF")
goofy_lib.define("identity(Tensor t) -> Tensor")
goofy_lib.impl("identity", identity)
def goofyidentity〡shape(t: List[int]) -> List[int]:
return t
def goofyidentity〡dtype(t_rank_dtype: Tuple[int, int]) -> int:
t_rank, t_dtype = t_rank_dtype
return t_dtype
def goofyidentity〡has_value_semantics() -> None:
return
extra_library = [
goofyidentity〡shape, goofyidentity〡dtype, goofyidentity〡has_value_semantics]
class CustomOpExampleModule(torch.nn.Module):
def __init__(self):
@ -38,26 +51,12 @@ class CustomOpExampleModule(torch.nn.Module):
mod = CustomOpExampleModule()
mod.eval()
abstract_interp_src = """\
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
return %0#1 : !torch.int
}
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
"""
with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
tmp.write(abstract_interp_src)
module = torch_mlir.compile(
mod,
torch.ones(3, 4),
output_type="torch",
backend_legal_ops=["goofy.identity"],
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir",
extra_library=extra_library,
)
print(module)