mirror of https://github.com/llvm/torch-mlir
Add `extra_library` kwarg to `torch_mlir.compile` (#1986)
This commit adds the ability to specify extra abstract interpretation functions in `torch_mlir.compile` to use during type refinement. This allows users to easily add custom ops without having to interact with MLIR or C++ directly.pull/1988/head
parent
6bb9965a41
commit
e0f301c890
|
@ -3,11 +3,12 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Optional, Sequence, Union, List, Dict, Tuple
|
||||
from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
from io import StringIO
|
||||
import tempfile
|
||||
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
import torch
|
||||
|
@ -15,6 +16,7 @@ import torch
|
|||
from torch_mlir.passmanager import PassManager
|
||||
from .compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library
|
||||
|
||||
|
||||
class OutputType(Enum):
|
||||
|
@ -252,7 +254,7 @@ def compile(model: torch.nn.Module,
|
|||
use_tracing: bool = False,
|
||||
ignore_traced_shapes=False,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
_completely_unsupported_in_progress_extra_library: Optional[str] = None,
|
||||
extra_library: Iterable[Callable] = [],
|
||||
verbose: bool = False):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
|
@ -278,12 +280,28 @@ def compile(model: torch.nn.Module,
|
|||
backend_legal_ops: A list of ops that should be considered legal for
|
||||
the backend. An op that is considered legal will not be decomposed.
|
||||
This option is only valid with the `"torch"` output type.
|
||||
extra_library: List of abstract interpretation functions to splice
|
||||
into the abstract interpretation library. See
|
||||
`docs/adding_abstract_interpretation_functions.md` for more info
|
||||
on the format the functions should have.
|
||||
verbose: If true, print extra information about the conversion.
|
||||
|
||||
Returns:
|
||||
An MLIR module that contains the converted model in the specified
|
||||
output type.
|
||||
"""
|
||||
extra_library_file_name = ""
|
||||
if len(extra_library) != 0:
|
||||
extra_library_dict = {}
|
||||
for library_func in extra_library:
|
||||
extra_library_dict[library_func.__name__] = library_func
|
||||
mlir_library = generate_library(extra_library_dict)
|
||||
|
||||
extra_library_file_name = \
|
||||
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
|
||||
with open(extra_library_file_name, "w") as f:
|
||||
f.write(mlir_library)
|
||||
|
||||
output_type = OutputType.get(output_type)
|
||||
example_args = ExampleArgs.get(example_args)
|
||||
if ignore_traced_shapes and not use_tracing:
|
||||
|
@ -368,11 +386,8 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
if output_type == OutputType.RAW:
|
||||
return mb.module
|
||||
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + (
|
||||
(" extra-library=" + _completely_unsupported_in_progress_extra_library)
|
||||
if (_completely_unsupported_in_progress_extra_library is not None)
|
||||
else ""
|
||||
) + "}"
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
|
||||
" extra-library=" + extra_library_file_name + "}"
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
import inspect
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -138,25 +138,30 @@ def _verify_signature_matches_registry(f, registry: Registry):
|
|||
atoms = function_name.split("〇")
|
||||
if len(atoms) == 2:
|
||||
atoms += [""]
|
||||
operator = registry.get_by_triple(tuple(atoms))
|
||||
try:
|
||||
operator = registry.get_by_triple(tuple(atoms))
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unable to find op {'.'.join(atoms)} in registry")
|
||||
if function_kind == "shape":
|
||||
expected_signature = operator.get_shape_function_signature()
|
||||
elif function_kind == "dtype":
|
||||
expected_signature = operator.get_dtype_function_signature()
|
||||
elif function_kind == "decomposition":
|
||||
expected_signature = operator.get_decomposition_function_signature()
|
||||
elif function_kind == "has_value_semantics":
|
||||
expected_signature = operator.get_has_value_semantics_function_signature()
|
||||
else:
|
||||
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
|
||||
if signature != expected_signature:
|
||||
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
|
||||
|
||||
def generate_library(globals_) -> str:
|
||||
"""Convert all op functions in `globals()` into MLIR."""
|
||||
def generate_library(functions: Dict[str, Any]) -> str:
|
||||
"""Convert all op functions in `functions` into MLIR."""
|
||||
mb = ModuleBuilder()
|
||||
# We use the registry to ensure that the shape functions are consistent
|
||||
# with the ops.
|
||||
registry = Registry.load()
|
||||
for k, v in globals_.items():
|
||||
for k, v in functions.items():
|
||||
if "〇" not in k:
|
||||
continue
|
||||
if not hasattr(v, "_not_present_in_registry"):
|
||||
|
|
|
@ -267,6 +267,23 @@ class JitOperator:
|
|||
return self._get_function_signature(
|
||||
"decomposition", parameter_decl_builder, ret_decl_builder)
|
||||
|
||||
def get_has_value_semantics_function_signature(self):
|
||||
"""Gets the Python function signature for this op's has_value_semantics function.
|
||||
|
||||
While this is technically debug-only output, it is useful to copy-paste
|
||||
it from the debug dump into the library definitions, as many
|
||||
ops have extra default arguments and stuff that are tedious to write out
|
||||
right.
|
||||
"""
|
||||
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
return ""
|
||||
|
||||
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||
return "None"
|
||||
|
||||
return self._get_function_signature(
|
||||
"has_value_semantics", parameter_decl_builder, ret_decl_builder)
|
||||
|
||||
def __repr__(self):
|
||||
f = io.StringIO()
|
||||
emitter = TextEmitter(f)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import tempfile
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
|
@ -18,6 +19,18 @@ goofy_lib = torch.library.Library("goofy", "DEF")
|
|||
goofy_lib.define("identity(Tensor t) -> Tensor")
|
||||
goofy_lib.impl("identity", identity)
|
||||
|
||||
def goofy〇identity〡shape(t: List[int]) -> List[int]:
|
||||
return t
|
||||
|
||||
def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int:
|
||||
t_rank, t_dtype = t_rank_dtype
|
||||
return t_dtype
|
||||
|
||||
def goofy〇identity〡has_value_semantics() -> None:
|
||||
return
|
||||
|
||||
extra_library = [
|
||||
goofy〇identity〡shape, goofy〇identity〡dtype, goofy〇identity〡has_value_semantics]
|
||||
|
||||
class CustomOpExampleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -38,26 +51,12 @@ class CustomOpExampleModule(torch.nn.Module):
|
|||
mod = CustomOpExampleModule()
|
||||
mod.eval()
|
||||
|
||||
abstract_interp_src = """\
|
||||
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
return %arg0 : !torch.list<int>
|
||||
}
|
||||
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
|
||||
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
|
||||
return %0#1 : !torch.int
|
||||
}
|
||||
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
|
||||
"""
|
||||
|
||||
with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
|
||||
tmp.write(abstract_interp_src)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
mod,
|
||||
torch.ones(3, 4),
|
||||
output_type="torch",
|
||||
backend_legal_ops=["goofy.identity"],
|
||||
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir",
|
||||
extra_library=extra_library,
|
||||
)
|
||||
|
||||
print(module)
|
||||
|
|
Loading…
Reference in New Issue