This PR implements an eager mode backend for PyTorch through the torch-mlir framework. This is accomplished by overriding the `__torch_dispatch__` class method on wrapper subclass `TorchMLIRTensor(torch.Tensor)`.

Effectively, this mode works by compiling op by op as the NN is eagerly executed by PyTorch. Entailed in that compilation is building a representation of the op that can be `torch.jit.script`ed, importing using `ModuleBuilder`, and then executing (e.g., with `RefBackendLinalgOnTensorsBackend`). This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported op).

Currently, all e2e tests pass execpt for two that involve an upstream PyTorch bug (https://github.com/pytorch/pytorch/issues/74400).

High priority next steps:

1. A compile cache in order to speed up reruns of the same NN.
2. Integration with IREE (though not in this repo).
3. Integration with `torch.distributed`.
pull/687/head snapshot-20220322.340
max 2022-03-02 17:02:18 -06:00 committed by Sean Silva
parent f9d34596e8
commit fe8ac57e6d
11 changed files with 699 additions and 15 deletions

View File

@ -15,13 +15,13 @@ from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
# Available test configs. # Available test configs.
from torch_mlir_e2e_test.torchscript.configs import ( from torch_mlir_e2e_test.torchscript.configs import (
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
) )
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS, EAGER_MODE_XFAIL_SET
# Import tests to register them in the global registry. # Import tests to register them in the global registry.
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking # Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
@ -57,7 +57,7 @@ from . import cast
from . import index_put from . import index_put
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external', 'eager_mode']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config', parser.add_argument('-c', '--config',
choices=config_choices, choices=config_choices,
@ -121,6 +121,9 @@ def main():
elif args.config == 'torchscript': elif args.config == 'torchscript':
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
xfail_set = {} xfail_set = {}
elif args.config == 'eager_mode':
config = EagerModeTestConfig()
xfail_set = EAGER_MODE_XFAIL_SET
elif args.config == 'external': elif args.config == 'external':
with open(args.external_config, 'r') as f: with open(args.external_config, 'r') as f:
code = compile(f.read(), args.external_config, 'exec') code = compile(f.read(), args.external_config, 'exec')

View File

@ -113,9 +113,9 @@ class BernoulliModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BernoulliModule()) @register_test_case(module_factory=lambda: BernoulliModule())
def BernoulliModule_basic(module, tu: TestUtils): def BernoulliModule_basic(module, tu: TestUtils):
module.forward( module.forward(
tu.rand(256, 512, 8).double(), tu.rand(512, 1024, 8).double(),
tu.rand(512, 1024, 4).double(), tu.rand(1024, 2048, 4).double(),
tu.rand(512, 256, 4).double()) tu.rand(1024, 256, 4).double())
# ============================================================================== # ==============================================================================
@ -188,9 +188,9 @@ class BernoulliFloatModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BernoulliFloatModule()) @register_test_case(module_factory=lambda: BernoulliFloatModule())
def BernoulliFloatModule_basic(module, tu: TestUtils): def BernoulliFloatModule_basic(module, tu: TestUtils):
module.forward( module.forward(
tu.rand(256, 512, 8).double(), tu.rand(512, 1024, 8).double(),
tu.rand(512, 1024, 4).double(), tu.rand(1024, 2048, 4).double(),
tu.rand(512, 256, 4).double()) tu.rand(1024, 512, 4).double())
# ============================================================================== # ==============================================================================
@ -228,9 +228,9 @@ class BernoulliTensorModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BernoulliTensorModule()) @register_test_case(module_factory=lambda: BernoulliTensorModule())
def BernoulliTensorModule_basic(module, tu: TestUtils): def BernoulliTensorModule_basic(module, tu: TestUtils):
module.forward( module.forward(
tu.rand(512, 512, 8).double(), tu.rand(1024, 1024, 16).double(),
tu.rand(512, 512, 8).double(), tu.rand(1024, 1024, 16).double(),
tu.rand(512, 1024, 4).double(), tu.rand(1024, 2048, 8).double(),
tu.rand(512, 1024, 4).double(), tu.rand(1024, 2048, 8).double(),
tu.rand(512, 256, 4).double(), tu.rand(1024, 512, 8).double(),
tu.rand(512, 256, 4).double()) tu.rand(1024, 512, 8).double())

View File

