mirror of https://github.com/llvm/torch-mlir
This PR implements an eager mode backend for PyTorch through the torch-mlir framework. This is accomplished by overriding the `__torch_dispatch__` class method on wrapper subclass `TorchMLIRTensor(torch.Tensor)`.
Effectively, this mode works by compiling op by op as the NN is eagerly executed by PyTorch. Entailed in that compilation is building a representation of the op that can be `torch.jit.script`ed, importing using `ModuleBuilder`, and then executing (e.g., with `RefBackendLinalgOnTensorsBackend`). This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported op). Currently, all e2e tests pass execpt for two that involve an upstream PyTorch bug (https://github.com/pytorch/pytorch/issues/74400). High priority next steps: 1. A compile cache in order to speed up reruns of the same NN. 2. Integration with IREE (though not in this repo). 3. Integration with `torch.distributed`.pull/687/head snapshot-20220322.340
parent
f9d34596e8
commit
fe8ac57e6d
|
@ -15,13 +15,13 @@ from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
||||||
|
|
||||||
# Available test configs.
|
# Available test configs.
|
||||||
from torch_mlir_e2e_test.torchscript.configs import (
|
from torch_mlir_e2e_test.torchscript.configs import (
|
||||||
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig
|
LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||||
|
|
||||||
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS
|
from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, COMMON_TORCH_MLIR_LOWERING_XFAILS, EAGER_MODE_XFAIL_SET
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
||||||
|
@ -57,7 +57,7 @@ from . import cast
|
||||||
from . import index_put
|
from . import index_put
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external', 'eager_mode']
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
parser.add_argument('-c', '--config',
|
parser.add_argument('-c', '--config',
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
|
@ -121,6 +121,9 @@ def main():
|
||||||
elif args.config == 'torchscript':
|
elif args.config == 'torchscript':
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
xfail_set = {}
|
xfail_set = {}
|
||||||
|
elif args.config == 'eager_mode':
|
||||||
|
config = EagerModeTestConfig()
|
||||||
|
xfail_set = EAGER_MODE_XFAIL_SET
|
||||||
elif args.config == 'external':
|
elif args.config == 'external':
|
||||||
with open(args.external_config, 'r') as f:
|
with open(args.external_config, 'r') as f:
|
||||||
code = compile(f.read(), args.external_config, 'exec')
|
code = compile(f.read(), args.external_config, 'exec')
|
||||||
|
|
|
@ -113,9 +113,9 @@ class BernoulliModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: BernoulliModule())
|
@register_test_case(module_factory=lambda: BernoulliModule())
|
||||||
def BernoulliModule_basic(module, tu: TestUtils):
|
def BernoulliModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(256, 512, 8).double(),
|
tu.rand(512, 1024, 8).double(),
|
||||||
tu.rand(512, 1024, 4).double(),
|
tu.rand(1024, 2048, 4).double(),
|
||||||
tu.rand(512, 256, 4).double())
|
tu.rand(1024, 256, 4).double())
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
@ -188,9 +188,9 @@ class BernoulliFloatModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
||||||
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(256, 512, 8).double(),
|
tu.rand(512, 1024, 8).double(),
|
||||||
tu.rand(512, 1024, 4).double(),
|
tu.rand(1024, 2048, 4).double(),
|
||||||
tu.rand(512, 256, 4).double())
|
tu.rand(1024, 512, 4).double())
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
@ -228,9 +228,9 @@ class BernoulliTensorModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||||
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(512, 512, 8).double(),
|
tu.rand(1024, 1024, 16).double(),
|
||||||
tu.rand(512, 512, 8).double(),
|
tu.rand(1024, 1024, 16).double(),
|
||||||
tu.rand(512, 1024, 4).double(),
|
tu.rand(1024, 2048, 8).double(),
|
||||||
tu.rand(512, 1024, 4).double(),
|
tu.rand(1024, 2048, 8).double(),
|
||||||
tu.rand(512, 256, 4).double(),
|
tu.rand(1024, 512, 8).double(),
|
||||||
tu.rand(512, 256, 4).double())
|
tu.rand(1024, 512, 8).double())
|
||||||
|
|
|
@ -21,6 +21,13 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||||
}
|
}
|
||||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||||
|
|
||||||
|
EAGER_MODE_XFAIL_SET = REFBACKEND_XFAIL_SET.union({
|
||||||
|
# These fail because an upstream pytorch bug; more information at the following issue
|
||||||
|
# https://github.com/pytorch/pytorch/issues/74400
|
||||||
|
"ElementwiseMulScalarModule_basic",
|
||||||
|
"ElementwiseSubScalarIntModule_basic",
|
||||||
|
})
|
||||||
|
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
|
|
@ -57,6 +57,12 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
||||||
add_subdirectory(torch_mlir_e2e_test)
|
add_subdirectory(torch_mlir_e2e_test)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Eager mode
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
add_subdirectory(torch_mlir/eager_mode)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Generate packages and shared library
|
# Generate packages and shared library
|
||||||
# Downstreams typically will not use these, but they are useful for local
|
# Downstreams typically will not use these, but they are useful for local
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
# Setup PyTorch
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
|
||||||
|
find_package(Torch 1.8 REQUIRED)
|
||||||
|
|
||||||
|
TorchMLIRConfigurePyTorch()
|
||||||
|
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
# Subdirectories
|
||||||
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
## Declare the sources of the Python module.
|
||||||
|
|
||||||
|
declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode
|
||||||
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
|
ADD_TO_PARENT TorchMLIRPythonSources
|
||||||
|
SOURCES_GLOB eager_mode/*.py lazytensor/*.py
|
||||||
|
)
|
|
@ -0,0 +1,230 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
"""
|
||||||
|
Translator from torch.jit.ScriptFunction to MLIR.
|
||||||
|
|
||||||
|
|
||||||
|
The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the
|
||||||
|
`torch` dialect.
|
||||||
|
|
||||||
|
The expected use of this module is to create an instance of one of the classes below, and then calling the
|
||||||
|
`to_mlir` method to generate the MLIR representation of the type.
|
||||||
|
|
||||||
|
Information about what types are supported by each class can be found in docstrings of each of the classes.
|
||||||
|
|
||||||
|
In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module.
|
||||||
|
|
||||||
|
The expected use for this module is to use the function
|
||||||
|
`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module`
|
||||||
|
to convert the TorchScript function into MLIR using the `torch` dialect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import Any, Optional, Iterable
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.jit import ScriptFunction
|
||||||
|
|
||||||
|
from torch_mlir import ir
|
||||||
|
from torch_mlir.dialects.builtin import FuncOp
|
||||||
|
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class TorchMlirType(abc.ABC):
|
||||||
|
"""
|
||||||
|
A `TorchMlirType` is an object that produces MLIR
|
||||||
|
types in the `torch` dialect. The only requirement
|
||||||
|
for a class to be a subclass of `TorchMlirType` is
|
||||||
|
to define a `to_mlir(self, ir.Context) -> ir.Type`.
|
||||||
|
Each class is allowed to have different types of
|
||||||
|
__init__ methods depending on the information they
|
||||||
|
require to produce the given MLIR representation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TorchTensorTypeError(Exception):
|
||||||
|
def __init__(self, value: str):
|
||||||
|
super().__init__()
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class TorchTensorType(TorchMlirType):
|
||||||
|
"""
|
||||||
|
This class is used to generate types of the form
|
||||||
|
!torch.tensor and !torch.vtensor<SHAPE, DTYPE>,
|
||||||
|
where SHAPE is a list representing the shape of the tensor,
|
||||||
|
and DTYPE is an MLIR data type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
shape: Optional[Iterable[Optional[int]]] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
self.shape = shape
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
if dtype is None and shape is not None:
|
||||||
|
err = "If shape is specified, dtype must also be specified"
|
||||||
|
raise TorchTensorTypeError(err)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})"
|
||||||
|
|
||||||
|
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||||
|
if self.dtype is None:
|
||||||
|
return ir.Type.parse("!torch.tensor", context=context)
|
||||||
|
|
||||||
|
shape_asm = self._shape_to_mlir_asm()
|
||||||
|
dtype_asm = self._dtype_to_mlir_asm()
|
||||||
|
return ir.Type.parse(
|
||||||
|
f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shape_to_mlir_asm(self) -> str:
|
||||||
|
if self.shape is None:
|
||||||
|
return "*"
|
||||||
|
|
||||||
|
str_sizes = map(lambda x: "?" if x is None else str(x), self.shape)
|
||||||
|
return f'[{",".join(str_sizes)}]'
|
||||||
|
|
||||||
|
def _dtype_to_mlir_asm(self) -> str:
|
||||||
|
if self.dtype in [torch.float64]:
|
||||||
|
return "f64"
|
||||||
|
if self.dtype in [torch.float, torch.float32]:
|
||||||
|
return "f32"
|
||||||
|
if self.dtype in [torch.int, torch.int32]:
|
||||||
|
return "si32"
|
||||||
|
if self.dtype in [torch.int64]:
|
||||||
|
return "si64"
|
||||||
|
if self.dtype in [torch.bool]:
|
||||||
|
return "i1"
|
||||||
|
|
||||||
|
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
class TorchNnModuleType(TorchMlirType):
|
||||||
|
"""This class is used to generate types for `!torch.nn.Module`s."""
|
||||||
|
|
||||||
|
def __init__(self, module_name: str):
|
||||||
|
self.module_name = module_name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "torch.nn.Module"
|
||||||
|
|
||||||
|
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||||
|
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonType(TorchMlirType):
|
||||||
|
"""
|
||||||
|
This class is used to convert regular Python types
|
||||||
|
into their corresponding `torch` dialect representation.
|
||||||
|
The list of supported types can be found in the dictionary
|
||||||
|
`_type_to_asm_dict`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_type_to_asm_dict = {
|
||||||
|
bool: "!torch.bool",
|
||||||
|
int: "!torch.int",
|
||||||
|
type(None): "!torch.none",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, type_: Any):
|
||||||
|
self.type_ = type_
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.type_)
|
||||||
|
|
||||||
|
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||||
|
asm = self._type_to_asm_dict.get(self.type_)
|
||||||
|
if asm is None:
|
||||||
|
raise NotImplementedError(f"Unsupported type: {self.type_}")
|
||||||
|
return ir.Type.parse(asm, context=context)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This functionality should be incorporated into ModuleBuilder.import_function.
|
||||||
|
class Annotation:
|
||||||
|
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
|
||||||
|
self.types = list(
|
||||||
|
map(lambda t: PythonType(t) if isinstance(t, type) else t, types)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
result = f"Annotation instance with {len(self.types)} types\n"
|
||||||
|
for e, type_ in enumerate(self.types):
|
||||||
|
result += f" Type of argument {e + 1}: {str(type_)}\n"
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.types)
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationConverter:
|
||||||
|
@staticmethod
|
||||||
|
def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr:
|
||||||
|
dict_attrs = []
|
||||||
|
for type_ in annotation:
|
||||||
|
if not isinstance(type_, TorchTensorType):
|
||||||
|
dict_attrs.append(ir.DictAttr.get({}, context=context))
|
||||||
|
continue
|
||||||
|
|
||||||
|
ir_type = type_.to_mlir(context)
|
||||||
|
with context:
|
||||||
|
type_attr = ir.TypeAttr.get(ir_type)
|
||||||
|
dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr})
|
||||||
|
dict_attrs.append(dict_attr)
|
||||||
|
|
||||||
|
return ir.ArrayAttr.get(dict_attrs, context=context)
|
||||||
|
|
||||||
|
|
||||||
|
def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
||||||
|
with module.context:
|
||||||
|
name_attr = ir.StringAttr.get(name)
|
||||||
|
for op in module.body.operations:
|
||||||
|
if isinstance(op, FuncOp) and op.name == name_attr:
|
||||||
|
return op
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def build_module(jit_function: ScriptFunction, annotations) -> ir.Module:
|
||||||
|
"""Translate input function into an MLIR module in the `torch` dialect.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
jit_function: ScriptFunction
|
||||||
|
Function in TorchScript IR to turn into MLIR.
|
||||||
|
annotation: Annotation
|
||||||
|
Annotation object representing the types of
|
||||||
|
the operands of `jit_function`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ir.Module
|
||||||
|
Translation of the input module into an MLIR module
|
||||||
|
"""
|
||||||
|
mb = ModuleBuilder()
|
||||||
|
mb.import_function(jit_function)
|
||||||
|
|
||||||
|
func_op = get_func_op_with_name(mb.module, jit_function.name)
|
||||||
|
assert (
|
||||||
|
func_op is not None
|
||||||
|
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
|
||||||
|
|
||||||
|
func_annotation = Annotation(annotations)
|
||||||
|
arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context)
|
||||||
|
func_op.attributes["arg_attrs"] = arg_attrs
|
||||||
|
|
||||||
|
return mb.module
|
|
@ -0,0 +1,265 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
from typing import Any, Callable, Tuple, Union
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch._C
|
||||||
|
from torch.fx.node import map_aggregate
|
||||||
|
from torch.fx.operator_schemas import normalize_function, create_type_hint
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
from torch_mlir._mlir_libs._mlir.passmanager import PassManager
|
||||||
|
|
||||||
|
from torch_mlir.dialects import torch as torch_dialect
|
||||||
|
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
|
||||||
|
from torch_mlir.eager_mode.ir_building import build_module, TorchTensorType
|
||||||
|
|
||||||
|
OP_REGISTRY = {op["name"]: op for op in get_registered_ops()}
|
||||||
|
SUPPORTED_OPS = frozenset(
|
||||||
|
[
|
||||||
|
member.OPERATION_NAME
|
||||||
|
for member in vars(torch_dialect).values()
|
||||||
|
if hasattr(member, "OPERATION_NAME")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedByTorchMlirEagerMode(Exception):
|
||||||
|
def __init__(self, value: str):
|
||||||
|
super().__init__()
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
def check_supported_op(schema: torch._C.FunctionSchema) -> bool:
|
||||||
|
return (
|
||||||
|
"torch."
|
||||||
|
+ schema.name.replace("::", ".")
|
||||||
|
+ ("." + schema.overload_name if schema.overload_name else "")
|
||||||
|
) in SUPPORTED_OPS
|
||||||
|
|
||||||
|
|
||||||
|
def is_tensor_type(typ: torch._C.Type):
|
||||||
|
return typ.isSubtypeOf(torch.TensorType.get()) or (
|
||||||
|
isinstance(typ, torch.OptionalType)
|
||||||
|
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
|
||||||
|
"""Fill in default values for optional args, which are dependent on the schema."""
|
||||||
|
|
||||||
|
arg_types = map_aggregate(args, type)
|
||||||
|
assert isinstance(arg_types, tuple)
|
||||||
|
arg_types = tuple([create_type_hint(i) for i in arg_types])
|
||||||
|
kwarg_types = {k: type(v) for k, v in kwargs.items()}
|
||||||
|
|
||||||
|
new_args_and_kwargs = normalize_function(
|
||||||
|
target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False
|
||||||
|
)
|
||||||
|
assert new_args_and_kwargs, "Couldn't normalize args and kwargs"
|
||||||
|
new_args, new_kwargs = new_args_and_kwargs
|
||||||
|
return new_args, new_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def build_script_function(
|
||||||
|
schema: torch._C.FunctionSchema,
|
||||||
|
args: List[torch._C.Argument],
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> torch.jit.ScriptFunction:
|
||||||
|
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
|
||||||
|
|
||||||
|
Constants are inlined for the purposes of invalidating the compile cache when they change.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Creates empty TS graph.
|
||||||
|
graph = torch._C.Graph()
|
||||||
|
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
|
||||||
|
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
|
||||||
|
# Associate graph inputs/outputs with node inputs/outputs.
|
||||||
|
for i, arg in enumerate(schema.arguments):
|
||||||
|
# Find value corresponding to schema arg, either in positional or kw args.
|
||||||
|
kwarg = False
|
||||||
|
if arg.name in kwargs:
|
||||||
|
val = kwargs[arg.name]
|
||||||
|
kwarg = True
|
||||||
|
else:
|
||||||
|
val = args[i]
|
||||||
|
|
||||||
|
# If arg is a tensor, then add input to the graph corresponding to arg.
|
||||||
|
if is_tensor_type(arg.type) and val is not None:
|
||||||
|
inp = graph.addInput()
|
||||||
|
if isinstance(arg.type, torch.OptionalType):
|
||||||
|
inp.setType(arg.type.getElementType())
|
||||||
|
else:
|
||||||
|
inp.setType(arg.type)
|
||||||
|
|
||||||
|
if kwarg:
|
||||||
|
# Rename for debugging aid.
|
||||||
|
inp.setDebugName(arg.name)
|
||||||
|
# If arg is a constant, inline (at the top of the graph).
|
||||||
|
else:
|
||||||
|
inp = graph.insertConstant(val)
|
||||||
|
inp.node().moveBefore(node)
|
||||||
|
|
||||||
|
node.addInput(inp)
|
||||||
|
|
||||||
|
if node.hasMultipleOutputs():
|
||||||
|
for outp in node.outputs():
|
||||||
|
graph.registerOutput(outp)
|
||||||
|
else:
|
||||||
|
graph.registerOutput(node.output())
|
||||||
|
|
||||||
|
fn_name = str(node).strip()
|
||||||
|
fn = torch._C._create_function_from_graph(fn_name, graph)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def annotate_args_kwargs(
|
||||||
|
script_fun: torch._C.ScriptFunction,
|
||||||
|
normalized_args: List[Any],
|
||||||
|
normalized_kwargs: Dict[str, Any],
|
||||||
|
):
|
||||||
|
unwrapped_normalized_args = tree_map(
|
||||||
|
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
|
||||||
|
normalized_args,
|
||||||
|
)
|
||||||
|
unwrapped_normalized_kwargs = tree_map(
|
||||||
|
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
|
||||||
|
normalized_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
tensor_args = []
|
||||||
|
for i, arg in enumerate(unwrapped_normalized_args):
|
||||||
|
if isinstance(arg, np.ndarray):
|
||||||
|
# TODO: Remove once size zero dimensions are handled by torch-mlir.
|
||||||
|
shape = tuple(map(lambda x: x or -1, arg.shape))
|
||||||
|
annotations.append(
|
||||||
|
TorchTensorType(shape=shape, dtype=normalized_args[i].dtype)
|
||||||
|
)
|
||||||
|
tensor_args.append(arg)
|
||||||
|
|
||||||
|
# Pull out tensor kwargs and put them in positional order.
|
||||||
|
tensor_kwargs_flat = []
|
||||||
|
if unwrapped_normalized_kwargs:
|
||||||
|
tensor_kwargs = {}
|
||||||
|
arg_idxs = {
|
||||||
|
arg_name: i
|
||||||
|
for i, arg_name in enumerate(
|
||||||
|
[arg.name for arg in script_fun.schema.arguments]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
for i, (kw, arg) in enumerate(unwrapped_normalized_kwargs.items()):
|
||||||
|
if isinstance(arg, np.ndarray):
|
||||||
|
tensor_kwargs[arg_idxs[kw]] = (arg, normalized_kwargs[kw].dtype)
|
||||||
|
|
||||||
|
for i in range(len(tensor_kwargs)):
|
||||||
|
arg, arg_dtype = tensor_kwargs[i]
|
||||||
|
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg_dtype))
|
||||||
|
tensor_kwargs_flat.append(arg)
|
||||||
|
|
||||||
|
return annotations, tensor_args, tensor_kwargs_flat
|
||||||
|
|
||||||
|
|
||||||
|
def write_back_to_mutable(
|
||||||
|
registered_op: Dict,
|
||||||
|
out: Union[np.ndarray, List[np.ndarray]],
|
||||||
|
all_tensor_args: List[np.ndarray],
|
||||||
|
):
|
||||||
|
"""Write back to mutable args that aren't properly handled otherwise.
|
||||||
|
|
||||||
|
Because of how we pass values to the backend (by copying the tensor to a numpy array) we don't currently support
|
||||||
|
ops that mutate operands. That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
|
||||||
|
as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which
|
||||||
|
we can for these special cases. Hence, the solution is to manually write back to the same operand that the
|
||||||
|
conventional pytorch op variant would write to.
|
||||||
|
|
||||||
|
Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that
|
||||||
|
mutate running_mean and running_var). We don't currently handle those.
|
||||||
|
"""
|
||||||
|
if len(registered_op["returns"]) > 1:
|
||||||
|
raise UnsupportedByTorchMlirEagerMode(
|
||||||
|
"TorchMLIR doesn't handle multiple aliased returns yet."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
aliased_arg = next(
|
||||||
|
arg
|
||||||
|
for arg in registered_op["arguments"]
|
||||||
|
if "alias_info" in arg and arg["alias_info"]["is_write"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"alias_info" in registered_op["returns"][0]
|
||||||
|
and registered_op["returns"][0]["alias_info"]["is_write"]
|
||||||
|
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
|
||||||
|
and registered_op["returns"][0]["alias_info"]["after"][0]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(aliased_arg["alias_info"]["after"]) == 1
|
||||||
|
and aliased_arg["alias_info"]["after"][0]
|
||||||
|
== registered_op["returns"][0]["alias_info"]["after"][0]
|
||||||
|
)
|
||||||
|
np.copyto(all_tensor_args[0], out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def try_torch_mlir_eager(op, args, kwargs, backend):
|
||||||
|
if hasattr(op, "op_name"):
|
||||||
|
op_name = op.op_name
|
||||||
|
elif hasattr(op, "__name__"):
|
||||||
|
# Handle builtin_function_or_method.
|
||||||
|
op_name = op.__name__
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"op {op} has no name")
|
||||||
|
|
||||||
|
if op_name == "detach":
|
||||||
|
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
|
||||||
|
raise UnsupportedByTorchMlirEagerMode("detaching")
|
||||||
|
|
||||||
|
if not hasattr(op, "_schema"):
|
||||||
|
raise RuntimeError(f"op {op} has no schema.")
|
||||||
|
|
||||||
|
new_args, new_kwargs = normalize_args_kwargs(op.overloadpacket, args, kwargs)
|
||||||
|
|
||||||
|
if "layout" in new_kwargs and new_kwargs["layout"] not in {0, None}:
|
||||||
|
raise UnsupportedByTorchMlirEagerMode(
|
||||||
|
f"{new_kwargs['layout']} layout not supported."
|
||||||
|
)
|
||||||
|
if "memory_format" in new_kwargs and new_kwargs["memory_format"] not in {0, None}:
|
||||||
|
raise UnsupportedByTorchMlirEagerMode(
|
||||||
|
f"{new_kwargs['memory_format']} memory format not supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
script_fun = build_script_function(op._schema, new_args, new_kwargs)
|
||||||
|
annotations, np_tensor_args, np_tensor_kwargs_flat = annotate_args_kwargs(
|
||||||
|
script_fun, new_args, new_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
eager_module = build_module(script_fun, annotations)
|
||||||
|
with eager_module.context:
|
||||||
|
pm = PassManager.parse(
|
||||||
|
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
|
||||||
|
)
|
||||||
|
pm.run(eager_module)
|
||||||
|
compiled_module = backend.compile(eager_module)
|
||||||
|
loaded_module = backend.load(compiled_module)
|
||||||
|
op_mlir_backend_callable = getattr(loaded_module, script_fun.name)
|
||||||
|
assert (
|
||||||
|
op_mlir_backend_callable is not None
|
||||||
|
), f"Couldn't find function {script_fun.name} in module."
|
||||||
|
|
||||||
|
all_tensor_args = np_tensor_args + np_tensor_kwargs_flat
|
||||||
|
out = op_mlir_backend_callable(*all_tensor_args)
|
||||||
|
|
||||||
|
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
|
||||||
|
if registered_op["is_mutable"]:
|
||||||
|
out = write_back_to_mutable(registered_op, out, all_tensor_args)
|
||||||
|
|
||||||
|
return out
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from torch_mlir.eager_mode.torch_mlir_dispatch import (
|
||||||
|
try_torch_mlir_eager,
|
||||||
|
UnsupportedByTorchMlirEagerMode,
|
||||||
|
)
|
||||||
|
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||||
|
|
||||||
|
|
||||||
|
class TorchMLIRTensor(torch.Tensor):
|
||||||
|
"""Wrap torch.Tensor in orer to dispatch through torch-mlir instead of aten.
|
||||||
|
|
||||||
|
This class uses the _make_wrapper_subclass pattern to override __torch_dispatch__
|
||||||
|
in order to dispatch through torch-mlir instead of aten. Here we basically only unwrap and wrap
|
||||||
|
torch.Tensors. Most of the heavy lifting is done in the adjacent torch_mlir_dispatch module.
|
||||||
|
|
||||||
|
More documentation on how this pattern works can be found in this RFC
|
||||||
|
https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call
|
||||||
|
and this repo with many examples
|
||||||
|
https://github.com/albanD/subclass_zoo
|
||||||
|
"""
|
||||||
|
|
||||||
|
elem: torch.Tensor
|
||||||
|
|
||||||
|
__slots__ = ["elem"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, elem, *args, **kwargs):
|
||||||
|
r = torch.Tensor._make_wrapper_subclass(
|
||||||
|
cls,
|
||||||
|
elem.size(),
|
||||||
|
strides=elem.stride(),
|
||||||
|
storage_offset=elem.storage_offset(),
|
||||||
|
dtype=elem.dtype,
|
||||||
|
layout=elem.layout,
|
||||||
|
device=elem.device,
|
||||||
|
requires_grad=kwargs.get("requires_grad", False) or elem.requires_grad,
|
||||||
|
)
|
||||||
|
r.elem = elem.detach() if r.requires_grad else elem
|
||||||
|
return r
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.grad_fn:
|
||||||
|
return f"TorchMLIRTensor({self.elem}, grad_fn={self.grad_fn})"
|
||||||
|
else:
|
||||||
|
return f"TorchMLIRTensor({self.elem})"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
|
||||||
|
requires_grad = False
|
||||||
|
|
||||||
|
def check_grad(e):
|
||||||
|
nonlocal requires_grad
|
||||||
|
if isinstance(e, TorchMLIRTensor):
|
||||||
|
requires_grad |= e.requires_grad
|
||||||
|
|
||||||
|
tree_map(check_grad, args)
|
||||||
|
tree_map(check_grad, kwargs)
|
||||||
|
|
||||||
|
def unwrap(e):
|
||||||
|
if isinstance(e, TorchMLIRTensor):
|
||||||
|
return e.elem
|
||||||
|
if isinstance(e, torch.nn.Parameter):
|
||||||
|
return e.detach()
|
||||||
|
return e
|
||||||
|
|
||||||
|
def wrap(e):
|
||||||
|
nonlocal requires_grad
|
||||||
|
return (
|
||||||
|
TorchMLIRTensor(e, requires_grad=requires_grad)
|
||||||
|
if isinstance(e, torch.Tensor)
|
||||||
|
else e
|
||||||
|
)
|
||||||
|
|
||||||
|
unwrapped_args = tree_map(unwrap, args)
|
||||||
|
unwrapped_kwargs = tree_map(unwrap, kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = try_torch_mlir_eager(
|
||||||
|
func,
|
||||||
|
unwrapped_args,
|
||||||
|
unwrapped_kwargs,
|
||||||
|
backend=refbackend.RefBackendLinalgOnTensorsBackend(),
|
||||||
|
)
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
out = [torch.from_numpy(o) for o in out]
|
||||||
|
else:
|
||||||
|
out = torch.from_numpy(out)
|
||||||
|
return tree_map(wrap, out)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, UnsupportedByTorchMlirEagerMode):
|
||||||
|
warnings.warn(
|
||||||
|
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
|
||||||
|
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
|
||||||
|
)
|
||||||
|
return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))
|
|
@ -7,3 +7,4 @@ from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||||
from .native_torch import NativeTorchTestConfig
|
from .native_torch import NativeTorchTestConfig
|
||||||
from .torchscript import TorchScriptTestConfig
|
from .torchscript import TorchScriptTestConfig
|
||||||
from .tosa_backend import TosaBackendTestConfig
|
from .tosa_backend import TosaBackendTestConfig
|
||||||
|
from .eager_mode import EagerModeTestConfig
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
||||||
|
|
||||||
|
def wrap(e):
|
||||||
|
return TorchMLIRTensor(e.detach().clone()) if isinstance(e, torch.Tensor) else e
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap(e):
|
||||||
|
return e.elem.clone() if isinstance(e, TorchMLIRTensor) else e
|
||||||
|
|
||||||
|
|
||||||
|
class EagerModeTestConfig(TestConfig):
|
||||||
|
"""Trivial test config that exercises eager mode plumbing"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
return program
|
||||||
|
|
||||||
|
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||||
|
result: Trace = []
|
||||||
|
for item in trace:
|
||||||
|
attr = artifact
|
||||||
|
for part in item.symbol.split('.'):
|
||||||
|
attr = getattr(attr, part)
|
||||||
|
|
||||||
|
inps = tree_map(wrap, item.inputs)
|
||||||
|
outps = attr(*inps)
|
||||||
|
output = tree_map(unwrap, outps)
|
||||||
|
|
||||||
|
result.append(
|
||||||
|
TraceItem(symbol=item.symbol,
|
||||||
|
inputs=item.inputs,
|
||||||
|
output=output))
|
||||||
|
return result
|
Loading…
Reference in New Issue