Remove eager_mode

This was an experimental attempt at rolling out own op-by-op executor
with `__torch_dispatch__`, but it proved difficult to make it robust.
Op-by-op execution is very easy to implement robustly now with the
PyTorch 2.0 stack, so we don't need eager_mode.

Downstream users were using eager_mode to implement lockstep numerical
accuracy debuggers. We implemented the same functionality with
TorchDynamo in https://github.com/llvm/torch-mlir/pull/1681 so now there
is not much reason to continue maintaining it.
pull/1692/head
Sean Silva 2022-12-08 14:54:22 +00:00
parent 109c91ae9b
commit 7731211d02
19 changed files with 2 additions and 1565 deletions

View File

@ -75,14 +75,6 @@ torch-mlir prediction
View examples [here](docs/ltc_examples.md). View examples [here](docs/ltc_examples.md).
### Eager Mode
Eager mode with TorchMLIR is a very experimental eager mode backend for PyTorch through the torch-mlir framework.
Effectively, this mode works by compiling operator by operator as the NN is eagerly executed by PyTorch.
This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported operator).
A simple example can be found at [eager_mode.py](examples/eager_mode.py).
A ResNet18 example can be found at [eager_mode_resnet18.py](examples/eager_mode_resnet18.py).
## Repository Layout ## Repository Layout
The project follows the conventions of typical MLIR-based projects: The project follows the conventions of typical MLIR-based projects:

View File

@ -262,9 +262,6 @@ function test_in_tree() {
echo ":::: Run Linalg e2e integration tests" echo ":::: Run Linalg e2e integration tests"
python -m e2e_testing.main --config=linalg -v python -m e2e_testing.main --config=linalg -v
echo ":::: Run eager_mode e2e integration tests"
python -m e2e_testing.main --config=eager_mode -v
echo ":::: Run MHLO e2e integration tests" echo ":::: Run MHLO e2e integration tests"
python -m e2e_testing.main --config=mhlo -v python -m e2e_testing.main --config=mhlo -v

View File

@ -20,7 +20,6 @@ from torch_mlir_e2e_test.configs import (
NativeTorchTestConfig, NativeTorchTestConfig,
TorchScriptTestConfig, TorchScriptTestConfig,
TosaBackendTestConfig, TosaBackendTestConfig,
EagerModeTestConfig,
TorchDynamoTestConfig, TorchDynamoTestConfig,
) )
@ -28,14 +27,14 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
# Import tests to register them in the global registry. # Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests() register_all_tests()
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core', 'torchdynamo'] config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'lazy_tensor_core', 'torchdynamo']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config', parser.add_argument('-c', '--config',
choices=config_choices, choices=config_choices,
@ -47,7 +46,6 @@ Meaning of options:
"tosa": run through torch-mlir's default TOSA backend. "tosa": run through torch-mlir's default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
"eager_mode": run through torch-mlir's eager mode frontend, using Linalg-on-Tensors for execution.
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. "lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
"torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors. "torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors.
''') ''')
@ -91,9 +89,6 @@ def main():
elif args.config == 'torchscript': elif args.config == 'torchscript':
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
xfail_set = {} xfail_set = {}
elif args.config == 'eager_mode':
config = EagerModeTestConfig()
xfail_set = EAGER_MODE_XFAIL_SET
elif args.config == 'lazy_tensor_core': elif args.config == 'lazy_tensor_core':
config = LazyTensorCoreTestConfig() config = LazyTensorCoreTestConfig()
xfail_set = LTC_XFAIL_SET xfail_set = LTC_XFAIL_SET

View File

