[torch_mlir.compile] Add support for multiple exported methods

For AoT deployments models often have multiple exported methods.
This patch enables something like this:

```
class TwoMethodsModule(torch.nn.Module):
    def sin(self, x):
        return torch.ops.aten.sin(x)

    def cos(self, x):
        return torch.ops.aten.cos(x)

example_args = torch_mlir.ExampleArgs()
example_args.add_method("sin", torch.ones(2, 3))
example_args.add_method("cos", torch.ones(2, 4))
print(torch_mlir.compile(TwoMethodsModule(), example_args))
```

In the
[long-term](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md#tools-for-advanced-aot-deployments)
we will need to reconcile this with our story for stateful models and the
backend contract being purely functional. For now, this provides some basic
infra that seems harmless. Arguably, we could tighten up the backend contract
even more to only allow a single compiled function which would prohibit this or
require building out a layer above.

Fixes #1557
pull/1569/head
Sean Silva 2022-11-09 11:15:54 +00:00
parent 2793a2bd41
commit 64914603fa
3 changed files with 176 additions and 68 deletions

View File

@ -0,0 +1,35 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
import torch
import torch_mlir
class TwoMethodsModule(torch.nn.Module):
def sin(self, x):
return torch.ops.aten.sin(x)
def cos(self, x):
return torch.ops.aten.cos(x)
example_args = torch_mlir.ExampleArgs()
example_args.add_method("sin", torch.ones(2, 3))
example_args.add_method("cos", torch.ones(2, 4))
# Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to
# check the `use_tracing` case first.
print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
# CHECK: module
# CHECK-DAG: func.func @sin
# CHECK-DAG: func.func @cos
print(torch_mlir.compile(TwoMethodsModule(), example_args))
# CHECK: module
# CHECK-DAG: func.func @sin
# CHECK-DAG: func.func @cos

View File

@ -59,14 +59,14 @@ class DictModule(torch.nn.Module):
try: try:
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs. # CHECK: Only Tensors, TensorPlaceholder's, or sequences of Tensors and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
except Exception as e: except Exception as e:
print(e) print(e)
try: try:
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs. # CHECK: Only Tensors, TensorPlaceholder's, or sequences of Tensors and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
except Exception as e: except Exception as e:
print(e) print(e)

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE. # Also available under a BSD-style license. See LICENSE.
from typing import Optional, Sequence, Union, List from typing import Optional, Sequence, Union, List, Dict, Tuple
from enum import Enum from enum import Enum
import sys import sys
@ -70,7 +70,6 @@ class OutputType(Enum):
return OutputType[spec] return OutputType[spec]
class TensorPlaceholder: class TensorPlaceholder:
"""A class that represents a formal parameter of a given shape and dtype. """A class that represents a formal parameter of a given shape and dtype.
@ -119,6 +118,120 @@ class TensorPlaceholder:
return TensorPlaceholder(shape, tensor.dtype) return TensorPlaceholder(shape, tensor.dtype)
_example_arg = Union[TensorPlaceholder, torch.Tensor]
_example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]]
_example_args = Union[_example_args_for_one_method, "ExampleArgs"]
class ExampleArgs:
"""A class representing the example arguments to an nn.Module.
In general, an nn.Module may have multiple methods that need to be compiled.
This requires example args for each method. This class is a lightweight
wrapper around a dictionary that maps method names to example arguments.
In user-facing API's, this type can always be passed interchangeably with a
single arg or list of args, which normalizes to an ExampleArgs for just
the `forward` method via the `ExampleArgs.get` static method.
"""
def __init__(self):
self._example_args = {}
def add_method(self, method_name: str, example_args: _example_args_for_one_method):
"""Adds example args for a method.
Args:
method_name: The name of the method. Must have not already been
added previously as a method.
example_args: The example args for the method.
Returns:
self, for chaining.
"""
assert method_name not in self._example_args
self._example_args[method_name] = ExampleArgs._canonicalize_args(
example_args)
return self
@staticmethod
def get(example_args: _example_args) -> "ExampleArgs":
"""Gets an ExampleArgs from one of the permissible ways to specify one.
Args:
example_args: An ExampleArgs instance or a single arg or list of args.
Returns:
An ExampleArgs instance.
"""
if isinstance(example_args, ExampleArgs):
return example_args
return ExampleArgs().add_method("forward", example_args)
@staticmethod
def _canonicalize_args(example_args: _example_args_for_one_method):
"""Canonicalize the args for one method into a tuple."""
if not isinstance(example_args, Sequence):
example_args = [example_args]
for arg in example_args:
if not isinstance(arg, _example_arg):
raise Exception(f"Only Tensors, TensorPlaceholder's, or sequences of "
f"Tensors and TensorPlaceholder's are supported as "
f"example args for method inputs. "
f"Got '{arg}'.")
return tuple(example_args)
def _get_methods(self):
return self._example_args.keys()
def _get_for_annotation(self):
result = {}
for method_name, example_args in self._example_args.items():
placeholders = []
for arg in example_args:
if isinstance(arg, TensorPlaceholder):
placeholders.append(arg)
else:
assert isinstance(arg, torch.Tensor)
placeholders.append(TensorPlaceholder.like(arg))
result[method_name] = placeholders
return result
def _get_for_tracing(
self,
use_tracing: bool,
ignore_traced_shapes: bool,
) -> Dict[str, Tuple[_example_arg, ...]]:
result = {}
for method_name, example_args in self._example_args.items():
# If we are tracing, then we need to convert any placeholders into
# concrete values.
if use_tracing:
example_args_for_trace = []
for arg in example_args:
if isinstance(arg, TensorPlaceholder):
if not ignore_traced_shapes:
# To avoid accidental footguns, we require
# `ignore_traced_shapes` to be true if we're using
# TensorPlaceholder's, as it falls into the same
# "hopefully the trace works for different inputs"
# bucket of concerns.
raise Exception(
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`")
# For any dynamic dimensions, replace them with "7"
# arbitrarily. If a user is using dynamic dimensions with
# tracing, they are walking on thin ice already -- assume
# they know what they are doing and that their trace is
# correct for any specific concrete size.
shape = [s if s != -1 else 7 for s in arg.shape]
example_args_for_trace.append(
torch.ones(*shape, dtype=arg.dtype))
else:
assert isinstance(arg, torch.Tensor)
example_args_for_trace.append(arg)
example_args = tuple(example_args_for_trace)
result[method_name] = example_args
return result
# The set of ops that are considered legal for each backend. # The set of ops that are considered legal for each backend.
# These are currently quite load-bearing, since different backends might be # These are currently quite load-bearing, since different backends might be
# missing patterns for decomposed forms of certain ops. # missing patterns for decomposed forms of certain ops.
@ -126,20 +239,17 @@ class TensorPlaceholder:
# ops in the backend contract, and move these lists somewhere deeper in the # ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops. # compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = { BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm','torch.aten.linear'], OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',], OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
OutputType.MHLO: [], OutputType.MHLO: [],
} }
_example_arg = Union[TensorPlaceholder, torch.Tensor]
def compile(model: torch.nn.Module, def compile(model: torch.nn.Module,
example_args: Union[_example_arg, Sequence[_example_arg]], example_args: _example_args,
output_type: Union[str, "OutputType"] = OutputType.TORCH, output_type: Union[str, "OutputType"] = OutputType.TORCH,
use_tracing: bool = False, use_tracing: bool = False,
ignore_traced_shapes = False, ignore_traced_shapes=False,
backend_legal_ops: Optional[Sequence[str]] = None, backend_legal_ops: Optional[Sequence[str]] = None,
verbose: bool = False): verbose: bool = False):
"""Convert a PyTorch model to MLIR. """Convert a PyTorch model to MLIR.
@ -150,7 +260,8 @@ def compile(model: torch.nn.Module,
shapes of the arguments to `forward` method of the model. shapes of the arguments to `forward` method of the model.
A single tensor is treated as a list of a single tensor. A single tensor is treated as a list of a single tensor.
A TensorPlaceholder object is also allowed in the place of any A TensorPlaceholder object is also allowed in the place of any
Tensor. Tensor. For models with multiple methods, an `ExampleArgs` object
can be passed.
output_type: The kind of output to produce. See `OutputType` for more output_type: The kind of output to produce. See `OutputType` for more
details. details.
use_tracing: If True, use `torch.jit.trace` to convert the model to use_tracing: If True, use `torch.jit.trace` to convert the model to
@ -172,6 +283,7 @@ def compile(model: torch.nn.Module,
output type. output type.
""" """
output_type = OutputType.get(output_type) output_type = OutputType.get(output_type)
example_args = ExampleArgs.get(example_args)
if ignore_traced_shapes and not use_tracing: if ignore_traced_shapes and not use_tracing:
raise Exception("`ignore_traced_shapes` requires `use_tracing`") raise Exception("`ignore_traced_shapes` requires `use_tracing`")
@ -188,66 +300,27 @@ def compile(model: torch.nn.Module,
else: else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
# Special case -- many models have just one input, so canonicalize a single
# tensor to a list of a single tensor to make the API more ergonomic.
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
example_args = (example_args,)
# If users passed in anything other than tensors or a list of tensors (e.g.
# a dictionary), we can't handle it.
if not isinstance(example_args, Sequence):
raise Exception(
"Only Tensors, TensorPlaceholders, or a sequences of Tensors and "
"TensorPlaceholders are supported as inputs.")
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
# `torch.jit.trace_module` for API inspiration.
if use_tracing: if use_tracing:
example_args_for_trace = [] scripted = torch.jit.trace_module(
for arg in example_args: model,
if isinstance(arg, TensorPlaceholder): example_args._get_for_tracing(use_tracing, ignore_traced_shapes)
if not ignore_traced_shapes: )
# To avoid accidental footguns, we require
# `ignore_traced_shapes` to be true if we're using
# TensorPlaceholder's, as it falls into the same
# "hopefully the trace works for different inputs" bucket
# of concerns.
raise Exception(
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`")
# For any dynamic dimensions, replace them with "7" arbitrarily.
# If a user is using dynamic dimensions with tracing, they are
# walking on thin ice already -- assume they know what they are
# doing.
shape = [s if s != -1 else 7 for s in arg.shape]
example_args_for_trace.append(
torch.ones(*shape, dtype=arg.dtype))
elif isinstance(arg, torch.Tensor):
example_args_for_trace.append(arg)
else:
raise Exception(
"Only Tensors, TensorPlaceholders, or a sequences of "
"Tensors and TensorPlaceholders are supported as inputs.")
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
else: else:
# Make sure that all the methods that the user requested get scripted.
# By default, PyTorch only scripts the `forward` method and transitive
# callees.
for method_name in example_args._get_methods():
torch.jit.export(getattr(model, method_name).__func__)
scripted = torch.jit.script(model) scripted = torch.jit.script(model)
# Convert all concrete inputs to TensorPlaceholder's, for consistency.
arg_placeholders = []
for arg in example_args:
if isinstance(arg, TensorPlaceholder):
arg_placeholders.append(arg)
else:
assert isinstance(arg, torch.Tensor)
arg_placeholders.append(TensorPlaceholder.like(arg))
class_annotator = ClassAnnotator() class_annotator = ClassAnnotator()
forward_annotation = [None]
for arg in arg_placeholders:
# Assume that all tensors have value semantics for now.
forward_annotation.append((arg.shape, arg.dtype, True))
class_annotator.exportNone(scripted._c._type()) class_annotator.exportNone(scripted._c._type())
class_annotator.exportPath(scripted._c._type(), ["forward"]) for method_name, example_args in example_args._get_for_annotation().items():
class_annotator.exportPath(scripted._c._type(), [method_name])
annotation = [None] # `None` is always the annotation for "self".
for arg in example_args:
annotation.append((arg.shape, arg.dtype, True))
class_annotator.annotateArgs( class_annotator.annotateArgs(
scripted._c._type(), ["forward"], forward_annotation) scripted._c._type(), [method_name], annotation)
mb = ModuleBuilder() mb = ModuleBuilder()
import_options = ImportOptions() import_options = ImportOptions()