mirror of https://github.com/llvm/torch-mlir
This PR implements an eager mode backend for PyTorch through the torch-mlir framework. This is accomplished by overriding the `__torch_dispatch__` class method on wrapper subclass `TorchMLIRTensor(torch.Tensor)`.
Effectively, this mode works by compiling op by op as the NN is eagerly executed by PyTorch. Entailed in that compilation is building a representation of the op that can be `torch.jit.script`ed, importing using `ModuleBuilder`, and then executing (e.g., with `RefBackendLinalgOnTensorsBackend`). This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported op). Currently, all e2e tests pass execpt for two that involve an upstream PyTorch bug (https://github.com/pytorch/pytorch/issues/74400). High priority next steps: 1. A compile cache in order to speed up reruns of the same NN. 2. Integration with IREE (though not in this repo). 3. Integration with `torch.distributed`.pull/687/head snapshot-20220322.340
parent
f9d34596e8
commit
fe8ac57e6d
|
@ -15,13 +15,13 @@ from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
|||
|
||||
# Available test configs.
|
||||
from torch_mlir_e2e_test.torchscript.configs import (
|
||||
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig
|
||||
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
|
||||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||
|
||||
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS, EAGER_MODE_XFAIL_SET
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
||||
|
@ -57,7 +57,7 @@ from . import cast
|
|||
from . import index_put
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external', 'eager_mode']
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
parser.add_argument('-c', '--config',
|
||||
choices=config_choices,
|
||||
|
@ -121,6 +121,9 @@ def main():
|
|||
elif args.config == 'torchscript':
|
||||
config = TorchScriptTestConfig()
|
||||
xfail_set = {}
|
||||
elif args.config == 'eager_mode':
|
||||
config = EagerModeTestConfig()
|
||||
xfail_set = EAGER_MODE_XFAIL_SET
|
||||
elif args.config == 'external':
|
||||
with open(args.external_config, 'r') as f:
|
||||
code = compile(f.read(), args.external_config, 'exec')
|
||||
|
|
|
@ -113,9 +113,9 @@ class BernoulliModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: BernoulliModule())
|
||||
def BernoulliModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(256, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
tu.rand(512, 1024, 8).double(),
|
||||
tu.rand(1024, 2048, 4).double(),
|
||||
tu.rand(1024, 256, 4).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -188,9 +188,9 @@ class BernoulliFloatModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
||||
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(256, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
tu.rand(512, 1024, 8).double(),
|
||||
tu.rand(1024, 2048, 4).double(),
|
||||
tu.rand(1024, 512, 4).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -228,9 +228,9 @@ class BernoulliTensorModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 512, 8).double(),
|
||||
tu.rand(512, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
tu.rand(1024, 1024, 16).double(),
|
||||
tu.rand(1024, 1024, 16).double(),
|
||||
tu.rand(1024, 2048, 8).double(),
|
||||
tu.rand(1024, 2048, 8).double(),
|
||||
tu.rand(1024, 512, 8).double(),
|
||||
tu.rand(1024, 512, 8).double())
|
||||
|
|
|
@ -21,6 +21,13 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
}
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
EAGER_MODE_XFAIL_SET = REFBACKEND_XFAIL_SET.union({
|
||||
# These fail because an upstream pytorch bug; more information at the following issue
|
||||
# https://github.com/pytorch/pytorch/issues/74400
|
||||
"ElementwiseMulScalarModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
})
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
|
|
|
@ -57,6 +57,12 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
|||
add_subdirectory(torch_mlir_e2e_test)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
# Eager mode
|
||||
################################################################################
|
||||
|
||||
add_subdirectory(torch_mlir/eager_mode)
|
||||
|
||||
################################################################################
|
||||
# Generate packages and shared library
|
||||
# Downstreams typically will not use these, but they are useful for local
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
#-------------------------------------------------------------------------------
|
||||
# Setup PyTorch
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
|
||||
find_package(Torch 1.8 REQUIRED)
|
||||
|
||||
TorchMLIRConfigurePyTorch()
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Subdirectories
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
## Declare the sources of the Python module.
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES_GLOB eager_mode/*.py lazytensor/*.py
|
||||
)
|
|
@ -0,0 +1,230 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
"""
|
||||
Translator from torch.jit.ScriptFunction to MLIR.
|
||||
|
||||
|
||||
The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the
|
||||
`torch` dialect.
|
||||
|
||||
The expected use of this module is to create an instance of one of the classes below, and then calling the
|
||||
`to_mlir` method to generate the MLIR representation of the type.
|
||||
|
||||
Information about what types are supported by each class can be found in docstrings of each of the classes.
|
||||
|
||||
In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module.
|
||||
|
||||
The expected use for this module is to use the function
|
||||
`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module`
|
||||
to convert the TorchScript function into MLIR using the `torch` dialect.
|
||||
"""
|
||||
|
||||
import abc
|
||||
from typing import Any, Optional, Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.jit import ScriptFunction
|
||||
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects.builtin import FuncOp
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
|
||||
|
||||
class TorchMlirType(abc.ABC):
|
||||
"""
|
||||
A `TorchMlirType` is an object that produces MLIR
|
||||
types in the `torch` dialect. The only requirement
|
||||
for a class to be a subclass of `TorchMlirType` is
|
||||
to define a `to_mlir(self, ir.Context) -> ir.Type`.
|
||||
Each class is allowed to have different types of
|
||||
__init__ methods depending on the information they
|
||||
require to produce the given MLIR representation.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
pass
|
||||
|
||||
|
||||
class TorchTensorTypeError(Exception):
|
||||
def __init__(self, value: str):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class TorchTensorType(TorchMlirType):
|
||||
"""
|
||||
This class is used to generate types of the form
|
||||
!torch.tensor and !torch.vtensor<SHAPE, DTYPE>,
|
||||
where SHAPE is a list representing the shape of the tensor,
|
||||
and DTYPE is an MLIR data type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shape: Optional[Iterable[Optional[int]]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
if dtype is None and shape is not None:
|
||||
err = "If shape is specified, dtype must also be specified"
|
||||
raise TorchTensorTypeError(err)
|
||||
|
||||
def __str__(self):
|
||||
return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})"
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
if self.dtype is None:
|
||||
return ir.Type.parse("!torch.tensor", context=context)
|
||||
|
||||
shape_asm = self._shape_to_mlir_asm()
|
||||
dtype_asm = self._dtype_to_mlir_asm()
|
||||
return ir.Type.parse(
|
||||
f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context
|
||||
)
|
||||
|
||||
def _shape_to_mlir_asm(self) -> str:
|
||||
if self.shape is None:
|
||||
return "*"
|
||||
|
||||
str_sizes = map(lambda x: "?" if x is None else str(x), self.shape)
|
||||
return f'[{",".join(str_sizes)}]'
|
||||
|
||||
def _dtype_to_mlir_asm(self) -> str:
|
||||
if self.dtype in [torch.float64]:
|
||||
return "f64"
|
||||
if self.dtype in [torch.float, torch.float32]:
|
||||
return "f32"
|
||||
if self.dtype in [torch.int, torch.int32]:
|
||||
return "si32"
|
||||
if self.dtype in [torch.int64]:
|
||||
return "si64"
|
||||
if self.dtype in [torch.bool]:
|
||||
return "i1"
|
||||
|
||||
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
|
||||
|
||||
|
||||
class TorchNnModuleType(TorchMlirType):
|
||||
"""This class is used to generate types for `!torch.nn.Module`s."""
|
||||
|
||||
def __init__(self, module_name: str):
|
||||
self.module_name = module_name
|
||||
|
||||
def __str__(self):
|
||||
return "torch.nn.Module"
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context)
|
||||
|
||||
|
||||
class PythonType(TorchMlirType):
|
||||
"""
|
||||
This class is used to convert regular Python types
|
||||
into their corresponding `torch` dialect representation.
|
||||
The list of supported types can be found in the dictionary
|
||||
`_type_to_asm_dict`.
|
||||
"""
|
||||
|
||||
_type_to_asm_dict = {
|
||||
bool: "!torch.bool",
|
||||
int: "!torch.int",
|
||||
type(None): "!torch.none",
|
||||
}
|
||||
|
||||
def __init__(self, type_: Any):
|
||||
self.type_ = type_
|
||||
|
||||
def __str__(self):
|
||||
return str(self.type_)
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
asm = self._type_to_asm_dict.get(self.type_)
|
||||
if asm is None:
|
||||
raise NotImplementedError(f"Unsupported type: {self.type_}")
|
||||
return ir.Type.parse(asm, context=context)
|
||||
|
||||
|
||||
# TODO: This functionality should be incorporated into ModuleBuilder.import_function.
|
||||
class Annotation:
|
||||
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
|
||||
self.types = list(
|
||||
map(lambda t: PythonType(t) if isinstance(t, type) else t, types)
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
result = f"Annotation instance with {len(self.types)} types\n"
|
||||
for e, type_ in enumerate(self.types):
|
||||
result += f" Type of argument {e + 1}: {str(type_)}\n"
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.types)
|
||||
|
||||
|
||||
class AnnotationConverter:
|
||||
@staticmethod
|
||||
def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr:
|
||||
dict_attrs = []
|
||||
for type_ in annotation:
|
||||
if not isinstance(type_, TorchTensorType):
|
||||
dict_attrs.append(ir.DictAttr.get({}, context=context))
|
||||
continue
|
||||
|
||||
ir_type = type_.to_mlir(context)
|
||||
with context:
|
||||
type_attr = ir.TypeAttr.get(ir_type)
|
||||
dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr})
|
||||
dict_attrs.append(dict_attr)
|
||||
|
||||
return ir.ArrayAttr.get(dict_attrs, context=context)
|
||||
|
||||
|
||||
def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
||||
with module.context:
|
||||
name_attr = ir.StringAttr.get(name)
|
||||
for op in module.body.operations:
|
||||
if isinstance(op, FuncOp) and op.name == name_attr:
|
||||
return op
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def build_module(jit_function: ScriptFunction, annotations) -> ir.Module:
|
||||
"""Translate input function into an MLIR module in the `torch` dialect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
jit_function: ScriptFunction
|
||||
Function in TorchScript IR to turn into MLIR.
|
||||
annotation: Annotation
|
||||
Annotation object representing the types of
|
||||
the operands of `jit_function`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ir.Module
|
||||
Translation of the input module into an MLIR module
|
||||
"""
|
||||
mb = ModuleBuilder()
|
||||
mb.import_function(jit_function)
|
||||
|
||||
func_op = get_func_op_with_name(mb.module, jit_function.name)
|
||||
assert (
|
||||
func_op is not None
|
||||
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
|
||||
|
||||
func_annotation = Annotation(annotations)
|
||||
arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context)
|
||||
func_op.attributes["arg_attrs"] = arg_attrs
|
||||
|
||||
return mb.module
|
|
@ -0,0 +1,265 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Any, Callable, Tuple, Union
|
||||
from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.fx.node import map_aggregate
|
||||
from torch.fx.operator_schemas import normalize_function, create_type_hint
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch_mlir._mlir_libs._mlir.passmanager import PassManager
|
||||
|
||||
from torch_mlir.dialects import torch as torch_dialect
|
||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
|
||||
from torch_mlir.eager_mode.ir_building import build_module, TorchTensorType
|
||||
|
||||
OP_REGISTRY = {op["name"]: op for op in get_registered_ops()}
|
||||
SUPPORTED_OPS = frozenset(
|
||||
[
|
||||
member.OPERATION_NAME
|
||||
for member in vars(torch_dialect).values()
|
||||
if hasattr(member, "OPERATION_NAME")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedByTorchMlirEagerMode(Exception):
|
||||
def __init__(self, value: str):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
def check_supported_op(schema: torch._C.FunctionSchema) -> bool:
|
||||
return (
|
||||
"torch."
|
||||
+ schema.name.replace("::", ".")
|
||||
+ ("." + schema.overload_name if schema.overload_name else "")
|
||||
) in SUPPORTED_OPS
|
||||
|
||||
|
||||
def is_tensor_type(typ: torch._C.Type):
|
||||
return typ.isSubtypeOf(torch.TensorType.get()) or (
|
||||
isinstance(typ, torch.OptionalType)
|
||||
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
|
||||
)
|
||||
|
||||
|
||||
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
|
||||
"""Fill in default values for optional args, which are dependent on the schema."""
|
||||
|
||||
arg_types = map_aggregate(args, type)
|
||||
assert isinstance(arg_types, tuple)
|
||||
arg_types = tuple([create_type_hint(i) for i in arg_types])
|
||||
kwarg_types = {k: type(v) for k, v in kwargs.items()}
|
||||
|
||||
new_args_and_kwargs = normalize_function(
|
||||
target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False
|
||||
)
|
||||
assert new_args_and_kwargs, "Couldn't normalize args and kwargs"
|
||||
new_args, new_kwargs = new_args_and_kwargs
|
||||
return new_args, new_kwargs
|
||||
|
||||
|
||||
def build_script_function(
|
||||
schema: torch._C.FunctionSchema,
|
||||
args: List[torch._C.Argument],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> torch.jit.ScriptFunction:
|
||||
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
|
||||
|
||||
Constants are inlined for the purposes of invalidating the compile cache when they change.
|
||||
"""
|
||||
|
||||
# Creates empty TS graph.
|
||||
graph = torch._C.Graph()
|
||||
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
|
||||
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
|
||||
# Associate graph inputs/outputs with node inputs/outputs.
|
||||
for i, arg in enumerate(schema.arguments):
|
||||
# Find value corresponding to schema arg, either in positional or kw args.
|
||||
kwarg = False
|
||||
if arg.name in kwargs:
|
||||
val = kwargs[arg.name]
|
||||
kwarg = True
|
||||
else:
|
||||
val = args[i]
|
||||
|
||||
# If arg is a tensor, then add input to the graph corresponding to arg.
|
||||
if is_tensor_type(arg.type) and val is not None:
|
||||
inp = graph.addInput()
|
||||
if isinstance(arg.type, torch.OptionalType):
|
||||
inp.setType(arg.type.getElementType())
|
||||
else:
|
||||
inp.setType(arg.type)
|
||||
|
||||
if kwarg:
|
||||
# Rename for debugging aid.
|
||||
inp.setDebugName(arg.name)
|
||||
# If arg is a constant, inline (at the top of the graph).
|
||||
else:
|
||||
inp = graph.insertConstant(val)
|
||||
inp.node().moveBefore(node)
|
||||
|
||||
node.addInput(inp)
|
||||
|
||||
if node.hasMultipleOutputs():
|
||||
for outp in node.outputs():
|
||||
graph.registerOutput(outp)
|
||||
else:
|
||||
graph.registerOutput(node.output())
|
||||
|
||||
fn_name = str(node).strip()
|
||||
fn = torch._C._create_function_from_graph(fn_name, graph)
|
||||
return fn
|
||||
|
||||
|
||||
def annotate_args_kwargs(
|
||||
script_fun: torch._C.ScriptFunction,
|
||||
normalized_args: List[Any],
|
||||
normalized_kwargs: Dict[str, Any],
|
||||
):
|
||||
unwrapped_normalized_args = tree_map(
|
||||
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
|
||||
normalized_args,
|
||||
)
|
||||
unwrapped_normalized_kwargs = tree_map(
|
||||
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
|
||||
normalized_kwargs,
|
||||
)
|
||||
|
||||
annotations = []
|
||||
tensor_args = []
|
||||
for i, arg in enumerate(unwrapped_normalized_args):
|
||||
if isinstance(arg, np.ndarray):
|
||||
# TODO: Remove once size zero dimensions are handled by torch-mlir.
|
||||
shape = tuple(map(lambda x: x or -1, arg.shape))
|
||||
annotations.append(
|
||||
TorchTensorType(shape=shape, dtype=normalized_args[i].dtype)
|
||||
)
|
||||
tensor_args.append(arg)
|
||||
|
||||
# Pull out tensor kwargs and put them in positional order.
|
||||
tensor_kwargs_flat = []
|
||||
if unwrapped_normalized_kwargs:
|
||||
tensor_kwargs = {}
|
||||
arg_idxs = {
|
||||
arg_name: i
|
||||
for i, arg_name in enumerate(
|
||||
[arg.name for arg in script_fun.schema.arguments]
|
||||
)
|
||||
}
|
||||
for i, (kw, arg) in enumerate(unwrapped_normalized_kwargs.items()):
|
||||
if isinstance(arg, np.ndarray):
|
||||
tensor_kwargs[arg_idxs[kw]] = (arg, normalized_kwargs[kw].dtype)
|
||||
|
||||
for i in range(len(tensor_kwargs)):
|
||||
arg, arg_dtype = tensor_kwargs[i]
|
||||
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg_dtype))
|
||||
tensor_kwargs_flat.append(arg)
|
||||
|
||||
return annotations, tensor_args, tensor_kwargs_flat
|
||||
|
||||
|
||||
def write_back_to_mutable(
|
||||
registered_op: Dict,
|
||||
out: Union[np.ndarray, List[np.ndarray]],
|
||||
all_tensor_args: List[np.ndarray],
|
||||
):
|
||||
"""Write back to mutable args that aren't properly handled otherwise.
|
||||
|
||||
Because of how we pass values to the backend (by copying the tensor to a numpy array) we don't currently support
|
||||
ops that mutate operands. That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
|
||||
as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which
|
||||
we can for these special cases. Hence, the solution is to manually write back to the same operand that the
|
||||
conventional pytorch op variant would write to.
|
||||
|
||||
Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that
|
||||
mutate running_mean and running_var). We don't currently handle those.
|
||||
"""
|
||||
if len(registered_op["returns"]) > 1:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
"TorchMLIR doesn't handle multiple aliased returns yet."
|
||||
)
|
||||
else:
|
||||
aliased_arg = next(
|
||||
arg
|
||||
for arg in registered_op["arguments"]
|
||||
if "alias_info" in arg and arg["alias_info"]["is_write"]
|
||||
)
|
||||
assert (
|
||||
"alias_info" in registered_op["returns"][0]
|
||||
and registered_op["returns"][0]["alias_info"]["is_write"]
|
||||
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
|
||||
and registered_op["returns"][0]["alias_info"]["after"][0]
|
||||
)
|
||||
assert (
|
||||
len(aliased_arg["alias_info"]["after"]) == 1
|
||||
and aliased_arg["alias_info"]["after"][0]
|
||||
== registered_op["returns"][0]["alias_info"]["after"][0]
|
||||
)
|
||||
np.copyto(all_tensor_args[0], out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def try_torch_mlir_eager(op, args, kwargs, backend):
|
||||
if hasattr(op, "op_name"):
|
||||
op_name = op.op_name
|
||||
elif hasattr(op, "__name__"):
|
||||
# Handle builtin_function_or_method.
|
||||
op_name = op.__name__
|
||||
else:
|
||||
raise RuntimeError(f"op {op} has no name")
|
||||
|
||||
if op_name == "detach":
|
||||
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
|
||||
raise UnsupportedByTorchMlirEagerMode("detaching")
|
||||
|
||||
if not hasattr(op, "_schema"):
|
||||
raise RuntimeError(f"op {op} has no schema.")
|
||||
|
||||
new_args, new_kwargs = normalize_args_kwargs(op.overloadpacket, args, kwargs)
|
||||
|
||||
if "layout" in new_kwargs and new_kwargs["layout"] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{new_kwargs['layout']} layout not supported."
|
||||
)
|
||||
if "memory_format" in new_kwargs and new_kwargs["memory_format"] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{new_kwargs['memory_format']} memory format not supported."
|
||||
)
|
||||
|
||||
script_fun = build_script_function(op._schema, new_args, new_kwargs)
|
||||
annotations, np_tensor_args, np_tensor_kwargs_flat = annotate_args_kwargs(
|
||||
script_fun, new_args, new_kwargs
|
||||
)
|
||||
|
||||
eager_module = build_module(script_fun, annotations)
|
||||
with eager_module.context:
|
||||
pm = PassManager.parse(
|
||||
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
|
||||
)
|
||||
pm.run(eager_module)
|
||||
compiled_module = backend.compile(eager_module)
|
||||
loaded_module = backend.load(compiled_module)
|
||||
op_mlir_backend_callable = getattr(loaded_module, script_fun.name)
|
||||
assert (
|
||||
op_mlir_backend_callable is not None
|
||||
), f"Couldn't find function {script_fun.name} in module."
|
||||
|
||||
all_tensor_args = np_tensor_args + np_tensor_kwargs_flat
|
||||
out = op_mlir_backend_callable(*all_tensor_args)
|
||||
|
||||
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
|
||||
if registered_op["is_mutable"]:
|
||||
out = write_back_to_mutable(registered_op, out, all_tensor_args)
|
||||
|
||||
return out
|
|
@ -0,0 +1,107 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import (
|
||||
try_torch_mlir_eager,
|
||||
UnsupportedByTorchMlirEagerMode,
|
||||
)
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
class TorchMLIRTensor(torch.Tensor):
|
||||
"""Wrap torch.Tensor in orer to dispatch through torch-mlir instead of aten.
|
||||
|
||||
This class uses the _make_wrapper_subclass pattern to override __torch_dispatch__
|
||||
in order to dispatch through torch-mlir instead of aten. Here we basically only unwrap and wrap
|
||||
torch.Tensors. Most of the heavy lifting is done in the adjacent torch_mlir_dispatch module.
|
||||
|
||||
More documentation on how this pattern works can be found in this RFC
|
||||
https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call
|
||||
and this repo with many examples
|
||||
https://github.com/albanD/subclass_zoo
|
||||
"""
|
||||
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ["elem"]
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
requires_grad=kwargs.get("requires_grad", False) or elem.requires_grad,
|
||||
)
|
||||
r.elem = elem.detach() if r.requires_grad else elem
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"TorchMLIRTensor({self.elem}, grad_fn={self.grad_fn})"
|
||||
else:
|
||||
return f"TorchMLIRTensor({self.elem})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
|
||||
requires_grad = False
|
||||
|
||||
def check_grad(e):
|
||||
nonlocal requires_grad
|
||||
if isinstance(e, TorchMLIRTensor):
|
||||
requires_grad |= e.requires_grad
|
||||
|
||||
tree_map(check_grad, args)
|
||||
tree_map(check_grad, kwargs)
|
||||
|
||||
def unwrap(e):
|
||||
if isinstance(e, TorchMLIRTensor):
|
||||
return e.elem
|
||||
if isinstance(e, torch.nn.Parameter):
|
||||
return e.detach()
|
||||
return e
|
||||
|
||||
def wrap(e):
|
||||
nonlocal requires_grad
|
||||
return (
|
||||
TorchMLIRTensor(e, requires_grad=requires_grad)
|
||||
if isinstance(e, torch.Tensor)
|
||||
else e
|
||||
)
|
||||
|
||||
unwrapped_args = tree_map(unwrap, args)
|
||||
unwrapped_kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
try:
|
||||
out = try_torch_mlir_eager(
|
||||
func,
|
||||
unwrapped_args,
|
||||
unwrapped_kwargs,
|
||||
backend=refbackend.RefBackendLinalgOnTensorsBackend(),
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = [torch.from_numpy(o) for o in out]
|
||||
else:
|
||||
out = torch.from_numpy(out)
|
||||
return tree_map(wrap, out)
|
||||
except Exception as e:
|
||||
if isinstance(e, UnsupportedByTorchMlirEagerMode):
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
|
||||
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
|
||||
)
|
||||
return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))
|
|
@ -7,3 +7,4 @@ from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
|||
from .native_torch import NativeTorchTestConfig
|
||||
from .torchscript import TorchScriptTestConfig
|
||||
from .tosa_backend import TosaBackendTestConfig
|
||||
from .eager_mode import EagerModeTestConfig
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
def wrap(e):
|
||||
return TorchMLIRTensor(e.detach().clone()) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
|
||||
def unwrap(e):
|
||||
return e.elem.clone() if isinstance(e, TorchMLIRTensor) else e
|
||||
|
||||
|
||||
class EagerModeTestConfig(TestConfig):
|
||||
"""Trivial test config that exercises eager mode plumbing"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||
return program
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
attr = artifact
|
||||
for part in item.symbol.split('.'):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
inps = tree_map(wrap, item.inputs)
|
||||
outps = attr(*inps)
|
||||
output = tree_map(unwrap, outps)
|
||||
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
return result
|
Loading…
Reference in New Issue