mirror of https://github.com/llvm/torch-mlir
Add `extra_library` kwarg to `torch_mlir.compile` (#1986)
This commit adds the ability to specify extra abstract interpretation functions in `torch_mlir.compile` to use during type refinement. This allows users to easily add custom ops without having to interact with MLIR or C++ directly.pull/1988/head
parent
6bb9965a41
commit
e0f301c890
|
@ -3,11 +3,12 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
from typing import Optional, Sequence, Union, List, Dict, Tuple
|
from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from torch._functorch.compile_utils import strip_overloads
|
from torch._functorch.compile_utils import strip_overloads
|
||||||
import torch
|
import torch
|
||||||
|
@ -15,6 +16,7 @@ import torch
|
||||||
from torch_mlir.passmanager import PassManager
|
from torch_mlir.passmanager import PassManager
|
||||||
from .compiler_utils import run_pipeline_with_repro_report
|
from .compiler_utils import run_pipeline_with_repro_report
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||||
|
from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library
|
||||||
|
|
||||||
|
|
||||||
class OutputType(Enum):
|
class OutputType(Enum):
|
||||||
|
@ -252,7 +254,7 @@ def compile(model: torch.nn.Module,
|
||||||
use_tracing: bool = False,
|
use_tracing: bool = False,
|
||||||
ignore_traced_shapes=False,
|
ignore_traced_shapes=False,
|
||||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||||
_completely_unsupported_in_progress_extra_library: Optional[str] = None,
|
extra_library: Iterable[Callable] = [],
|
||||||
verbose: bool = False):
|
verbose: bool = False):
|
||||||
"""Convert a PyTorch model to MLIR.
|
"""Convert a PyTorch model to MLIR.
|
||||||
|
|
||||||
|
@ -278,12 +280,28 @@ def compile(model: torch.nn.Module,
|
||||||
backend_legal_ops: A list of ops that should be considered legal for
|
backend_legal_ops: A list of ops that should be considered legal for
|
||||||
the backend. An op that is considered legal will not be decomposed.
|
the backend. An op that is considered legal will not be decomposed.
|
||||||
This option is only valid with the `"torch"` output type.
|
This option is only valid with the `"torch"` output type.
|
||||||
|
extra_library: List of abstract interpretation functions to splice
|
||||||
|
into the abstract interpretation library. See
|
||||||
|
`docs/adding_abstract_interpretation_functions.md` for more info
|
||||||
|
on the format the functions should have.
|
||||||
verbose: If true, print extra information about the conversion.
|
verbose: If true, print extra information about the conversion.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An MLIR module that contains the converted model in the specified
|
An MLIR module that contains the converted model in the specified
|
||||||
output type.
|
output type.
|
||||||
"""
|
"""
|
||||||
|
extra_library_file_name = ""
|
||||||
|
if len(extra_library) != 0:
|
||||||
|
extra_library_dict = {}
|
||||||
|
for library_func in extra_library:
|
||||||
|
extra_library_dict[library_func.__name__] = library_func
|
||||||
|
mlir_library = generate_library(extra_library_dict)
|
||||||
|
|
||||||
|
extra_library_file_name = \
|
||||||
|
tempfile.gettempdir() + "/custom_op_extra_library.mlir"
|
||||||
|
with open(extra_library_file_name, "w") as f:
|
||||||
|
f.write(mlir_library)
|
||||||
|
|
||||||
output_type = OutputType.get(output_type)
|
output_type = OutputType.get(output_type)
|
||||||
example_args = ExampleArgs.get(example_args)
|
example_args = ExampleArgs.get(example_args)
|
||||||
if ignore_traced_shapes and not use_tracing:
|
if ignore_traced_shapes and not use_tracing:
|
||||||
|
@ -368,11 +386,8 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
||||||
if output_type == OutputType.RAW:
|
if output_type == OutputType.RAW:
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + (
|
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
|
||||||
(" extra-library=" + _completely_unsupported_in_progress_extra_library)
|
" extra-library=" + extra_library_file_name + "}"
|
||||||
if (_completely_unsupported_in_progress_extra_library is not None)
|
|
||||||
else ""
|
|
||||||
) + "}"
|
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
mb.module,
|
mb.module,
|
||||||
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union, Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -138,25 +138,30 @@ def _verify_signature_matches_registry(f, registry: Registry):
|
||||||
atoms = function_name.split("〇")
|
atoms = function_name.split("〇")
|
||||||
if len(atoms) == 2:
|
if len(atoms) == 2:
|
||||||
atoms += [""]
|
atoms += [""]
|
||||||
operator = registry.get_by_triple(tuple(atoms))
|
try:
|
||||||
|
operator = registry.get_by_triple(tuple(atoms))
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"Unable to find op {'.'.join(atoms)} in registry")
|
||||||
if function_kind == "shape":
|
if function_kind == "shape":
|
||||||
expected_signature = operator.get_shape_function_signature()
|
expected_signature = operator.get_shape_function_signature()
|
||||||
elif function_kind == "dtype":
|
elif function_kind == "dtype":
|
||||||
expected_signature = operator.get_dtype_function_signature()
|
expected_signature = operator.get_dtype_function_signature()
|
||||||
elif function_kind == "decomposition":
|
elif function_kind == "decomposition":
|
||||||
expected_signature = operator.get_decomposition_function_signature()
|
expected_signature = operator.get_decomposition_function_signature()
|
||||||
|
elif function_kind == "has_value_semantics":
|
||||||
|
expected_signature = operator.get_has_value_semantics_function_signature()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
|
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
|
||||||
if signature != expected_signature:
|
if signature != expected_signature:
|
||||||
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
|
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
|
||||||
|
|
||||||
def generate_library(globals_) -> str:
|
def generate_library(functions: Dict[str, Any]) -> str:
|
||||||
"""Convert all op functions in `globals()` into MLIR."""
|
"""Convert all op functions in `functions` into MLIR."""
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
# We use the registry to ensure that the shape functions are consistent
|
# We use the registry to ensure that the shape functions are consistent
|
||||||
# with the ops.
|
# with the ops.
|
||||||
registry = Registry.load()
|
registry = Registry.load()
|
||||||
for k, v in globals_.items():
|
for k, v in functions.items():
|
||||||
if "〇" not in k:
|
if "〇" not in k:
|
||||||
continue
|
continue
|
||||||
if not hasattr(v, "_not_present_in_registry"):
|
if not hasattr(v, "_not_present_in_registry"):
|
||||||
|
|
|
@ -267,6 +267,23 @@ class JitOperator:
|
||||||
return self._get_function_signature(
|
return self._get_function_signature(
|
||||||
"decomposition", parameter_decl_builder, ret_decl_builder)
|
"decomposition", parameter_decl_builder, ret_decl_builder)
|
||||||
|
|
||||||
|
def get_has_value_semantics_function_signature(self):
|
||||||
|
"""Gets the Python function signature for this op's has_value_semantics function.
|
||||||
|
|
||||||
|
While this is technically debug-only output, it is useful to copy-paste
|
||||||
|
it from the debug dump into the library definitions, as many
|
||||||
|
ops have extra default arguments and stuff that are tedious to write out
|
||||||
|
right.
|
||||||
|
"""
|
||||||
|
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
return self._get_function_signature(
|
||||||
|
"has_value_semantics", parameter_decl_builder, ret_decl_builder)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
f = io.StringIO()
|
f = io.StringIO()
|
||||||
emitter = TextEmitter(f)
|
emitter = TextEmitter(f)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
|
@ -18,6 +19,18 @@ goofy_lib = torch.library.Library("goofy", "DEF")
|
||||||
goofy_lib.define("identity(Tensor t) -> Tensor")
|
goofy_lib.define("identity(Tensor t) -> Tensor")
|
||||||
goofy_lib.impl("identity", identity)
|
goofy_lib.impl("identity", identity)
|
||||||
|
|
||||||
|
def goofy〇identity〡shape(t: List[int]) -> List[int]:
|
||||||
|
return t
|
||||||
|
|
||||||
|
def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int:
|
||||||
|
t_rank, t_dtype = t_rank_dtype
|
||||||
|
return t_dtype
|
||||||
|
|
||||||
|
def goofy〇identity〡has_value_semantics() -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
extra_library = [
|
||||||
|
goofy〇identity〡shape, goofy〇identity〡dtype, goofy〇identity〡has_value_semantics]
|
||||||
|
|
||||||
class CustomOpExampleModule(torch.nn.Module):
|
class CustomOpExampleModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -38,26 +51,12 @@ class CustomOpExampleModule(torch.nn.Module):
|
||||||
mod = CustomOpExampleModule()
|
mod = CustomOpExampleModule()
|
||||||
mod.eval()
|
mod.eval()
|
||||||
|
|
||||||
abstract_interp_src = """\
|
|
||||||
func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list<int>) -> !torch.list<int> {
|
|
||||||
return %arg0 : !torch.list<int>
|
|
||||||
}
|
|
||||||
func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple<int, int>) -> !torch.int {
|
|
||||||
%0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int
|
|
||||||
return %0#1 : !torch.int
|
|
||||||
}
|
|
||||||
func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return }
|
|
||||||
"""
|
|
||||||
|
|
||||||
with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp:
|
|
||||||
tmp.write(abstract_interp_src)
|
|
||||||
|
|
||||||
module = torch_mlir.compile(
|
module = torch_mlir.compile(
|
||||||
mod,
|
mod,
|
||||||
torch.ones(3, 4),
|
torch.ones(3, 4),
|
||||||
output_type="torch",
|
output_type="torch",
|
||||||
backend_legal_ops=["goofy.identity"],
|
backend_legal_ops=["goofy.identity"],
|
||||||
_completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir",
|
extra_library=extra_library,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(module)
|
print(module)
|
||||||
|
|
Loading…
Reference in New Issue