Add `extra_library` kwarg to `torch_mlir.compile` (#1986)

This commit adds the ability to specify extra abstract interpretation
functions in `torch_mlir.compile` to use during type refinement. This
allows users to easily add custom ops without having to interact with
MLIR or C++ directly.
pull/1988/head
Ramiro Leal-Cavazos 2023-03-30 09:20:19 -07:00 committed by GitHub
parent 6bb9965a41
commit e0f301c890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 27 deletions

View File

@ -3,11 +3,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE. # Also available under a BSD-style license. See LICENSE.
from typing import Optional, Sequence, Union, List, Dict, Tuple from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable
from enum import Enum from enum import Enum
import sys import sys
from io import StringIO from io import StringIO
import tempfile
from torch._functorch.compile_utils import strip_overloads from torch._functorch.compile_utils import strip_overloads
import torch import torch
@ -15,6 +16,7 @@ import torch
from torch_mlir.passmanager import PassManager from torch_mlir.passmanager import PassManager
from .compiler_utils import run_pipeline_with_repro_report from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library
class OutputType(Enum): class OutputType(Enum):
@ -252,7 +254,7 @@ def compile(model: torch.nn.Module,
use_tracing: bool = False, use_tracing: bool = False,
ignore_traced_shapes=False, ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None, backend_legal_ops: Optional[Sequence[str]] = None,
_completely_unsupported_in_progress_extra_library: Optional[str] = None, extra_library: Iterable[Callable] = [],
verbose: bool = False): verbose: bool = False):
"""Convert a PyTorch model to MLIR. """Convert a PyTorch model to MLIR.
@ -278,12 +280,28 @@ def compile(model: torch.nn.Module,
backend_legal_ops: A list of ops that should be considered legal for backend_legal_ops: A list of ops that should be considered legal for
the backend. An op that is considered legal will not be decomposed. the backend. An op that is considered legal will not be decomposed.
This option is only valid with the `"torch"` output type. This option is only valid with the `"torch"` output type.
extra_library: List of abstract interpretation functions to splice
into the abstract interpretation library. See
`docs/adding_abstract_interpretation_functions.md` for more info
on the format the functions should have.
verbose: If true, print extra information about the conversion. verbose: If true, print extra information about the conversion.
Returns: Returns:
An MLIR module that contains the converted model in the specified An MLIR module that contains the converted model in the specified
output type. output type.
""" """
extra_library_file_name = ""
if len(extra_library) != 0:
extra_library_dict = {}
for library_func in extra_library:
extra_library_dict[library_func.__name__] = library_func
mlir_library = generate_library(extra_library_dict)
extra_library_file_name = \
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
with open(extra_library_file_name, "w") as f:
f.write(mlir_library)
output_type = OutputType.get(output_type) output_type = OutputType.get(output_type)
example_args = ExampleArgs.get(example_args) example_args = ExampleArgs.get(example_args)
if ignore_traced_shapes and not use_tracing: if ignore_traced_shapes and not use_tracing:
@ -368,11 +386,8 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
if output_type == OutputType.RAW: if output_type == OutputType.RAW:
return mb.module return mb.module
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + ( option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
(" extra-library=" + _completely_unsupported_in_progress_extra_library) " extra-library=" + extra_library_file_name + "}"
if (_completely_unsupported_in_progress_extra_library is not None)
else ""
) + "}"
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
mb.module, mb.module,
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",

View File

@ -5,7 +5,7 @@
import inspect import inspect
import re import re
from typing import List, Optional, Union from typing import List, Optional, Union, Any, Dict
import torch import torch
@ -138,25 +138,30 @@ def _verify_signature_matches_registry(f, registry: Registry):
atoms = function_name.split("") atoms = function_name.split("")
if len(atoms) == 2: if len(atoms) == 2:
atoms += [""] atoms += [""]
try:
operator = registry.get_by_triple(tuple(atoms)) operator = registry.get_by_triple(tuple(atoms))
except KeyError as e:
raise ValueError(f"Unable to find op {'.'.join(atoms)} in registry")
if function_kind == "shape": if function_kind == "shape":
expected_signature = operator.get_shape_function_signature() expected_signature = operator.get_shape_function_signature()
elif function_kind == "dtype": elif function_kind == "dtype":
expected_signature = operator.get_dtype_function_signature() expected_signature = operator.get_dtype_function_signature()
elif function_kind == "decomposition": elif function_kind == "decomposition":
expected_signature = operator.get_decomposition_function_signature() expected_signature = operator.get_decomposition_function_signature()
elif function_kind == "has_value_semantics":
expected_signature = operator.get_has_value_semantics_function_signature()
else: else:
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'") raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
if signature != expected_signature: if signature != expected_signature:
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}") raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
def generate_library(globals_) -> str: def generate_library(functions: Dict[str, Any]) -> str:
"""Convert all op functions in `globals()` into MLIR.""" """Convert all op functions in `functions` into MLIR."""
mb = ModuleBuilder() mb = ModuleBuilder()
# We use the registry to ensure that the shape functions are consistent # We use the registry to ensure that the shape functions are consistent
# with the ops. # with the ops.
registry = Registry.load() registry = Registry.load()
for k, v in globals_.items(): for k, v in functions.items():
if "" not in k: if "" not in k:
continue continue
if not hasattr(v, "_not_present_in_registry"): if not hasattr(v, "_not_present_in_registry"):