@ -21,6 +21,13 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
} }
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
EAGER_MODE_XFAIL_SET = REFBACKEND_XFAIL_SET.union({
# These fail because an upstream pytorch bug; more information at the following issue
# https://github.com/pytorch/pytorch/issues/74400
"ElementwiseMulScalarModule_basic",
"ElementwiseSubScalarIntModule_basic",
})
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {

View File

@ -57,6 +57,12 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
add_subdirectory(torch_mlir_e2e_test) add_subdirectory(torch_mlir_e2e_test)
endif() endif()
################################################################################
# Eager mode
################################################################################
add_subdirectory(torch_mlir/eager_mode)
################################################################################ ################################################################################
# Generate packages and shared library # Generate packages and shared library
# Downstreams typically will not use these, but they are useful for local # Downstreams typically will not use these, but they are useful for local

View File

@ -0,0 +1,20 @@
#-------------------------------------------------------------------------------
# Setup PyTorch
#-------------------------------------------------------------------------------
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
find_package(Torch 1.8 REQUIRED)
TorchMLIRConfigurePyTorch()
#-------------------------------------------------------------------------------
# Subdirectories
#-------------------------------------------------------------------------------
## Declare the sources of the Python module.
declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES_GLOB eager_mode/*.py lazytensor/*.py
)

View File

@ -0,0 +1,230 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Translator from torch.jit.ScriptFunction to MLIR.
The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the
`torch` dialect.
The expected use of this module is to create an instance of one of the classes below, and then calling the
`to_mlir` method to generate the MLIR representation of the type.
Information about what types are supported by each class can be found in docstrings of each of the classes.
In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module.
The expected use for this module is to use the function
`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module`
to convert the TorchScript function into MLIR using the `torch` dialect.
"""
import abc
from typing import Any, Optional, Iterable
from typing import Union
import torch
from torch.jit import ScriptFunction
from torch_mlir import ir
from torch_mlir.dialects.builtin import FuncOp
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
class TorchMlirType(abc.ABC):
"""
A `TorchMlirType` is an object that produces MLIR
types in the `torch` dialect. The only requirement
for a class to be a subclass of `TorchMlirType` is
to define a `to_mlir(self, ir.Context) -> ir.Type`.
Each class is allowed to have different types of
__init__ methods depending on the information they
require to produce the given MLIR representation.
"""
@abc.abstractmethod
def to_mlir(self, context: ir.Context) -> ir.Type:
pass
class TorchTensorTypeError(Exception):
def __init__(self, value: str):
super().__init__()
self.value = value
def __str__(self) -> str:
return self.value
class TorchTensorType(TorchMlirType):
"""
This class is used to generate types of the form
!torch.tensor and !torch.vtensor<SHAPE, DTYPE>,
where SHAPE is a list representing the shape of the tensor,
and DTYPE is an MLIR data type.
"""
def __init__(
self,
*,
shape: Optional[Iterable[Optional[int]]] = None,
dtype: Optional[torch.dtype] = None,
):
self.shape = shape
self.dtype = dtype
if dtype is None and shape is not None:
err = "If shape is specified, dtype must also be specified"
raise TorchTensorTypeError(err)
def __str__(self):
return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})"
def to_mlir(self, context: ir.Context) -> ir.Type:
if self.dtype is None:
return ir.Type.parse("!torch.tensor", context=context)
shape_asm = self._shape_to_mlir_asm()
dtype_asm = self._dtype_to_mlir_asm()
return ir.Type.parse(
f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context
)
def _shape_to_mlir_asm(self) -> str:
if self.shape is None:
return "*"
str_sizes = map(lambda x: "?" if x is None else str(x), self.shape)
return f'[{",".join(str_sizes)}]'
def _dtype_to_mlir_asm(self) -> str:
if self.dtype in [torch.float64]:
return "f64"
if self.dtype in [torch.float, torch.float32]:
return "f32"
if self.dtype in [torch.int, torch.int32]:
return "si32"
if self.dtype in [torch.int64]:
return "si64"
if self.dtype in [torch.bool]:
return "i1"
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
class TorchNnModuleType(TorchMlirType):
"""This class is used to generate types for `!torch.nn.Module`s."""
def __init__(self, module_name: str):
self.module_name = module_name
def __str__(self):
return "torch.nn.Module"
def to_mlir(self, context: ir.Context) -> ir.Type:
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context)
class PythonType(TorchMlirType):
"""
This class is used to convert regular Python types
into their corresponding `torch` dialect representation.
The list of supported types can be found in the dictionary
`_type_to_asm_dict`.
"""
_type_to_asm_dict = {
bool: "!torch.bool",
int: "!torch.int",
type(None): "!torch.none",
}
def __init__(self, type_: Any):
self.type_ = type_
def __str__(self):
return str(self.type_)
def to_mlir(self, context: ir.Context) -> ir.Type:
asm = self._type_to_asm_dict.get(self.type_)
if asm is None:
raise NotImplementedError(f"Unsupported type: {self.type_}")
return ir.Type.parse(asm, context=context)
# TODO: This functionality should be incorporated into ModuleBuilder.import_function.
class Annotation:
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
self.types = list(
map(lambda t: PythonType(t) if isinstance(t, type) else t, types)
)
def __str__(self):
result = f"Annotation instance with {len(self.types)} types\n"
for e, type_ in enumerate(self.types):
result += f" Type of argument {e + 1}: {str(type_)}\n"
return result
def __iter__(self):
return iter(self.types)
class AnnotationConverter:
@staticmethod
def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr:
dict_attrs = []
for type_ in annotation:
if not isinstance(type_, TorchTensorType):
dict_attrs.append(ir.DictAttr.get({}, context=context))
continue
ir_type = type_.to_mlir(context)
with context:
type_attr = ir.TypeAttr.get(ir_type)
dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr})
dict_attrs.append(dict_attr)
return ir.ArrayAttr.get(dict_attrs, context=context)
def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
with module.context:
name_attr = ir.StringAttr.get(name)
for op in module.body.operations:
if isinstance(op, FuncOp) and op.name == name_attr:
return op
return None
def build_module(jit_function: ScriptFunction, annotations) -> ir.Module:
"""Translate input function into an MLIR module in the `torch` dialect.
Parameters
----------
jit_function: ScriptFunction
Function in TorchScript IR to turn into MLIR.
annotation: Annotation
Annotation object representing the types of
the operands of `jit_function`.
Returns
-------
ir.Module
Translation of the input module into an MLIR module
"""
mb = ModuleBuilder()
mb.import_function(jit_function)
func_op = get_func_op_with_name(mb.module, jit_function.name)
assert (
func_op is not None
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
func_annotation = Annotation(annotations)
arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context)
func_op.attributes["arg_attrs"] = arg_attrs
return mb.module

View File

@ -0,0 +1,265 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Any, Callable, Tuple, Union
from typing import List, Dict
import numpy as np
import torch
import torch._C
from torch.fx.node import map_aggregate
from torch.fx.operator_schemas import normalize_function, create_type_hint
from torch.utils._pytree import tree_map
from torch_mlir._mlir_libs._mlir.passmanager import PassManager
from torch_mlir.dialects import torch as torch_dialect
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
from torch_mlir.eager_mode.ir_building import build_module, TorchTensorType
OP_REGISTRY = {op["name"]: op for op in get_registered_ops()}
SUPPORTED_OPS = frozenset(
[
member.OPERATION_NAME
for member in vars(torch_dialect).values()
if hasattr(member, "OPERATION_NAME")
]
)
class UnsupportedByTorchMlirEagerMode(Exception):
def __init__(self, value: str):
super().__init__()
self.value = value
def __str__(self) -> str:
return self.value
def check_supported_op(schema: torch._C.FunctionSchema) -> bool:
return (
"torch."
+ schema.name.replace("::", ".")
+ ("." + schema.overload_name if schema.overload_name else "")
) in SUPPORTED_OPS
def is_tensor_type(typ: torch._C.Type):
return typ.isSubtypeOf(torch.TensorType.get()) or (
isinstance(typ, torch.OptionalType)
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
)
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
"""Fill in default values for optional args, which are dependent on the schema."""
arg_types = map_aggregate(args, type)
assert isinstance(arg_types, tuple)
arg_types = tuple([create_type_hint(i) for i in arg_types])
kwarg_types = {k: type(v) for k, v in kwargs.items()}
new_args_and_kwargs = normalize_function(
target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False
)
assert new_args_and_kwargs, "Couldn't normalize args and kwargs"
new_args, new_kwargs = new_args_and_kwargs
return new_args, new_kwargs
def build_script_function(
schema: torch._C.FunctionSchema,
args: List[torch._C.Argument],
kwargs: Dict[str, Any],
) -> torch.jit.ScriptFunction:
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
Constants are inlined for the purposes of invalidating the compile cache when they change.
"""
# Creates empty TS graph.
graph = torch._C.Graph()
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
# Associate graph inputs/outputs with node inputs/outputs.
for i, arg in enumerate(schema.arguments):
# Find value corresponding to schema arg, either in positional or kw args.
kwarg = False
if arg.name in kwargs:
val = kwargs[arg.name]
kwarg = True
else:
val = args[i]
# If arg is a tensor, then add input to the graph corresponding to arg.
if is_tensor_type(arg.type) and val is not None:
inp = graph.addInput()
if isinstance(arg.type, torch.OptionalType):
inp.setType(arg.type.getElementType())
else:
inp.setType(arg.type)
if kwarg:
# Rename for debugging aid.
inp.setDebugName(arg.name)
# If arg is a constant, inline (at the top of the graph).
else:
inp = graph.insertConstant(val)
inp.node().moveBefore(node)
node.addInput(inp)
if node.hasMultipleOutputs():
for outp in node.outputs():
graph.registerOutput(outp)
else:
graph.registerOutput(node.output())
fn_name = str(node).strip()
fn = torch._C._create_function_from_graph(fn_name, graph)
return fn
def annotate_args_kwargs(
script_fun: torch._C.ScriptFunction,
normalized_args: List[Any],
normalized_kwargs: Dict[str, Any],
):
unwrapped_normalized_args = tree_map(
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
normalized_args,
)
unwrapped_normalized_kwargs = tree_map(
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
normalized_kwargs,
)
annotations = []
tensor_args = []
for i, arg in enumerate(unwrapped_normalized_args):
if isinstance(arg, np.ndarray):
# TODO: Remove once size zero dimensions are handled by torch-mlir.
shape = tuple(map(lambda x: x or -1, arg.shape))
annotations.append(
TorchTensorType(shape=shape, dtype=normalized_args[i].dtype)
)
tensor_args.append(arg)
# Pull out tensor kwargs and put them in positional order.
tensor_kwargs_flat = []
if unwrapped_normalized_kwargs:
tensor_kwargs = {}
arg_idxs = {
arg_name: i
for i, arg_name in enumerate(
[arg.name for arg in script_fun.schema.arguments]
)
}
for i, (kw, arg) in enumerate(unwrapped_normalized_kwargs.items()):
if isinstance(arg, np.ndarray):
tensor_kwargs[arg_idxs[kw]] = (arg, normalized_kwargs[kw].dtype)
for i in range(len(tensor_kwargs)):
arg, arg_dtype = tensor_kwargs[i]
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg_dtype))
tensor_kwargs_flat.append(arg)
return annotations, tensor_args, tensor_kwargs_flat
def write_back_to_mutable(
registered_op: Dict,
out: Union[np.ndarray, List[np.ndarray]],
all_tensor_args: List[np.ndarray],
):
"""Write back to mutable args that aren't properly handled otherwise.
Because of how we pass values to the backend (by copying the tensor to a numpy array) we don't currently support
ops that mutate operands. That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which
we can for these special cases. Hence, the solution is to manually write back to the same operand that the
conventional pytorch op variant would write to.
Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that
mutate running_mean and running_var). We don't currently handle those.
"""
if len(registered_op["returns"]) > 1:
raise UnsupportedByTorchMlirEagerMode(
"TorchMLIR doesn't handle multiple aliased returns yet."
)
else:
aliased_arg = next(
arg
for arg in registered_op["arguments"]
if "alias_info" in arg and arg["alias_info"]["is_write"]
)
assert (
"alias_info" in registered_op["returns"][0]
and registered_op["returns"][0]["alias_info"]["is_write"]
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
and registered_op["returns"][0]["alias_info"]["after"][0]
)
assert (
len(aliased_arg["alias_info"]["after"]) == 1
and aliased_arg["alias_info"]["after"][0]
== registered_op["returns"][0]["alias_info"]["after"][0]
)
np.copyto(all_tensor_args[0], out)
return out
def try_torch_mlir_eager(op, args, kwargs, backend):
if hasattr(op, "op_name"):
op_name = op.op_name
elif hasattr(op, "__name__"):
# Handle builtin_function_or_method.
op_name = op.__name__
else:
raise RuntimeError(f"op {op} has no name")
if op_name == "detach":
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
raise UnsupportedByTorchMlirEagerMode("detaching")
if not hasattr(op, "_schema"):
raise RuntimeError(f"op {op} has no schema.")
new_args, new_kwargs = normalize_args_kwargs(op.overloadpacket, args, kwargs)
if "layout" in new_kwargs and new_kwargs["layout"] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{new_kwargs['layout']} layout not supported."
)
if "memory_format" in new_kwargs and new_kwargs["memory_format"] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{new_kwargs['memory_format']} memory format not supported."
)
script_fun = build_script_function(op._schema, new_args, new_kwargs)
annotations, np_tensor_args, np_tensor_kwargs_flat = annotate_args_kwargs(
script_fun, new_args, new_kwargs
)
eager_module = build_module(script_fun, annotations)
with eager_module.context:
pm = PassManager.parse(
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
)
pm.run(eager_module)
compiled_module = backend.compile(eager_module)
loaded_module = backend.load(compiled_module)
op_mlir_backend_callable = getattr(loaded_module, script_fun.name)
assert (
op_mlir_backend_callable is not None
), f"Couldn't find function {script_fun.name} in module."
all_tensor_args = np_tensor_args + np_tensor_kwargs_flat
out = op_mlir_backend_callable(*all_tensor_args)
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
if registered_op["is_mutable"]:
out = write_back_to_mutable(registered_op, out, all_tensor_args)
return out

View File

@ -0,0 +1,107 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import warnings
import torch
from torch.utils._pytree import tree_map
from torch_mlir.eager_mode.torch_mlir_dispatch import (
try_torch_mlir_eager,
UnsupportedByTorchMlirEagerMode,
)
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
class TorchMLIRTensor(torch.Tensor):
"""Wrap torch.Tensor in orer to dispatch through torch-mlir instead of aten.
This class uses the _make_wrapper_subclass pattern to override __torch_dispatch__
in order to dispatch through torch-mlir instead of aten. Here we basically only unwrap and wrap
torch.Tensors. Most of the heavy lifting is done in the adjacent torch_mlir_dispatch module.
More documentation on how this pattern works can be found in this RFC
https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call
and this repo with many examples
https://github.com/albanD/subclass_zoo
"""
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=kwargs.get("requires_grad", False) or elem.requires_grad,
)
r.elem = elem.detach() if r.requires_grad else elem
return r
def __repr__(self):
if self.grad_fn:
return f"TorchMLIRTensor({self.elem}, grad_fn={self.grad_fn})"
else:
return f"TorchMLIRTensor({self.elem})"
@classmethod
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
requires_grad = False
def check_grad(e):
nonlocal requires_grad
if isinstance(e, TorchMLIRTensor):
requires_grad |= e.requires_grad
tree_map(check_grad, args)
tree_map(check_grad, kwargs)
def unwrap(e):
if isinstance(e, TorchMLIRTensor):
return e.elem
if isinstance(e, torch.nn.Parameter):
return e.detach()
return e
def wrap(e):
nonlocal requires_grad
return (
TorchMLIRTensor(e, requires_grad=requires_grad)
if isinstance(e, torch.Tensor)
else e
)
unwrapped_args = tree_map(unwrap, args)
unwrapped_kwargs = tree_map(unwrap, kwargs)
try:
out = try_torch_mlir_eager(
func,
unwrapped_args,
unwrapped_kwargs,
backend=refbackend.RefBackendLinalgOnTensorsBackend(),
)
if isinstance(out, tuple):
out = [torch.from_numpy(o) for o in out]
else:
out = torch.from_numpy(out)
return tree_map(wrap, out)
except Exception as e:
if isinstance(e, UnsupportedByTorchMlirEagerMode):
warnings.warn(
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
)
else:
warnings.warn(
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
)
return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))

View File

@ -7,3 +7,4 @@ from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
from .native_torch import NativeTorchTestConfig from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig from .torchscript import TorchScriptTestConfig
from .tosa_backend import TosaBackendTestConfig from .tosa_backend import TosaBackendTestConfig
from .eager_mode import EagerModeTestConfig

View File

@ -0,0 +1,45 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch.utils._pytree import tree_map
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
def wrap(e):
return TorchMLIRTensor(e.detach().clone()) if isinstance(e, torch.Tensor) else e
def unwrap(e):
return e.elem.clone() if isinstance(e, TorchMLIRTensor) else e
class EagerModeTestConfig(TestConfig):
"""Trivial test config that exercises eager mode plumbing"""
def __init__(self):
super().__init__()
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
return program
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
attr = artifact
for part in item.symbol.split('.'):
attr = getattr(attr, part)
inps = tree_map(wrap, item.inputs)
outps = attr(*inps)
output = tree_map(unwrap, outps)
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output))
return result