mirror of https://github.com/llvm/torch-mlir
[torch_mlir.compile] Add support for multiple exported methods
For AoT deployments models often have multiple exported methods. This patch enables something like this: ``` class TwoMethodsModule(torch.nn.Module): def sin(self, x): return torch.ops.aten.sin(x) def cos(self, x): return torch.ops.aten.cos(x) example_args = torch_mlir.ExampleArgs() example_args.add_method("sin", torch.ones(2, 3)) example_args.add_method("cos", torch.ones(2, 4)) print(torch_mlir.compile(TwoMethodsModule(), example_args)) ``` In the [long-term](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md#tools-for-advanced-aot-deployments) we will need to reconcile this with our story for stateful models and the backend contract being purely functional. For now, this provides some basic infra that seems harmless. Arguably, we could tighten up the backend contract even more to only allow a single compiled function which would prohibit this or require building out a layer above. Fixes #1557pull/1569/head
parent
2793a2bd41
commit
64914603fa
|
@ -0,0 +1,35 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
|
||||
class TwoMethodsModule(torch.nn.Module):
|
||||
def sin(self, x):
|
||||
return torch.ops.aten.sin(x)
|
||||
|
||||
def cos(self, x):
|
||||
return torch.ops.aten.cos(x)
|
||||
|
||||
|
||||
example_args = torch_mlir.ExampleArgs()
|
||||
example_args.add_method("sin", torch.ones(2, 3))
|
||||
example_args.add_method("cos", torch.ones(2, 4))
|
||||
|
||||
# Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to
|
||||
# check the `use_tracing` case first.
|
||||
|
||||
print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True))
|
||||
# CHECK: module
|
||||
# CHECK-DAG: func.func @sin
|
||||
# CHECK-DAG: func.func @cos
|
||||
|
||||
print(torch_mlir.compile(TwoMethodsModule(), example_args))
|
||||
# CHECK: module
|
||||
# CHECK-DAG: func.func @sin
|
||||
# CHECK-DAG: func.func @cos
|
|
@ -59,14 +59,14 @@ class DictModule(torch.nn.Module):
|
|||
|
||||
|
||||
try:
|
||||
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs.
|
||||
# CHECK: Only Tensors, TensorPlaceholder's, or sequences of Tensors and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||
torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
try:
|
||||
# CHECK: Only Tensors, TensorPlaceholders, or a sequences of Tensors and TensorPlaceholders are supported as inputs.
|
||||
# CHECK: Only Tensors, TensorPlaceholder's, or sequences of Tensors and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}'
|
||||
torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(e)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Optional, Sequence, Union, List
|
||||
from typing import Optional, Sequence, Union, List, Dict, Tuple
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
|
@ -70,7 +70,6 @@ class OutputType(Enum):
|
|||
return OutputType[spec]
|
||||
|
||||
|
||||
|
||||
class TensorPlaceholder:
|
||||
"""A class that represents a formal parameter of a given shape and dtype.
|
||||
|
||||
|
@ -119,6 +118,120 @@ class TensorPlaceholder:
|
|||
return TensorPlaceholder(shape, tensor.dtype)
|
||||
|
||||
|
||||
_example_arg = Union[TensorPlaceholder, torch.Tensor]
|
||||
_example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]]
|
||||
_example_args = Union[_example_args_for_one_method, "ExampleArgs"]
|
||||
|
||||
|
||||
class ExampleArgs:
|
||||
"""A class representing the example arguments to an nn.Module.
|
||||
|
||||
In general, an nn.Module may have multiple methods that need to be compiled.
|
||||
This requires example args for each method. This class is a lightweight
|
||||
wrapper around a dictionary that maps method names to example arguments.
|
||||
|
||||
In user-facing API's, this type can always be passed interchangeably with a
|
||||
single arg or list of args, which normalizes to an ExampleArgs for just
|
||||
the `forward` method via the `ExampleArgs.get` static method.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._example_args = {}
|
||||
|
||||
def add_method(self, method_name: str, example_args: _example_args_for_one_method):
|
||||
"""Adds example args for a method.
|
||||
|
||||
Args:
|
||||
method_name: The name of the method. Must have not already been
|
||||
added previously as a method.
|
||||
example_args: The example args for the method.
|
||||
Returns:
|
||||
self, for chaining.
|
||||
"""
|
||||
assert method_name not in self._example_args
|
||||
self._example_args[method_name] = ExampleArgs._canonicalize_args(
|
||||
example_args)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def get(example_args: _example_args) -> "ExampleArgs":
|
||||
"""Gets an ExampleArgs from one of the permissible ways to specify one.
|
||||
|
||||
Args:
|
||||
example_args: An ExampleArgs instance or a single arg or list of args.
|
||||
Returns:
|
||||
An ExampleArgs instance.
|
||||
"""
|
||||
if isinstance(example_args, ExampleArgs):
|
||||
return example_args
|
||||
return ExampleArgs().add_method("forward", example_args)
|
||||
|
||||
@staticmethod
|
||||
def _canonicalize_args(example_args: _example_args_for_one_method):
|
||||
"""Canonicalize the args for one method into a tuple."""
|
||||
if not isinstance(example_args, Sequence):
|
||||
example_args = [example_args]
|
||||
for arg in example_args:
|
||||
if not isinstance(arg, _example_arg):
|
||||
raise Exception(f"Only Tensors, TensorPlaceholder's, or sequences of "
|
||||
f"Tensors and TensorPlaceholder's are supported as "
|
||||
f"example args for method inputs. "
|
||||
f"Got '{arg}'.")
|
||||
return tuple(example_args)
|
||||
|
||||
def _get_methods(self):
|
||||
return self._example_args.keys()
|
||||
|
||||
def _get_for_annotation(self):
|
||||
result = {}
|
||||
for method_name, example_args in self._example_args.items():
|
||||
placeholders = []
|
||||
for arg in example_args:
|
||||
if isinstance(arg, TensorPlaceholder):
|
||||
placeholders.append(arg)
|
||||
else:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
placeholders.append(TensorPlaceholder.like(arg))
|
||||
result[method_name] = placeholders
|
||||
return result
|
||||
|
||||
def _get_for_tracing(
|
||||
self,
|
||||
use_tracing: bool,
|
||||
ignore_traced_shapes: bool,
|
||||
) -> Dict[str, Tuple[_example_arg, ...]]:
|
||||
result = {}
|
||||
for method_name, example_args in self._example_args.items():
|
||||
# If we are tracing, then we need to convert any placeholders into
|
||||
# concrete values.
|
||||
if use_tracing:
|
||||
example_args_for_trace = []
|
||||
for arg in example_args:
|
||||
if isinstance(arg, TensorPlaceholder):
|
||||
if not ignore_traced_shapes:
|
||||
# To avoid accidental footguns, we require
|
||||
# `ignore_traced_shapes` to be true if we're using
|
||||
# TensorPlaceholder's, as it falls into the same
|
||||
# "hopefully the trace works for different inputs"
|
||||
# bucket of concerns.
|
||||
raise Exception(
|
||||
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`")
|
||||
# For any dynamic dimensions, replace them with "7"
|
||||
# arbitrarily. If a user is using dynamic dimensions with
|
||||
# tracing, they are walking on thin ice already -- assume
|
||||
# they know what they are doing and that their trace is
|
||||
# correct for any specific concrete size.
|
||||
shape = [s if s != -1 else 7 for s in arg.shape]
|
||||
example_args_for_trace.append(
|
||||
torch.ones(*shape, dtype=arg.dtype))
|
||||
else:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
example_args_for_trace.append(arg)
|
||||
example_args = tuple(example_args_for_trace)
|
||||
result[method_name] = example_args
|
||||
return result
|
||||
|
||||
|
||||
# The set of ops that are considered legal for each backend.
|
||||
# These are currently quite load-bearing, since different backends might be
|
||||
# missing patterns for decomposed forms of certain ops.
|
||||
|
@ -126,20 +239,17 @@ class TensorPlaceholder:
|
|||
# ops in the backend contract, and move these lists somewhere deeper in the
|
||||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm','torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
||||
OutputType.MHLO: [],
|
||||
}
|
||||
|
||||
|
||||
_example_arg = Union[TensorPlaceholder, torch.Tensor]
|
||||
|
||||
|
||||
def compile(model: torch.nn.Module,
|
||||
example_args: Union[_example_arg, Sequence[_example_arg]],
|
||||
example_args: _example_args,
|
||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||
use_tracing: bool = False,
|
||||
ignore_traced_shapes = False,
|
||||
ignore_traced_shapes=False,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
verbose: bool = False):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
@ -150,7 +260,8 @@ def compile(model: torch.nn.Module,
|
|||
shapes of the arguments to `forward` method of the model.
|
||||
A single tensor is treated as a list of a single tensor.
|
||||
A TensorPlaceholder object is also allowed in the place of any
|
||||
Tensor.
|
||||
Tensor. For models with multiple methods, an `ExampleArgs` object
|
||||
can be passed.
|
||||
output_type: The kind of output to produce. See `OutputType` for more
|
||||
details.
|
||||
use_tracing: If True, use `torch.jit.trace` to convert the model to
|
||||
|
@ -172,6 +283,7 @@ def compile(model: torch.nn.Module,
|
|||
output type.
|
||||
"""
|
||||
output_type = OutputType.get(output_type)
|
||||
example_args = ExampleArgs.get(example_args)
|
||||
if ignore_traced_shapes and not use_tracing:
|
||||
raise Exception("`ignore_traced_shapes` requires `use_tracing`")
|
||||
|
||||
|
@ -188,66 +300,27 @@ def compile(model: torch.nn.Module,
|
|||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
# Special case -- many models have just one input, so canonicalize a single
|
||||
# tensor to a list of a single tensor to make the API more ergonomic.
|
||||
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
|
||||
example_args = (example_args,)
|
||||
|
||||
# If users passed in anything other than tensors or a list of tensors (e.g.
|
||||
# a dictionary), we can't handle it.
|
||||
if not isinstance(example_args, Sequence):
|
||||
raise Exception(
|
||||
"Only Tensors, TensorPlaceholders, or a sequences of Tensors and "
|
||||
"TensorPlaceholders are supported as inputs.")
|
||||
|
||||
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
|
||||
# `torch.jit.trace_module` for API inspiration.
|
||||
if use_tracing:
|
||||
example_args_for_trace = []
|
||||
for arg in example_args:
|
||||
if isinstance(arg, TensorPlaceholder):
|
||||
if not ignore_traced_shapes:
|
||||
# To avoid accidental footguns, we require
|
||||
# `ignore_traced_shapes` to be true if we're using
|
||||
# TensorPlaceholder's, as it falls into the same
|
||||
# "hopefully the trace works for different inputs" bucket
|
||||
# of concerns.
|
||||
raise Exception(
|
||||
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`")
|
||||
# For any dynamic dimensions, replace them with "7" arbitrarily.
|
||||
# If a user is using dynamic dimensions with tracing, they are
|
||||
# walking on thin ice already -- assume they know what they are
|
||||
# doing.
|
||||
shape = [s if s != -1 else 7 for s in arg.shape]
|
||||
example_args_for_trace.append(
|
||||
torch.ones(*shape, dtype=arg.dtype))
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
example_args_for_trace.append(arg)
|
||||
else:
|
||||
raise Exception(
|
||||
"Only Tensors, TensorPlaceholders, or a sequences of "
|
||||
"Tensors and TensorPlaceholders are supported as inputs.")
|
||||
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
|
||||
scripted = torch.jit.trace_module(
|
||||
model,
|
||||
example_args._get_for_tracing(use_tracing, ignore_traced_shapes)
|
||||
)
|
||||
else:
|
||||
# Make sure that all the methods that the user requested get scripted.
|
||||
# By default, PyTorch only scripts the `forward` method and transitive
|
||||
# callees.
|
||||
for method_name in example_args._get_methods():
|
||||
torch.jit.export(getattr(model, method_name).__func__)
|
||||
scripted = torch.jit.script(model)
|
||||
# Convert all concrete inputs to TensorPlaceholder's, for consistency.
|
||||
arg_placeholders = []
|
||||
for arg in example_args:
|
||||
if isinstance(arg, TensorPlaceholder):
|
||||
arg_placeholders.append(arg)
|
||||
else:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
arg_placeholders.append(TensorPlaceholder.like(arg))
|
||||
|
||||
class_annotator = ClassAnnotator()
|
||||
forward_annotation = [None]
|
||||
for arg in arg_placeholders:
|
||||
# Assume that all tensors have value semantics for now.
|
||||
forward_annotation.append((arg.shape, arg.dtype, True))
|
||||
class_annotator.exportNone(scripted._c._type())
|
||||
class_annotator.exportPath(scripted._c._type(), ["forward"])
|
||||
class_annotator.annotateArgs(
|
||||
scripted._c._type(), ["forward"], forward_annotation)
|
||||
for method_name, example_args in example_args._get_for_annotation().items():
|
||||
class_annotator.exportPath(scripted._c._type(), [method_name])
|
||||
annotation = [None] # `None` is always the annotation for "self".
|
||||
for arg in example_args:
|
||||
annotation.append((arg.shape, arg.dtype, True))
|
||||
class_annotator.annotateArgs(
|
||||
scripted._c._type(), [method_name], annotation)
|
||||
|
||||
mb = ModuleBuilder()
|
||||
import_options = ImportOptions()
|
||||
|
|
Loading…
Reference in New Issue