@ -14,13 +14,6 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
EAGER_MODE_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# RefBackend fails for some reason.
# These tests pass in the regular RefBackend flow, so it's unclear
# why they fail here.
"Matmul_vecmat",
}
TORCHDYNAMO_XFAIL_SET = { TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors #### General TorchDynamo/PyTorch errors

View File

@ -1,34 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
torch_a = torch.randn(5, requires_grad=True)
torch_b = torch.randn(5, requires_grad=True)
torch_c = torch_a + torch_b
torch_d = torch_a * torch_b
torch_e = torch_c / torch_d
torch_loss = torch_e.sum()
print("PyTorch loss: ", torch_loss)
torch_loss.backward()
print("PyTorch grad a: ", torch_a.grad)
print("PyTorch grad b: ", torch_b.grad)
a = TorchMLIRTensor(torch_a)
b = TorchMLIRTensor(torch_b)
c = a + b
d = a * b
e = c / d
loss = e.sum()
print("Torch-MLIR loss: ", loss)
loss.backward()
print("Torch-MLIR grad a: ", a.grad)
print("Torch-MLIR grad b: ", b.grad)

View File

@ -1,89 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import sys
import requests
import torch
import torchvision.models as models
from PIL import Image
from torchvision import transforms
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
def load_and_preprocess_image(url: str):
headers = {
'User-Agent':
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
}
img = Image.open(requests.get(url, headers=headers,
stream=True).raw).convert("RGB")
# preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
img_preprocessed = preprocess(img)
return torch.unsqueeze(img_preprocessed, 0)
def load_labels():
classes_text = requests.get(
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
stream=True,
).text
labels = [line.strip() for line in classes_text.splitlines()]
return labels
def top3_possibilities(res):
_, indexes = torch.sort(res, descending=True)
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
return top3
def predictions(torch_func, img, labels):
golden_prediction = top3_possibilities(torch_func(img))
print("PyTorch prediction")
print(golden_prediction)
prediction = top3_possibilities(torch_func(TorchMLIRTensor(img)))
print("torch-mlir prediction")
print(prediction)
class ResNet18Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.resnet = models.resnet18(pretrained=True)
self.train(False)
def forward(self, img):
return self.resnet.forward(img)
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s = ResNet18Module()
def forward(self, x):
return self.s.forward(x)
image_url = (
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
)
print("load image from " + image_url, file=sys.stderr)
img = load_and_preprocess_image(image_url)
labels = load_labels()
test_module = TestModule()
predictions(test_module.forward, img, labels)

View File

@ -102,12 +102,6 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
add_subdirectory(torch_mlir_e2e_test) add_subdirectory(torch_mlir_e2e_test)
endif() endif()
################################################################################
# Eager mode
################################################################################
add_subdirectory(torch_mlir/eager_mode)
################################################################################ ################################################################################
# Custom op example # Custom op example
# Required for running the update_torch_ods.sh and update_shape_lib.sh scripts. # Required for running the update_torch_ods.sh and update_shape_lib.sh scripts.

View File

@ -1,320 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
import torch
from framework import run_test
from torch_mlir.eager_mode.ir_building import build_ts_script_function
# CHECK: graph(%[[A1:.*]] : Tensor,
# CHECK: %[[A2:.*]] : Tensor,
# CHECK: %[[A3:.*]] : Tensor):
# CHECK: %[[A4:.*]] : int = prim::Constant[value=1]()
# CHECK: %[[A5:.*]] : int = prim::Constant[value=1]()
# CHECK: %[[A0:.*]] : Tensor = aten::addmm(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]])
# CHECK: return (%[[A0]])
# -----
# CHECK: PASS - simple
@run_test
def simple():
target = torch.ops.aten.addmm.default
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
mat1=torch.randn(1, 3, 32, 32),
mat2=torch.randn(1, 3, 32, 32),
beta=1,
alpha=1,
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
# CHECK: graph(%[[B1:.*]] : Tensor,
# CHECK: %[[B2:.*]] : Tensor,
# CHECK: %[[B3:.*]] : Tensor):
# CHECK: %[[B4:.*]] : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %[[B5:.*]] : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %[[B6:.*]] : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %[[B7:.*]] : bool = prim::Constant[value=0]()
# CHECK: %[[B8:.*]] : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %[[B9:.*]] : int = prim::Constant[value=1]()
# CHECK: %[[B0:.*]] : Tensor = aten::convolution(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[B9]])
# CHECK: return (%[[B0]])
# -----
# CHECK: PASS - handle_optional_tensor_input
@run_test
def handle_optional_tensor_input():
target = torch.ops.aten.convolution.default
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(3, 3, 3, 3),
bias=torch.randn(3),
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
# CHECK: FAIL - fail_not_enough_args
# CHECK: Errors: 'groups'
@run_test
def fail_not_enough_args():
target = torch.ops.aten.convolution.default
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(3, 3, 3, 3),
bias=torch.randn(3),
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
# Missing groups=1,
)
build_ts_script_function(target._schema, kwargs)
# CHECK: graph(%input : Tensor,
# CHECK: %weight : Tensor,
# CHECK: %bias : Tensor):
# CHECK: %4 : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %5 : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %6 : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %7 : bool = prim::Constant[value=0]()
# CHECK: %8 : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %9 : int = prim::Constant[value=1]()
# CHECK: %0 : Tensor = aten::convolution(%input, %weight, %bias, %4, %5, %6, %7, %8, %9)
# CHECK: return (%0)
# -----
# CHECK: PASS - simple_kwargs
@run_test
def simple_kwargs():
target = torch.ops.aten.convolution.default
script_fun1 = build_ts_script_function(
target._schema,
dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(3, 3, 3, 3),
bias=torch.randn(3),
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
),
)
print(script_fun1.graph)
# CHECK: graph(%[[C2:.*]] : Tensor):
# CHECK: %[[C3:.*]] : int[] = prim::Constant[value=[3, 3]]()
# CHECK: %[[C4:.*]] : NoneType = prim::Constant()
# CHECK: %[[C5:.*]] : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %[[C6:.*]] : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %[[C7:.*]] : bool = prim::Constant[value=0]()
# CHECK: %[[C0:.*]] : Tensor, %[[C1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]])
# CHECK: return (%[[C0]], %[[C1]])
# -----
# CHECK: PASS - handle_empty_lists
@run_test
def handle_empty_lists():
target = torch.ops.aten.max_pool2d_with_indices.default
# print(target._schema)
input = torch.randn((1, 3, 32, 32))
kwargs = dict(
input=input,
kernel_size=[3, 3],
stride=[],
padding=[0, 0],
dilation=[1, 1],
ceil_mode=False,
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
# CHECK: graph(%[[D2:.*]] : Tensor):
# CHECK: %[[D3:.*]] : int[] = prim::Constant[value=[3, 3]]()
# CHECK: %[[D4:.*]] : NoneType = prim::Constant()
# CHECK: %[[D5:.*]] : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %[[D6:.*]] : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %[[D7:.*]] : bool = prim::Constant[value=0]()
# CHECK: %[[D0:.*]] : Tensor, %[[D1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[D2]], %[[D3]], %[[D4]], %[[D5]], %[[D6]], %[[D7]])
# CHECK: return (%[[D0]], %[[D1]])
# -----
# CHECK: PASS - handle_nones
@run_test
def handle_nones():
target = torch.ops.aten.max_pool2d_with_indices.default
# print(target._schema)
kwargs = dict(
input=torch.randn((1, 3, 32, 32)),
kernel_size=[3, 3],
stride=None,
padding=[0, 0],
dilation=[1, 1],
ceil_mode=False,
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
# CHECK: graph(%[[E1:.*]] : Tensor,
# CHECK: %[[E2:.*]] : Tensor,
# CHECK: %[[E3:.*]] : Tensor):
# CHECK: %[[E4:.*]] : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %[[E5:.*]] : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %[[E6:.*]] : int[] = prim::Constant[value=[1, 1]]()
# CHECK: %[[E7:.*]] : bool = prim::Constant[value=0]()
# CHECK: %[[E8:.*]] : int[] = prim::Constant[value=[0, 0]]()
# CHECK: %[[E9:.*]] : int = prim::Constant[value=1]()
# CHECK: %[[E0:.*]] : Tensor = aten::convolution(%[[E1]], %[[E2]], %[[E3]], %[[E4]], %[[E5]], %[[E6]], %[[E7]], %[[E8]], %[[E9]])
# CHECK: return (%[[E0]])
# -----
# CHECK: PASS - handle_optional_tensors
@run_test
def handle_optional_tensors():
target = torch.ops.aten.convolution.default
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(3, 3, 3, 3),
bias=torch.randn(3),
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
# CHECK: graph(%[[F1:.*]] : Tensor):
# CHECK: %[[F2:.*]] : NoneType = prim::Constant()
# CHECK: %[[F3:.*]] : NoneType = prim::Constant()
# CHECK: %[[F4:.*]] : NoneType = prim::Constant()
# CHECK: %[[F5:.*]] : NoneType = prim::Constant()
# CHECK: %[[F6:.*]] : NoneType = prim::Constant()
# CHECK: %[[F0:.*]] : Tensor = aten::ones_like(%[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], %[[F6]])
# CHECK: return (%[[F0]])
# -----
# CHECK: PASS - handle_ones_like
@run_test
def handle_ones_like():
target = torch.ops.aten.ones_like.default
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
# CHECK: graph(%[[G3:.*]] : Tensor,
# CHECK: %[[G4:.*]] : Tensor,
# CHECK: %[[G5:.*]] : Tensor):
# CHECK: %[[G6:.*]] : NoneType = prim::Constant()
# CHECK: %[[G7:.*]] : NoneType = prim::Constant()
# CHECK: %[[G8:.*]] : bool = prim::Constant[value=0]()
# CHECK: %[[G9:.*]] : float = prim::Constant[value=1.]()
# CHECK: %[[G10:.*]] : float = prim::Constant[value=1.]()
# CHECK: %[[G0:.*]] : Tensor, %[[G1:.*]] : Tensor, %[[G2:.*]] : Tensor = aten::native_batch_norm(%[[G3]], %[[G4]], %[[G5]], %[[G6]], %[[G7]], %[[G8]], %[[G9]], %[[G10]])
# CHECK: return (%[[G0]], %[[G1]], %[[G2]])
# -----
# CHECK: PASS - handle_multiple_outputs
@run_test
def handle_multiple_outputs():
target = torch.ops.aten.native_batch_norm.default
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(1, 3, 32, 32),
bias=torch.randn(1, 3, 32, 32),
running_mean=None,
running_var=None,
training=False,
momentum=1.0,
eps=1.0
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)
# CHECK: f
# CHECK: PASS - check_legal_name
@run_test
def check_legal_name():
target = torch.ops.aten.native_batch_norm.default
kwargs = dict(
input=torch.randn(1, 3, 32, 32),
weight=torch.randn(1, 3, 32, 32),
bias=torch.randn(1, 3, 32, 32),
running_mean=None,
running_var=None,
training=False,
momentum=1.0,
eps=1.0
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.name)
# CHECK: graph(%[[H3:.*]] : Tensor,
# CHECK: %[[H4:.*]] : Tensor,
# CHECK: %[[H5:.*]] : Tensor,
# CHECK: %[[H6:.*]] : Tensor,
# CHECK: %[[H7:.*]] : Tensor,
# CHECK: %out : Tensor,
# CHECK: %save_mean : Tensor,
# CHECK: %save_invstd : Tensor):
# CHECK: %[[H8:.*]] : bool = prim::Constant[value=0]()
# CHECK: %[[H9:.*]] : float = prim::Constant[value=0.10000000000000001]()
# CHECK: %[[H10:.*]] : float = prim::Constant[value=0.0001]()
# CHECK: %[[H0:.*]] : Tensor, %[[H1:.*]] : Tensor, %[[H2:.*]] : Tensor = aten::native_batch_norm(%[[H3]], %[[H4]], %[[H5]], %[[H6]], %[[H7]], %[[H8]], %[[H9]], %[[H10]], %out, %save_mean, %save_invstd)
# CHECK: return (%[[H0]], %[[H1]], %[[H2]])
# -----
# CHECK: PASS - correctly_order_kwargs
@run_test
def correctly_order_kwargs():
target = torch.ops.aten.native_batch_norm.out
input = torch.randn(2, 5, 2, 3)
running_mean = torch.randn(5)
running_var = torch.randn(5)
kwargs = dict(
input=torch.randn(2, 5, 2, 3),
weight=torch.randn(5),
bias=torch.randn(5),
running_mean=running_mean,
running_var=running_var,
training=False,
momentum=0.1,
eps=0.0001,
out=torch.empty_like(input),
save_mean=torch.empty_like(running_mean),
save_invstd=torch.empty_like(running_var),
)
script_fun = build_ts_script_function(target._schema, kwargs)
print(script_fun.graph)

View File

@ -1,23 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: true
def run_test(*args, XPASS=False, XFAIL=False):
def _run_test(test):
test_name = test.__name__
try:
test()
print(("X" if XPASS else "") + f"PASS - {test_name}")
except Exception as e:
print(("X" if XFAIL else "") + f"FAIL - {test_name}")
print("Errors: ", e)
print()
if len(args):
_run_test(args[0])
else:
return _run_test

View File

@ -1,69 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s | FileCheck %s
import torch
from framework import run_test
from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs
# CHECK: PASS - should_normalize
@run_test
def should_normalize():
target = torch.ops.aten.max_pool2d_with_indices.default
input = torch.randn((1, 3, 32, 32))
kwargs = {"kernel_size": [3, 3]}
golden = {
"kernel_size": [3, 3],
# This is due to the schema for max_pool2d_with_indices defining
# the stride arg as int[2] stride=[].
"stride": [],
"padding": [0, 0],
"dilation": [1, 1],
"ceil_mode": False,
}
new_kwargs = normalize_args_kwargs(target, (input,), kwargs)
assert torch.allclose(new_kwargs["input"], input)
for k, v in new_kwargs.items():
if k == "input": continue
assert v == golden[k]
# CHECK: FAIL - shouldnt_normalize1
# CHECK: Errors: missing a required argument: 'kernel_size'
@run_test
def shouldnt_normalize1():
target = torch.ops.aten.max_pool2d_with_indices.default
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"stride": []}
normalize_args_kwargs(target, args, kwargs)
# This next two tests are XPASS because of https://github.com/pytorch/pytorch/issues/75342
# I.e., they should fail but in fact they pass because of the upstream bug.
# The reason for the bug is a fast path branch in operator_schemas.normalize_function
# that doesn't do rigorous type checking, and hence lets type mistmatches slip through.
# TODO(max): change these to FAIL when the upstream bug is fixed.
# CHECK: XPASS - shouldnt_normalize2
@run_test(XPASS=True)
def shouldnt_normalize2():
target = torch.ops.aten.max_pool2d_with_indices.default
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"kernel_size": []}
normalize_args_kwargs(target, args, kwargs)
# CHECK: XPASS - shouldnt_normalize3
@run_test(XPASS=True)
def shouldnt_normalize3():
target = torch.ops.aten.max_pool2d_with_indices.default
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"kernel_size": [3, 3], "padding": None}
normalize_args_kwargs(target, args, kwargs)

