diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 3baca1d23..a6e30bb9e 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -15,13 +15,13 @@ from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY # Available test configs. from torch_mlir_e2e_test.torchscript.configs import ( - LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig + LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS +from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS, EAGER_MODE_XFAIL_SET # Import tests to register them in the global registry. # Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking @@ -57,7 +57,7 @@ from . import cast from . import index_put def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] + config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external', 'eager_mode'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('-c', '--config', choices=config_choices, @@ -121,6 +121,9 @@ def main(): elif args.config == 'torchscript': config = TorchScriptTestConfig() xfail_set = {} + elif args.config == 'eager_mode': + config = EagerModeTestConfig() + xfail_set = EAGER_MODE_XFAIL_SET elif args.config == 'external': with open(args.external_config, 'r') as f: code = compile(f.read(), args.external_config, 'exec') diff --git a/e2e_testing/torchscript/rng.py b/e2e_testing/torchscript/rng.py index ca3133229..edaf43ce1 100644 --- a/e2e_testing/torchscript/rng.py +++ b/e2e_testing/torchscript/rng.py @@ -113,9 +113,9 @@ class BernoulliModule(torch.nn.Module): @register_test_case(module_factory=lambda: BernoulliModule()) def BernoulliModule_basic(module, tu: TestUtils): module.forward( - tu.rand(256, 512, 8).double(), - tu.rand(512, 1024, 4).double(), - tu.rand(512, 256, 4).double()) + tu.rand(512, 1024, 8).double(), + tu.rand(1024, 2048, 4).double(), + tu.rand(1024, 256, 4).double()) # ============================================================================== @@ -188,9 +188,9 @@ class BernoulliFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: BernoulliFloatModule()) def BernoulliFloatModule_basic(module, tu: TestUtils): module.forward( - tu.rand(256, 512, 8).double(), - tu.rand(512, 1024, 4).double(), - tu.rand(512, 256, 4).double()) + tu.rand(512, 1024, 8).double(), + tu.rand(1024, 2048, 4).double(), + tu.rand(1024, 512, 4).double()) # ============================================================================== @@ -228,9 +228,9 @@ class BernoulliTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: BernoulliTensorModule()) def BernoulliTensorModule_basic(module, tu: TestUtils): module.forward( - tu.rand(512, 512, 8).double(), - tu.rand(512, 512, 8).double(), - tu.rand(512, 1024, 4).double(), - tu.rand(512, 1024, 4).double(), - tu.rand(512, 256, 4).double(), - tu.rand(512, 256, 4).double()) + tu.rand(1024, 1024, 16).double(), + tu.rand(1024, 1024, 16).double(), + tu.rand(1024, 2048, 8).double(), + tu.rand(1024, 2048, 8).double(), + tu.rand(1024, 512, 8).double(), + tu.rand(1024, 512, 8).double()) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index bba6c737d..7220b8a4a 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -21,6 +21,13 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { } REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS +EAGER_MODE_XFAIL_SET = REFBACKEND_XFAIL_SET.union({ + # These fail because an upstream pytorch bug; more information at the following issue + # https://github.com/pytorch/pytorch/issues/74400 + "ElementwiseMulScalarModule_basic", + "ElementwiseSubScalarIntModule_basic", +}) + # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 5f25d7e61..168479653 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -57,6 +57,12 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) add_subdirectory(torch_mlir_e2e_test) endif() +################################################################################ +# Eager mode +################################################################################ + +add_subdirectory(torch_mlir/eager_mode) + ################################################################################ # Generate packages and shared library # Downstreams typically will not use these, but they are useful for local diff --git a/python/torch_mlir/eager_mode/CMakeLists.txt b/python/torch_mlir/eager_mode/CMakeLists.txt new file mode 100644 index 000000000..84e864e39 --- /dev/null +++ b/python/torch_mlir/eager_mode/CMakeLists.txt @@ -0,0 +1,20 @@ +#------------------------------------------------------------------------------- +# Setup PyTorch +#------------------------------------------------------------------------------- + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") +find_package(Torch 1.8 REQUIRED) + +TorchMLIRConfigurePyTorch() + +#------------------------------------------------------------------------------- +# Subdirectories +#------------------------------------------------------------------------------- + +## Declare the sources of the Python module. + +declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES_GLOB eager_mode/*.py lazytensor/*.py +) diff --git a/python/torch_mlir/eager_mode/__init__.py b/python/torch_mlir/eager_mode/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/torch_mlir/eager_mode/ir_building.py b/python/torch_mlir/eager_mode/ir_building.py new file mode 100644 index 000000000..35afd2929 --- /dev/null +++ b/python/torch_mlir/eager_mode/ir_building.py @@ -0,0 +1,230 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +""" +Translator from torch.jit.ScriptFunction to MLIR. + + +The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the +`torch` dialect. + +The expected use of this module is to create an instance of one of the classes below, and then calling the +`to_mlir` method to generate the MLIR representation of the type. + +Information about what types are supported by each class can be found in docstrings of each of the classes. + +In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module. + +The expected use for this module is to use the function +`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module` +to convert the TorchScript function into MLIR using the `torch` dialect. +""" + +import abc +from typing import Any, Optional, Iterable +from typing import Union + +import torch +from torch.jit import ScriptFunction + +from torch_mlir import ir +from torch_mlir.dialects.builtin import FuncOp +from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder + + +class TorchMlirType(abc.ABC): + """ + A `TorchMlirType` is an object that produces MLIR + types in the `torch` dialect. The only requirement + for a class to be a subclass of `TorchMlirType` is + to define a `to_mlir(self, ir.Context) -> ir.Type`. + Each class is allowed to have different types of + __init__ methods depending on the information they + require to produce the given MLIR representation. + """ + + @abc.abstractmethod + def to_mlir(self, context: ir.Context) -> ir.Type: + pass + + +class TorchTensorTypeError(Exception): + def __init__(self, value: str): + super().__init__() + self.value = value + + def __str__(self) -> str: + return self.value + + +class TorchTensorType(TorchMlirType): + """ + This class is used to generate types of the form + !torch.tensor and !torch.vtensor, + where SHAPE is a list representing the shape of the tensor, + and DTYPE is an MLIR data type. + """ + + def __init__( + self, + *, + shape: Optional[Iterable[Optional[int]]] = None, + dtype: Optional[torch.dtype] = None, + ): + self.shape = shape + self.dtype = dtype + + if dtype is None and shape is not None: + err = "If shape is specified, dtype must also be specified" + raise TorchTensorTypeError(err) + + def __str__(self): + return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})" + + def to_mlir(self, context: ir.Context) -> ir.Type: + if self.dtype is None: + return ir.Type.parse("!torch.tensor", context=context) + + shape_asm = self._shape_to_mlir_asm() + dtype_asm = self._dtype_to_mlir_asm() + return ir.Type.parse( + f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context + ) + + def _shape_to_mlir_asm(self) -> str: + if self.shape is None: + return "*" + + str_sizes = map(lambda x: "?" if x is None else str(x), self.shape) + return f'[{",".join(str_sizes)}]' + + def _dtype_to_mlir_asm(self) -> str: + if self.dtype in [torch.float64]: + return "f64" + if self.dtype in [torch.float, torch.float32]: + return "f32" + if self.dtype in [torch.int, torch.int32]: + return "si32" + if self.dtype in [torch.int64]: + return "si64" + if self.dtype in [torch.bool]: + return "i1" + + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + + +class TorchNnModuleType(TorchMlirType): + """This class is used to generate types for `!torch.nn.Module`s.""" + + def __init__(self, module_name: str): + self.module_name = module_name + + def __str__(self): + return "torch.nn.Module" + + def to_mlir(self, context: ir.Context) -> ir.Type: + return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context) + + +class PythonType(TorchMlirType): + """ + This class is used to convert regular Python types + into their corresponding `torch` dialect representation. + The list of supported types can be found in the dictionary + `_type_to_asm_dict`. + """ + + _type_to_asm_dict = { + bool: "!torch.bool", + int: "!torch.int", + type(None): "!torch.none", + } + + def __init__(self, type_: Any): + self.type_ = type_ + + def __str__(self): + return str(self.type_) + + def to_mlir(self, context: ir.Context) -> ir.Type: + asm = self._type_to_asm_dict.get(self.type_) + if asm is None: + raise NotImplementedError(f"Unsupported type: {self.type_}") + return ir.Type.parse(asm, context=context) + + +# TODO: This functionality should be incorporated into ModuleBuilder.import_function. +class Annotation: + def __init__(self, types: Iterable[Union[TorchTensorType, type]]): + self.types = list( + map(lambda t: PythonType(t) if isinstance(t, type) else t, types) + ) + + def __str__(self): + result = f"Annotation instance with {len(self.types)} types\n" + for e, type_ in enumerate(self.types): + result += f" Type of argument {e + 1}: {str(type_)}\n" + return result + + def __iter__(self): + return iter(self.types) + + +class AnnotationConverter: + @staticmethod + def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr: + dict_attrs = [] + for type_ in annotation: + if not isinstance(type_, TorchTensorType): + dict_attrs.append(ir.DictAttr.get({}, context=context)) + continue + + ir_type = type_.to_mlir(context) + with context: + type_attr = ir.TypeAttr.get(ir_type) + dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr}) + dict_attrs.append(dict_attr) + + return ir.ArrayAttr.get(dict_attrs, context=context) + + +def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]: + with module.context: + name_attr = ir.StringAttr.get(name) + for op in module.body.operations: + if isinstance(op, FuncOp) and op.name == name_attr: + return op + + return None + + +def build_module(jit_function: ScriptFunction, annotations) -> ir.Module: + """Translate input function into an MLIR module in the `torch` dialect. + + Parameters + ---------- + jit_function: ScriptFunction + Function in TorchScript IR to turn into MLIR. + annotation: Annotation + Annotation object representing the types of + the operands of `jit_function`. + + Returns + ------- + ir.Module + Translation of the input module into an MLIR module + """ + mb = ModuleBuilder() + mb.import_function(jit_function) + + func_op = get_func_op_with_name(mb.module, jit_function.name) + assert ( + func_op is not None + ), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder" + + func_annotation = Annotation(annotations) + arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context) + func_op.attributes["arg_attrs"] = arg_attrs + + return mb.module diff --git a/python/torch_mlir/eager_mode/torch_mlir_dispatch.py b/python/torch_mlir/eager_mode/torch_mlir_dispatch.py new file mode 100644 index 000000000..a6e020520 --- /dev/null +++ b/python/torch_mlir/eager_mode/torch_mlir_dispatch.py @@ -0,0 +1,265 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from typing import Any, Callable, Tuple, Union +from typing import List, Dict + +import numpy as np +import torch +import torch._C +from torch.fx.node import map_aggregate +from torch.fx.operator_schemas import normalize_function, create_type_hint +from torch.utils._pytree import tree_map +from torch_mlir._mlir_libs._mlir.passmanager import PassManager + +from torch_mlir.dialects import torch as torch_dialect +from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error +from torch_mlir.eager_mode.ir_building import build_module, TorchTensorType + +OP_REGISTRY = {op["name"]: op for op in get_registered_ops()} +SUPPORTED_OPS = frozenset( + [ + member.OPERATION_NAME + for member in vars(torch_dialect).values() + if hasattr(member, "OPERATION_NAME") + ] +) + + +class UnsupportedByTorchMlirEagerMode(Exception): + def __init__(self, value: str): + super().__init__() + self.value = value + + def __str__(self) -> str: + return self.value + + +def check_supported_op(schema: torch._C.FunctionSchema) -> bool: + return ( + "torch." + + schema.name.replace("::", ".") + + ("." + schema.overload_name if schema.overload_name else "") + ) in SUPPORTED_OPS + + +def is_tensor_type(typ: torch._C.Type): + return typ.isSubtypeOf(torch.TensorType.get()) or ( + isinstance(typ, torch.OptionalType) + and typ.getElementType().isSubtypeOf(torch._C.TensorType.get()) + ) + + +def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]): + """Fill in default values for optional args, which are dependent on the schema.""" + + arg_types = map_aggregate(args, type) + assert isinstance(arg_types, tuple) + arg_types = tuple([create_type_hint(i) for i in arg_types]) + kwarg_types = {k: type(v) for k, v in kwargs.items()} + + new_args_and_kwargs = normalize_function( + target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False + ) + assert new_args_and_kwargs, "Couldn't normalize args and kwargs" + new_args, new_kwargs = new_args_and_kwargs + return new_args, new_kwargs + + +def build_script_function( + schema: torch._C.FunctionSchema, + args: List[torch._C.Argument], + kwargs: Dict[str, Any], +) -> torch.jit.ScriptFunction: + """Build a torch.jit.ScriptFunction that corresponds to the schema. + + Constants are inlined for the purposes of invalidating the compile cache when they change. + """ + + # Creates empty TS graph. + graph = torch._C.Graph() + # Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point. + node = graph.insertNode(graph.create(schema.name, len(schema.returns))) + # Associate graph inputs/outputs with node inputs/outputs. + for i, arg in enumerate(schema.arguments): + # Find value corresponding to schema arg, either in positional or kw args. + kwarg = False + if arg.name in kwargs: + val = kwargs[arg.name] + kwarg = True + else: + val = args[i] + + # If arg is a tensor, then add input to the graph corresponding to arg. + if is_tensor_type(arg.type) and val is not None: + inp = graph.addInput() + if isinstance(arg.type, torch.OptionalType): + inp.setType(arg.type.getElementType()) + else: + inp.setType(arg.type) + + if kwarg: + # Rename for debugging aid. + inp.setDebugName(arg.name) + # If arg is a constant, inline (at the top of the graph). + else: + inp = graph.insertConstant(val) + inp.node().moveBefore(node) + + node.addInput(inp) + + if node.hasMultipleOutputs(): + for outp in node.outputs(): + graph.registerOutput(outp) + else: + graph.registerOutput(node.output()) + + fn_name = str(node).strip() + fn = torch._C._create_function_from_graph(fn_name, graph) + return fn + + +def annotate_args_kwargs( + script_fun: torch._C.ScriptFunction, + normalized_args: List[Any], + normalized_kwargs: Dict[str, Any], +): + unwrapped_normalized_args = tree_map( + lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x, + normalized_args, + ) + unwrapped_normalized_kwargs = tree_map( + lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x, + normalized_kwargs, + ) + + annotations = [] + tensor_args = [] + for i, arg in enumerate(unwrapped_normalized_args): + if isinstance(arg, np.ndarray): + # TODO: Remove once size zero dimensions are handled by torch-mlir. + shape = tuple(map(lambda x: x or -1, arg.shape)) + annotations.append( + TorchTensorType(shape=shape, dtype=normalized_args[i].dtype) + ) + tensor_args.append(arg) + + # Pull out tensor kwargs and put them in positional order. + tensor_kwargs_flat = [] + if unwrapped_normalized_kwargs: + tensor_kwargs = {} + arg_idxs = { + arg_name: i + for i, arg_name in enumerate( + [arg.name for arg in script_fun.schema.arguments] + ) + } + for i, (kw, arg) in enumerate(unwrapped_normalized_kwargs.items()): + if isinstance(arg, np.ndarray): + tensor_kwargs[arg_idxs[kw]] = (arg, normalized_kwargs[kw].dtype) + + for i in range(len(tensor_kwargs)): + arg, arg_dtype = tensor_kwargs[i] + annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg_dtype)) + tensor_kwargs_flat.append(arg) + + return annotations, tensor_args, tensor_kwargs_flat + + +def write_back_to_mutable( + registered_op: Dict, + out: Union[np.ndarray, List[np.ndarray]], + all_tensor_args: List[np.ndarray], +): + """Write back to mutable args that aren't properly handled otherwise. + + Because of how we pass values to the backend (by copying the tensor to a numpy array) we don't currently support + ops that mutate operands. That includes both inplace variants and outplace variants. Additionally, Torch-MLIR, + as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which + we can for these special cases. Hence, the solution is to manually write back to the same operand that the + conventional pytorch op variant would write to. + + Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that + mutate running_mean and running_var). We don't currently handle those. + """ + if len(registered_op["returns"]) > 1: + raise UnsupportedByTorchMlirEagerMode( + "TorchMLIR doesn't handle multiple aliased returns yet." + ) + else: + aliased_arg = next( + arg + for arg in registered_op["arguments"] + if "alias_info" in arg and arg["alias_info"]["is_write"] + ) + assert ( + "alias_info" in registered_op["returns"][0] + and registered_op["returns"][0]["alias_info"]["is_write"] + and len(registered_op["returns"][0]["alias_info"]["after"]) == 1 + and registered_op["returns"][0]["alias_info"]["after"][0] + ) + assert ( + len(aliased_arg["alias_info"]["after"]) == 1 + and aliased_arg["alias_info"]["after"][0] + == registered_op["returns"][0]["alias_info"]["after"][0] + ) + np.copyto(all_tensor_args[0], out) + + return out + + +def try_torch_mlir_eager(op, args, kwargs, backend): + if hasattr(op, "op_name"): + op_name = op.op_name + elif hasattr(op, "__name__"): + # Handle builtin_function_or_method. + op_name = op.__name__ + else: + raise RuntimeError(f"op {op} has no name") + + if op_name == "detach": + # We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch. + raise UnsupportedByTorchMlirEagerMode("detaching") + + if not hasattr(op, "_schema"): + raise RuntimeError(f"op {op} has no schema.") + + new_args, new_kwargs = normalize_args_kwargs(op.overloadpacket, args, kwargs) + + if "layout" in new_kwargs and new_kwargs["layout"] not in {0, None}: + raise UnsupportedByTorchMlirEagerMode( + f"{new_kwargs['layout']} layout not supported." + ) + if "memory_format" in new_kwargs and new_kwargs["memory_format"] not in {0, None}: + raise UnsupportedByTorchMlirEagerMode( + f"{new_kwargs['memory_format']} memory format not supported." + ) + + script_fun = build_script_function(op._schema, new_args, new_kwargs) + annotations, np_tensor_args, np_tensor_kwargs_flat = annotate_args_kwargs( + script_fun, new_args, new_kwargs + ) + + eager_module = build_module(script_fun, annotations) + with eager_module.context: + pm = PassManager.parse( + "torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline" + ) + pm.run(eager_module) + compiled_module = backend.compile(eager_module) + loaded_module = backend.load(compiled_module) + op_mlir_backend_callable = getattr(loaded_module, script_fun.name) + assert ( + op_mlir_backend_callable is not None + ), f"Couldn't find function {script_fun.name} in module." + + all_tensor_args = np_tensor_args + np_tensor_kwargs_flat + out = op_mlir_backend_callable(*all_tensor_args) + + registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)] + if registered_op["is_mutable"]: + out = write_back_to_mutable(registered_op, out, all_tensor_args) + + return out diff --git a/python/torch_mlir/eager_mode/torch_mlir_tensor.py b/python/torch_mlir/eager_mode/torch_mlir_tensor.py new file mode 100644 index 000000000..b05e61e0c --- /dev/null +++ b/python/torch_mlir/eager_mode/torch_mlir_tensor.py @@ -0,0 +1,107 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +import warnings + +import torch +from torch.utils._pytree import tree_map + +from torch_mlir.eager_mode.torch_mlir_dispatch import ( + try_torch_mlir_eager, + UnsupportedByTorchMlirEagerMode, +) +from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend + + +class TorchMLIRTensor(torch.Tensor): + """Wrap torch.Tensor in orer to dispatch through torch-mlir instead of aten. + + This class uses the _make_wrapper_subclass pattern to override __torch_dispatch__ + in order to dispatch through torch-mlir instead of aten. Here we basically only unwrap and wrap + torch.Tensors. Most of the heavy lifting is done in the adjacent torch_mlir_dispatch module. + + More documentation on how this pattern works can be found in this RFC + https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call + and this repo with many examples + https://github.com/albanD/subclass_zoo + """ + + elem: torch.Tensor + + __slots__ = ["elem"] + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_wrapper_subclass( + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=kwargs.get("requires_grad", False) or elem.requires_grad, + ) + r.elem = elem.detach() if r.requires_grad else elem + return r + + def __repr__(self): + if self.grad_fn: + return f"TorchMLIRTensor({self.elem}, grad_fn={self.grad_fn})" + else: + return f"TorchMLIRTensor({self.elem})" + + @classmethod + def __torch_dispatch__(cls, func, _types, args=(), kwargs=None): + requires_grad = False + + def check_grad(e): + nonlocal requires_grad + if isinstance(e, TorchMLIRTensor): + requires_grad |= e.requires_grad + + tree_map(check_grad, args) + tree_map(check_grad, kwargs) + + def unwrap(e): + if isinstance(e, TorchMLIRTensor): + return e.elem + if isinstance(e, torch.nn.Parameter): + return e.detach() + return e + + def wrap(e): + nonlocal requires_grad + return ( + TorchMLIRTensor(e, requires_grad=requires_grad) + if isinstance(e, torch.Tensor) + else e + ) + + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + + try: + out = try_torch_mlir_eager( + func, + unwrapped_args, + unwrapped_kwargs, + backend=refbackend.RefBackendLinalgOnTensorsBackend(), + ) + if isinstance(out, tuple): + out = [torch.from_numpy(o) for o in out] + else: + out = torch.from_numpy(out) + return tree_map(wrap, out) + except Exception as e: + if isinstance(e, UnsupportedByTorchMlirEagerMode): + warnings.warn( + f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager." + ) + else: + warnings.warn( + f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; " + f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues" + ) + return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs)) diff --git a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py index b201dc923..14c2f48c3 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py @@ -7,3 +7,4 @@ from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig from .tosa_backend import TosaBackendTestConfig +from .eager_mode import EagerModeTestConfig diff --git a/python/torch_mlir_e2e_test/torchscript/configs/eager_mode.py b/python/torch_mlir_e2e_test/torchscript/configs/eager_mode.py new file mode 100644 index 000000000..ffabe90f2 --- /dev/null +++ b/python/torch_mlir_e2e_test/torchscript/configs/eager_mode.py @@ -0,0 +1,45 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch +from torch.utils._pytree import tree_map + +from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor +from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem + + +def wrap(e): + return TorchMLIRTensor(e.detach().clone()) if isinstance(e, torch.Tensor) else e + + +def unwrap(e): + return e.elem.clone() if isinstance(e, TorchMLIRTensor) else e + + +class EagerModeTestConfig(TestConfig): + """Trivial test config that exercises eager mode plumbing""" + + def __init__(self): + super().__init__() + + def compile(self, program: torch.nn.Module) -> torch.nn.Module: + return program + + def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: + result: Trace = [] + for item in trace: + attr = artifact + for part in item.symbol.split('.'): + attr = getattr(attr, part) + + inps = tree_map(wrap, item.inputs) + outps = attr(*inps) + output = tree_map(unwrap, outps) + + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) + return result