View File

@ -267,6 +267,23 @@ class JitOperator:
return self._get_function_signature( return self._get_function_signature(
"decomposition", parameter_decl_builder, ret_decl_builder) "decomposition", parameter_decl_builder, ret_decl_builder)
def get_has_value_semantics_function_signature(self):
"""Gets the Python function signature for this op's has_value_semantics function.
While this is technically debug-only output, it is useful to copy-paste
it from the debug dump into the library definitions, as many
ops have extra default arguments and stuff that are tedious to write out
right.
"""
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return ""
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return "None"
return self._get_function_signature(
"has_value_semantics", parameter_decl_builder, ret_decl_builder)
def __repr__(self): def __repr__(self):
f = io.StringIO() f = io.StringIO()
emitter = TextEmitter(f) emitter = TextEmitter(f)

View File

@ -1,5 +1,6 @@
import os import os
import tempfile import tempfile
from typing import List, Tuple
import torch import torch
import torch.utils.cpp_extension import torch.utils.cpp_extension
@ -18,6 +19,18 @@ goofy_lib = torch.library.Library("goofy", "DEF")
goofy_lib.define("identity(Tensor t) -> Tensor") goofy_lib.define("identity(Tensor t) -> Tensor")
goofy_lib.impl("identity", identity) goofy_lib.impl("identity", identity)
def goofyidentity〡shape(t: List[int]) -> List[int]:
return t
def goofyidentity〡dtype(t_rank_dtype: Tuple[int, int]) -> int:
t_rank, t_dtype = t_rank_dtype
return t_dtype
def goofyidentity〡has_value_semantics() -> None:
return
extra_library = [
goofyidentity〡shape, goofyidentity〡dtype, goofyidentity〡has_value_semantics]
class CustomOpExampleModule(torch.nn.Module): class CustomOpExampleModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -38,26 +51,12 @@ class CustomOpExampleModule(torch.nn.Module):
mod = CustomOpExampleModule() mod = CustomOpExampleModule()
mod.eval() mod.eval()
abstract_interp_src = """\
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
return %0#1 : !torch.int
}
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
"""
with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
tmp.write(abstract_interp_src)
module = torch_mlir.compile( module = torch_mlir.compile(
mod, mod,
torch.ones(3, 4), torch.ones(3, 4),
output_type="torch", output_type="torch",
backend_legal_ops=["goofy.identity"], backend_legal_ops=["goofy.identity"],
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir", extra_library=extra_library,
) )
print(module) print(module)