View File

@ -1,11 +0,0 @@
#-------------------------------------------------------------------------------
# Subdirectories
#-------------------------------------------------------------------------------
## Declare the sources of the Python module.
declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES_GLOB eager_mode/*.py lazytensor/*.py
)

View File

@ -1,3 +0,0 @@
import os
EAGER_MODE_DEBUG = os.environ.get("EAGER_MODE_DEBUG", 'False').lower() in ('true', '1', 't')

View File

@ -1,359 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Translator from torch.jit.ScriptFunction to MLIR.
The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the
`torch` dialect.
The expected use of this module is to create an instance of one of the classes below, and then calling the
`to_mlir` method to generate the MLIR representation of the type.
Information about what types are supported by each class can be found in docstrings of each of the classes.
In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module.
The expected use for this module is to use the function
`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module`
to convert the TorchScript function into MLIR using the `torch` dialect.
"""
import abc
import re
from typing import Any, Optional, Iterable, Dict
from typing import Union
import numpy as np
import torch
import torch._C
import torch.jit
from torch._ops import OpOverload
from torch_mlir import ir
from torch_mlir.dialects.func import FuncOp
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
class TorchMlirType(abc.ABC):
"""
A `TorchMlirType` is an object that produces MLIR
types in the `torch` dialect. The only requirement
for a class to be a subclass of `TorchMlirType` is
to define a `to_mlir(self, ir.Context) -> ir.Type`.
Each class is allowed to have different types of
__init__ methods depending on the information they
require to produce the given MLIR representation.
"""
@abc.abstractmethod
def to_mlir(self, context: ir.Context) -> ir.Type:
pass
class TorchTensorTypeError(Exception):
def __init__(self, value: str):
super().__init__()
self.value = value
def __str__(self) -> str:
return self.value
class TorchTensorType(TorchMlirType):
"""
This class is used to generate types of the form
!torch.tensor and !torch.vtensor<SHAPE, DTYPE>,
where SHAPE is a list representing the shape of the tensor,
and DTYPE is an MLIR data type.
"""
def __init__(
self,
*,
shape: Optional[Iterable[Optional[int]]] = None,
dtype: Optional[torch.dtype] = None,
):
self.shape = shape
self.dtype = dtype
if dtype is None and shape is not None:
err = "If shape is specified, dtype must also be specified"
raise TorchTensorTypeError(err)
def __str__(self):
return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})"
def to_mlir(self, context: ir.Context) -> ir.Type:
if self.dtype is None:
return ir.Type.parse("!torch.tensor", context=context)
shape_asm = self._shape_to_mlir_asm()
dtype_asm = self._dtype_to_mlir_asm()
return ir.Type.parse(
f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context
)
def _shape_to_mlir_asm(self) -> str:
if self.shape is None:
return "*"
str_sizes = map(lambda x: "?" if x is None else str(x), self.shape)
return f'[{",".join(str_sizes)}]'
def _dtype_to_mlir_asm(self) -> str:
if self.dtype in [torch.float64]:
return "f64"
if self.dtype in [torch.float, torch.float32]:
return "f32"
if self.dtype in [torch.int, torch.int32]:
return "si32"
if self.dtype in [torch.int64]:
return "si64"
if self.dtype in [torch.bool]:
return "i1"
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
class TorchNnModuleType(TorchMlirType):
"""This class is used to generate types for `!torch.nn.Module`s."""
def __init__(self, module_name: str):
self.module_name = module_name
def __str__(self):
return "torch.nn.Module"
def to_mlir(self, context: ir.Context) -> ir.Type:
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context)
class PythonType(TorchMlirType):
"""
This class is used to convert regular Python types
into their corresponding `torch` dialect representation.
The list of supported types can be found in the dictionary
`_type_to_asm_dict`.
"""
_type_to_asm_dict = {
bool: "!torch.bool",
int: "!torch.int",
type(None): "!torch.none",
}
def __init__(self, type_: Any):
self.type_ = type_
def __str__(self):
return str(self.type_)
def to_mlir(self, context: ir.Context) -> ir.Type:
asm = self._type_to_asm_dict.get(self.type_)
if asm is None:
raise NotImplementedError(f"Unsupported type: {self.type_}")
return ir.Type.parse(asm, context=context)
# TODO: This functionality should be incorporated into ModuleBuilder.import_function.
class Annotation:
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
self.types = list(
map(lambda t: PythonType(t) if isinstance(t, type) else t, types)
)
def __str__(self):
result = f"Annotation instance with {len(self.types)} types\n"
for e, type_ in enumerate(self.types):
result += f" Type of argument {e + 1}: {str(type_)}\n"
return result
def __iter__(self):
return iter(self.types)
class AnnotationConverter:
@staticmethod
def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr:
dict_attrs = []
for type_ in annotation:
if not isinstance(type_, TorchTensorType):
dict_attrs.append(ir.DictAttr.get({}, context=context))
continue
ir_type = type_.to_mlir(context)
with context:
type_attr = ir.TypeAttr.get(ir_type)
dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr})
dict_attrs.append(dict_attr)
return ir.ArrayAttr.get(dict_attrs, context=context)
def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
with module.context:
name_attr = ir.StringAttr.get(name)
for op in module.body.operations:
if isinstance(op, FuncOp) and op.name == name_attr:
# Add name of torch op as debug_module_name so that
# run_pipeline_with_repro_report can use it.
module.operation.attributes["torch.debug_module_name"] = name_attr
return op
return None
def is_tensor_type(typ: torch._C.Type):
return typ.isSubtypeOf(torch.TensorType.get()) or (
isinstance(typ, torch.OptionalType)
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
)
def is_list_of_tensors_type(typ: torch._C.Type):
return isinstance(typ, torch.ListType) and is_tensor_type(typ.getElementType())
name_mangle_regex = re.compile("[^a-zA-Z0-9]")
def build_ts_script_function(
schema: torch._C.FunctionSchema, kwargs: Dict[str, Any]
) -> torch.jit.ScriptFunction:
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
Constants are inlined for the purposes of invalidating the compile cache when they change.
Parameters
----------
schema: torch._C.FunctionSchema
PyTorch's representation for ops, contains type information needed for inlining constants into the TS graph.
kwargs: Dict
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float/bool params).
Returns
-------
torch.jit.ScriptFunction
Fully specialized (all constants) TS graph whose only arguments are tensors.
"""
# Creates empty TS graph.
graph = torch._C.Graph()
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
# Associate graph inputs/outputs with node inputs/outputs.
graph_inputs = []
for arg in schema.arguments:
arg_name = arg.name if arg.name != "self" else "input"
# If arg is a flattened list of tensors, such as in the case of torch.cat
# then add each element of the list to the graph corresponding to arg
# and insert a ListConstruct to function as input to the op.
if is_list_of_tensors_type(arg.type):
inps = []
for kwarg in [
kwarg for kwarg in kwargs if f"{arg_name}_flattened" in kwarg
]:
inp = graph.addInput()
el_typ = arg.type.getElementType()
if isinstance(el_typ, torch.OptionalType):
el_typ = el_typ.getElementType()
inp.setType(el_typ)
inp.setDebugName(kwarg)
inps.append(inp)
graph_inputs.append(kwarg)
list_cons = graph.insertNode(graph.create("prim::ListConstruct", inps))
list_cons.moveBefore(node)
inp = list_cons.output()
inp.setType(torch.ListType.ofTensors())
# If arg is a tensor, then add input to the graph corresponding to arg.
elif is_tensor_type(arg.type) and kwargs[arg_name] is not None:
inp = graph.addInput()
if isinstance(arg.type, torch.OptionalType):
el_typ = arg.type.getElementType()
else:
el_typ = arg.type
inp.setType(el_typ)
inp.setDebugName(arg_name)
graph_inputs.append(arg_name)
# If arg is a constant, inline (at the top of the graph).
else:
val = kwargs[arg_name]
if val == []:
# Some ops have empty list default values for args
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
# This might be an upstream bug but there doesn't seem to be a way to
# build a prim::ListConstruct list that's empty.
val = None
inp = graph.insertConstant(val)
inp.node().moveBefore(node)
node.addInput(inp)
# Reorder graph inputs to match kwargs.
permutes = [
{inp: i for i, inp in enumerate(graph_inputs)}[kwarg]
for kwarg in [kwarg for kwarg in kwargs if kwarg in graph_inputs]
]
graph.permuteInputs(permutes)
if node.hasMultipleOutputs():
for outp in node.outputs():
graph.registerOutput(outp)
else:
graph.registerOutput(node.output())
fn = torch._C._create_function_from_graph(
f"{name_mangle_regex.sub('', str(graph))}", graph
)
return fn
def build_mlir_module(op: OpOverload, kwargs: Dict[str, Any]) -> ir.Module:
"""Translate input function into an MLIR module in the `torch` dialect.
Parameters
----------
op: OpOverload
Callable from the torch.ops.aten module/namespace that has a _schema field.
kwargs: Dict
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float,bool params).
Returns
-------
ir.Module
Translation of the input module into an MLIR module.
"""
# The assert here is to catch tensor shapes that have size 0 dimensions, such as those produced in
# the course of evaluating SliceEndSleStartModule_basic and SliceOutOfLowerBoundEndIndexModule_basic.
# Such 0 size dimensions fail the assert at mlir/lib/IR/BuiltinTypes.cpp, line 887
annotations = []
for arg_name, arg in kwargs.items():
if isinstance(arg, torch.Tensor):
assert np.prod(arg.shape) != 0, f"{arg_name} has invalid shape {arg.shape}"
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg.dtype))
annotations = tuple(annotations)
script_fun = build_ts_script_function(op._schema, kwargs)
assert len(annotations) == len(
list(script_fun.graph.inputs())
), "Number of annotations and number of graph inputs differs."
mb = ModuleBuilder()
mb.import_function(script_fun)
func_op = get_func_op_with_name(mb.module, script_fun.name)
assert (
func_op is not None
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
func_annotation = Annotation(annotations)
arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context)
func_op.attributes["arg_attrs"] = arg_attrs
return mb.module

View File

@ -1,111 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from __future__ import annotations
from typing import Any, Callable, Tuple
from typing import Dict
import torch
from torch.fx import immutable_collections
from torch.fx.operator_schemas import (
_torchscript_schema_to_signature,
_args_kwargs_to_normalized_args_kwargs,
)
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops
from torch_mlir.dialects import torch as torch_dialect
OP_REGISTRY = {op["name"]: op for op in get_registered_ops()}
SUPPORTED_OPS = frozenset(
[
member.OPERATION_NAME
for member in vars(torch_dialect).values()
if hasattr(member, "OPERATION_NAME")
]
)
class UnsupportedByTorchMlirEagerMode(Exception):
def __init__(self, value: str):
super().__init__()
self.value = value
def __str__(self) -> str:
return self.value
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
"""Fill in default values for optional args, which are dependent on the schema."""
sig = _torchscript_schema_to_signature(target._schema)
_, new_kwargs = _args_kwargs_to_normalized_args_kwargs(
sig, args, kwargs, normalize_to_only_use_kwargs=True
)
if "self" in new_kwargs:
new_kwargs["input"] = new_kwargs.pop("self")
# Flatten lists of args for ops that takes lists, such as torch.cat.
to_remove = set()
to_add = {}
for k, v in new_kwargs.items():
if isinstance(v, (tuple, list)) and len(v) and isinstance(v[0], torch.Tensor):
to_remove.add(k)
for i, vv in enumerate(v):
to_add[f"{k}_flattened_{i}"] = vv
for rem in to_remove:
del new_kwargs[rem]
new_kwargs.update(**to_add)
# Sort here in order to have consistency across TS graph and
# MLIR module.
sorted_kwargs = dict(sorted(new_kwargs.items()))
return immutable_collections.immutable_dict(sorted_kwargs)
def get_registered_op(op):
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
return registered_op
def check_get_aliased_arg(func: Callable,):
"""Write back to mutable args that aren't properly handled otherwise.
Because of how we pass values to the backend we don't currently support ops that mutate operands.
That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which
we can for these special cases. Hence, the solution is to manually write back to the same operand that the
conventional pytorch op variant would write to.
Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that
mutate running_mean and running_var). We don't currently handle those.
"""
registered_op = get_registered_op(func)
if not registered_op["is_mutable"]:
return None
if len(registered_op["returns"]) > 1:
raise UnsupportedByTorchMlirEagerMode(
"TorchMLIR doesn't handle multiple aliased returns yet."
)
aliased_arg = next(
arg
for arg in registered_op["arguments"]
if "alias_info" in arg and arg["alias_info"]["is_write"]
)
assert (
"alias_info" in registered_op["returns"][0]
and registered_op["returns"][0]["alias_info"]["is_write"]
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
and registered_op["returns"][0]["alias_info"]["after"][0]
)
assert (
len(aliased_arg["alias_info"]["after"]) == 1
and aliased_arg["alias_info"]["after"][0]
== registered_op["returns"][0]["alias_info"]["after"][0]
)
return aliased_arg["name"] if aliased_arg["name"] != "self" else "input"

View File

@ -1,102 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import abc
from dataclasses import dataclass
from typing import TypeVar, Tuple, Callable, List, Dict, Any
import torch
from torch_mlir._mlir_libs._mlir.ir import Module
# TODO: This might need to be an ABC too, such as
# to support finding the backend that created the tensor.
DeviceTensor = TypeVar("DeviceTensor")
@dataclass(frozen=True)
class TensorMetaData:
"""A small container for metadata necessary for satisfying the pytorch dispatcher and other code (pytorch or
otherwise) that branches on these attributes.
There is a lot of code in the PyTorch codebase that branches based on these attributes; the obvious ones here
are dtype, device, and requires_grad (necessary for autograd itself). There is ample warning from PyTorch that,
in principle, these should be as close as possible to true; see
https://github.com/albanD/subclass_zoo/blob/1566e038f03cd89ab3cc37e670a44e3c2bbc1897/trivial_tensors.py#L90-L92
The defaults (properties) simplify the api and seem to work after some testing but
might malfunction in unexpected ways.
# TODO: revisit these assumptions
"""
size: Tuple[int]
dtype: torch.dtype
requires_grad: bool
strides: Tuple[int]
storage_offset: int = 0
layout: torch.layout = torch.strided
device: torch.device = torch.device("cpu")
def __init__(
self,
size,
dtype,
requires_grad,
strides=None,
storage_offset=None,
layout=None,
device=None,
):
super().__init__()
object.__setattr__(self, "size", size)
object.__setattr__(self, "dtype", dtype)
object.__setattr__(self, "requires_grad", requires_grad)
object.__setattr__(
self, "strides", strides if strides is not None else len(size) * [0]
)
object.__setattr__(
self, "storage_offset", storage_offset if storage_offset is not None else 0
)
object.__setattr__(
self, "layout", layout if layout is not None else torch.strided
)
object.__setattr__(
self, "device", device if device is not None else torch.device("cpu")
)
class TorchMLIREagerBackend(abc.ABC):
@abc.abstractmethod
def compile(
self, module: Module
) -> Callable[[List[DeviceTensor]], List[DeviceTensor]]:
raise NotImplementedError
@abc.abstractmethod
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> DeviceTensor:
"""Unwrap the backend representation in order to build a torch.Tensor."""
raise NotImplementedError
@abc.abstractmethod
def get_torch_metadata(
self, tensor: DeviceTensor, kwargs: Dict[str, Any]
) -> TensorMetaData:
"""Parse relevant tensor metadata from backend device array (e.g., shape, stride, layout) in order to build
wrapper tensor."""
raise NotImplementedError
@abc.abstractmethod
def transfer_from_device_to_torch(self, tensor: DeviceTensor) -> torch.Tensor:
"""If compilation fails for some reason then device specific representations need to be munged into a
torch.Tensor representation.
"""
raise NotImplementedError
@abc.abstractmethod
def copy_into(self, dst: DeviceTensor, src: DeviceTensor):
"""This method is needed for things like handling aliased arguments."""
raise NotImplementedError

View File

@ -1,257 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import contextlib
import re
import traceback
import warnings
from typing import Any
import torch
from torch.utils._pytree import tree_map
from torch_mlir.eager_mode.ir_building import build_mlir_module
from torch_mlir.eager_mode.torch_mlir_dispatch import (
UnsupportedByTorchMlirEagerMode,
normalize_args_kwargs,
check_get_aliased_arg,
)
from torch_mlir.eager_mode import EAGER_MODE_DEBUG
from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend
@contextlib.contextmanager
def no_dispatch():
"""Prevent infinite recursion in case accidentally calling a tensor method on a TorchMLIRTensor within
__torch_dispatch__."""
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
backend = EagerModeRefBackend()
UNSUPPORTED_OPS = re.compile(
"|".join([
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
"detach",
# We don't handle _local_scalar_dense because it's just a way to unwrap a tensor that wraps a number.
"_local_scalar_dense",
# https://github.com/llvm/torch-mlir/issues/878
"_unsafe_view",
"view",
])
)
class TorchMLIRTensor(torch.Tensor):
"""This class serves the role abstract class with common functionality for dispatching through Torch-MLIR instead of aten.
It defers device specific behavior to device specific implementations. The deriving classes use the
make_bare_wrapper_subclass convenience method, adjacent here, and override __torch_dispatch__ in order to dispatch
through Torch-MLIR instead of aten. Backends are free to choose whatever representation of the buffers (i.e., `elem`)
and are expected to provide conversion mechanisms between their representation and torch.Tensor.
Here we only verify that inputs abide by current supported features of Torch-MLIR (contiguous memory and
strided tensor layout) and build the mlir module. Importantly, we also recover from any malfunctions in the
deriving classes and dispatch back to conventional PyTorch.
More documentation on how the __torch_dispatch__ pattern works can be found in this forum post
https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557
and this RFC
https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call
and this repo with many examples
https://github.com/albanD/subclass_zoo
"""
elem: Any
__slots__ = ["elem"]
def __new__(cls, elem, **kwargs):
"""Wrap elem (which could be a torch.Tensor or otherwise) in a torch.Tensor subclass.
Critically, this method needs to parse relevant metadata from the device representation
(such as shape, striding, dtype, etc.) and translate it into torch conventions.
Deriving classes must provide a way to construct themselves from either their device specific representation
or torch.Tensor; the latter is to handle the case that dispatch to PyTorch to recover from an error.
"""
if kwargs.get("constructing_from_device_tensor", False):
tensor_meta_data = backend.get_torch_metadata(elem, kwargs)
r = make_bare_wrapper_subclass(
cls=cls,
size=tensor_meta_data.size,
strides=tensor_meta_data.strides,
storage_offset=tensor_meta_data.storage_offset,
dtype=tensor_meta_data.dtype,
layout=tensor_meta_data.layout,
device=tensor_meta_data.device,
requires_grad=tensor_meta_data.requires_grad,
)
r.elem = elem
elif isinstance(elem, torch.nn.Parameter):
r = make_wrapper_subclass_from_torch_tensor(cls, elem.data, **kwargs)
r.elem = backend.transfer_from_torch_to_device(elem.detach().data)
elif isinstance(elem, torch.Tensor):
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
r.elem = backend.transfer_from_torch_to_device(elem)
# This branch handles the case when a python scalar is passed to some op
# or is returned from some aten op, such as _local_scalar_dense.
elif isinstance(elem, (int, float, bool)):
return elem
else:
raise ValueError(f"Unknown element type: {type(elem)}")
return r
def __repr__(self):
if self.grad_fn:
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})"
else:
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__})"
@classmethod
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
requires_grad = check_requires_grad(*args, **kwargs)
try:
with no_dispatch():
if hasattr(func, "op_name"):
op_name = func.op_name
elif hasattr(func, "__name__"):
# Handle builtin_function_or_method.
op_name = func.__name__
else:
raise RuntimeError(f"op {func} has no name")
requires_grad = requires_grad and "view" not in op_name
if UNSUPPORTED_OPS.match(op_name):
raise UnsupportedByTorchMlirEagerMode(op_name)
if not hasattr(func, "_schema"):
raise RuntimeError(f"op {func} has no schema.")
normalized_kwargs = normalize_args_kwargs(func, args, kwargs)
if "layout" in normalized_kwargs and normalized_kwargs[
"layout"
] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{normalized_kwargs['layout']} layout not supported."
)
if "memory_format" in normalized_kwargs and normalized_kwargs[
"memory_format"
] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{normalized_kwargs['memory_format']} memory format not supported."
)
eager_module = build_mlir_module(func, normalized_kwargs)
device_tensor_args = [
kwarg.elem
for _, kwarg in normalized_kwargs.items()
if isinstance(kwarg, cls)
]
assert len(eager_module.body.operations[0].arguments) == len(
device_tensor_args
), "Number of parameters and number of arguments differs."
op_mlir_backend_callable = backend.compile(eager_module)
out = op_mlir_backend_callable(*device_tensor_args)
out = tree_map(
lambda x: cls(
x, requires_grad=requires_grad, constructing_from_device_tensor=True
),
out,
)
except Exception as e:
if EAGER_MODE_DEBUG:
warnings.warn(traceback.format_exc())
if isinstance(e, UnsupportedByTorchMlirEagerMode):
warnings.warn(
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
)
else:
warnings.warn(
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
)
with no_dispatch():
unwrapped_args = tree_map(cls.unwrap, args)
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
out = func(*unwrapped_args, **unwrapped_kwargs)
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)
maybe_aliased_arg_name = check_get_aliased_arg(func)
if maybe_aliased_arg_name is not None:
backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem)
return out
@classmethod
def unwrap(cls, e):
"""Unwrap the TorchMLIRTensor representation in order to access the actual device specific representation."""
if isinstance(e, cls):
return backend.transfer_from_device_to_torch(e.elem)
return e
def check_requires_grad(*args, **kwargs):
requires_grad = False
def check_grad(e):
nonlocal requires_grad
if isinstance(e, TorchMLIRTensor):
requires_grad |= e.requires_grad
tree_map(check_grad, args)
tree_map(check_grad, kwargs)
return requires_grad
def make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs):
"""Convenience method that parse out relevant metadata from a torch.Tensor, in order to produce
a wrapper subclass.
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
of the device specific implementation.
"""
r = make_bare_wrapper_subclass(
cls=cls,
size=elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
# Only float tensors can have gradients.
requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64}
and (kwargs.get("requires_grad", False) or elem.requires_grad),
)
return r
def make_bare_wrapper_subclass(
*, cls, size, strides, storage_offset, dtype, layout, device, requires_grad
):
"""Convenience method that builds a wrapper subclass.
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
of the device specific implementation.
"""
return torch.Tensor._make_wrapper_subclass(
cls,
size,
strides=strides,
storage_offset=storage_offset,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
)

View File

@ -9,5 +9,4 @@ from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig from .torchscript import TorchScriptTestConfig
from .mhlo_backend import MhloBackendTestConfig from .mhlo_backend import MhloBackendTestConfig
from .tosa_backend import TosaBackendTestConfig from .tosa_backend import TosaBackendTestConfig
from .eager_mode import EagerModeTestConfig
from .torchdynamo import TorchDynamoTestConfig from .torchdynamo import TorchDynamoTestConfig

View File

@ -1,65 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch.utils._pytree import tree_map
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
def wrap(e):
return TorchMLIRTensor(e.detach().clone()) if isinstance(e, torch.Tensor) else e
def unwrap(e):
return TorchMLIRTensor.unwrap(e) if isinstance(e, TorchMLIRTensor) else e
def to_tmt(m: torch.nn.Module):
for buf_name, buf in m.named_buffers(recurse=True):
if isinstance(buf, TorchMLIRTensor):
continue
m.register_buffer(buf_name, TorchMLIRTensor(buf))
for param_name, param in m.named_parameters(recurse=True):
if isinstance(param, TorchMLIRTensor):
continue
m.register_parameter(
param_name,
torch.nn.Parameter(
TorchMLIRTensor(param), requires_grad=param.requires_grad
),
)
for attr in dir(m):
field = getattr(m, attr)
if isinstance(field, torch.Tensor) and not isinstance(field, TorchMLIRTensor):
setattr(m, attr, TorchMLIRTensor(field))
class EagerModeTestConfig(TestConfig):
"""Trivial test config that exercises eager mode plumbing"""
def __init__(self):
super().__init__()
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
program.apply(to_tmt)
return program
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
attr = artifact
for part in item.symbol.split("."):
attr = getattr(attr, part)
inps = tree_map(wrap, item.inputs)
outps = attr(*inps)
output = tree_map(unwrap, outps)
result.append(
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
)
return result

View File

@ -1,90 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from __future__ import annotations
from typing import Dict, Any
import numpy as np
import torch
from torch_mlir.compiler_utils import (
get_module_name_for_debug_dump,
run_pipeline_with_repro_report,
)
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
TorchMLIREagerBackend,
TensorMetaData,
)
from torch_mlir.ir import Module
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
RefBackendLinalgOnTensorsBackend,
)
NUMPY_TO_TORCH_DTYPE_DICT = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
_ref_backend = RefBackendLinalgOnTensorsBackend()
class EagerModeRefBackend(TorchMLIREagerBackend):
"""Main entry-point for the reference backend for eager mode.
RefBackend uses numpy.ndarray representations of tensors and thus all of the wrapping and unwrapping
and munging here is done to between torch.Tensor and numpy.ndarray.
"""
module_to_refbackend_invoker = {}
def get_torch_metadata(
self, tensor: np.ndarray, kwargs: Dict[str, Any]
) -> TensorMetaData:
return TensorMetaData(
size=tensor.shape,
dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type],
requires_grad=tensor.dtype in {np.float, np.float32, np.float64}
and kwargs.get("requires_grad", False),
)
def compile(self, imported_module: Module):
"""Lower the imported TS module to linalg and then further compile for the reference backend and then call."""
fn_name = get_module_name_for_debug_dump(imported_module)
module_hash = str(imported_module)
if module_hash not in self.module_to_refbackend_invoker:
run_pipeline_with_repro_report(
imported_module,
"builtin.module(torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline)",
"EagerMode",
)
self.module_to_refbackend_invoker[module_hash] = _ref_backend.load(
_ref_backend.compile(imported_module)
)
ref_backend_invoker = self.module_to_refbackend_invoker[module_hash]
op_mlir_backend_callable = getattr(ref_backend_invoker, fn_name)
assert (
op_mlir_backend_callable is not None
), f"Couldn't find function in module."
return op_mlir_backend_callable
def copy_into(self, dst: np.ndarray, src: np.ndarray):
np.copyto(dst, src)
def transfer_from_device_to_torch(self, e: np.ndarray):
return torch.from_numpy(e).clone()
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> np.ndarray:
return tensor.detach().numpy()