From e0f301c8909caca4e11983445399421206e21b9b Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Thu, 30 Mar 2023 09:20:19 -0700 Subject: [PATCH] Add `extra_library` kwarg to `torch_mlir.compile` (#1986) This commit adds the ability to specify extra abstract interpretation functions in `torch_mlir.compile` to use during type refinement. This allows users to easily add custom ops without having to interact with MLIR or C++ directly. --- python/torch_mlir/__init__.py | 29 ++++++++++++++----- .../jit_ir/build_tools/library_generator.py | 15 ++++++---- .../importer/jit_ir/build_tools/registry.py | 17 +++++++++++ test/python/custom_op_shape_dtype_fn.py | 29 +++++++++---------- 4 files changed, 63 insertions(+), 27 deletions(-) diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index a4a47e157..2d8b9e882 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -3,11 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Optional, Sequence, Union, List, Dict, Tuple +from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable from enum import Enum import sys from io import StringIO +import tempfile from torch._functorch.compile_utils import strip_overloads import torch @@ -15,6 +16,7 @@ import torch from torch_mlir.passmanager import PassManager from .compiler_utils import run_pipeline_with_repro_report from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library class OutputType(Enum): @@ -252,7 +254,7 @@ def compile(model: torch.nn.Module, use_tracing: bool = False, ignore_traced_shapes=False, backend_legal_ops: Optional[Sequence[str]] = None, - _completely_unsupported_in_progress_extra_library: Optional[str] = None, + extra_library: Iterable[Callable] = [], verbose: bool = False): """Convert a PyTorch model to MLIR. @@ -278,12 +280,28 @@ def compile(model: torch.nn.Module, backend_legal_ops: A list of ops that should be considered legal for the backend. An op that is considered legal will not be decomposed. This option is only valid with the `"torch"` output type. + extra_library: List of abstract interpretation functions to splice + into the abstract interpretation library. See + `docs/adding_abstract_interpretation_functions.md` for more info + on the format the functions should have. verbose: If true, print extra information about the conversion. Returns: An MLIR module that contains the converted model in the specified output type. """ + extra_library_file_name = "" + if len(extra_library) != 0: + extra_library_dict = {} + for library_func in extra_library: + extra_library_dict[library_func.__name__] = library_func + mlir_library = generate_library(extra_library_dict) + + extra_library_file_name = \ + tempfile.gettempdir() + "/custom_op_extra_library.mlir" + with open(extra_library_file_name, "w") as f: + f.write(mlir_library) + output_type = OutputType.get(output_type) example_args = ExampleArgs.get(example_args) if ignore_traced_shapes and not use_tracing: @@ -368,11 +386,8 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: if output_type == OutputType.RAW: return mb.module - option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + ( - (" extra-library=" + _completely_unsupported_in_progress_extra_library) - if (_completely_unsupported_in_progress_extra_library is not None) - else "" - ) + "}" + option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \ + " extra-library=" + extra_library_file_name + "}" run_pipeline_with_repro_report( mb.module, f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index 6bd19fb2b..f87a7019d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -5,7 +5,7 @@ import inspect import re -from typing import List, Optional, Union +from typing import List, Optional, Union, Any, Dict import torch @@ -138,25 +138,30 @@ def _verify_signature_matches_registry(f, registry: Registry): atoms = function_name.split("〇") if len(atoms) == 2: atoms += [""] - operator = registry.get_by_triple(tuple(atoms)) + try: + operator = registry.get_by_triple(tuple(atoms)) + except KeyError as e: + raise ValueError(f"Unable to find op {'.'.join(atoms)} in registry") if function_kind == "shape": expected_signature = operator.get_shape_function_signature() elif function_kind == "dtype": expected_signature = operator.get_dtype_function_signature() elif function_kind == "decomposition": expected_signature = operator.get_decomposition_function_signature() + elif function_kind == "has_value_semantics": + expected_signature = operator.get_has_value_semantics_function_signature() else: raise ValueError(f"Invalid Op signature function kind: '{function_kind}'") if signature != expected_signature: raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}") -def generate_library(globals_) -> str: - """Convert all op functions in `globals()` into MLIR.""" +def generate_library(functions: Dict[str, Any]) -> str: + """Convert all op functions in `functions` into MLIR.""" mb = ModuleBuilder() # We use the registry to ensure that the shape functions are consistent # with the ops. registry = Registry.load() - for k, v in globals_.items(): + for k, v in functions.items(): if "〇" not in k: continue if not hasattr(v, "_not_present_in_registry"): diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 550b47802..0396df1a0 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -267,6 +267,23 @@ class JitOperator: return self._get_function_signature( "decomposition", parameter_decl_builder, ret_decl_builder) + def get_has_value_semantics_function_signature(self): + """Gets the Python function signature for this op's has_value_semantics function. + + While this is technically debug-only output, it is useful to copy-paste + it from the debug dump into the library definitions, as many + ops have extra default arguments and stuff that are tedious to write out + right. + """ + def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: + return "" + + def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: + return "None" + + return self._get_function_signature( + "has_value_semantics", parameter_decl_builder, ret_decl_builder) + def __repr__(self): f = io.StringIO() emitter = TextEmitter(f) diff --git a/test/python/custom_op_shape_dtype_fn.py b/test/python/custom_op_shape_dtype_fn.py index b71a649d9..d955ec7a2 100644 --- a/test/python/custom_op_shape_dtype_fn.py +++ b/test/python/custom_op_shape_dtype_fn.py @@ -1,5 +1,6 @@ import os import tempfile +from typing import List, Tuple import torch import torch.utils.cpp_extension @@ -18,6 +19,18 @@ goofy_lib = torch.library.Library("goofy", "DEF") goofy_lib.define("identity(Tensor t) -> Tensor") goofy_lib.impl("identity", identity) +def goofy〇identity〡shape(t: List[int]) -> List[int]: + return t + +def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int: + t_rank, t_dtype = t_rank_dtype + return t_dtype + +def goofy〇identity〡has_value_semantics() -> None: + return + +extra_library = [ + goofy〇identity〡shape, goofy〇identity〡dtype, goofy〇identity〡has_value_semantics] class CustomOpExampleModule(torch.nn.Module): def __init__(self): @@ -38,26 +51,12 @@ class CustomOpExampleModule(torch.nn.Module): mod = CustomOpExampleModule() mod.eval() -abstract_interp_src = """\ -func.func @__torch_mlir_shape_fn.goofy.identity(%arg0: !torch.list) -> !torch.list { - return %arg0 : !torch.list -} -func.func @__torch_mlir_dtype_fn.goofy.identity(%arg0 : !torch.tuple) -> !torch.int { - %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int - return %0#1 : !torch.int -} -func.func @__torch_mlir_has_value_semantics_fn.goofy.identity() { return } -""" - -with open("/tmp/custom_op_shape_dtype_fn.mlir", "w") as tmp: - tmp.write(abstract_interp_src) - module = torch_mlir.compile( mod, torch.ones(3, 4), output_type="torch", backend_legal_ops=["goofy.identity"], - _completely_unsupported_in_progress_extra_library="/tmp/custom_op_shape_dtype_fn.mlir", + extra_library=extra_library, ) print(module)