diff --git a/README.md b/README.md index 22e07dffe..051456829 100644 --- a/README.md +++ b/README.md @@ -75,14 +75,6 @@ torch-mlir prediction View examples [here](docs/ltc_examples.md). -### Eager Mode - -Eager mode with TorchMLIR is a very experimental eager mode backend for PyTorch through the torch-mlir framework. -Effectively, this mode works by compiling operator by operator as the NN is eagerly executed by PyTorch. -This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported operator). -A simple example can be found at [eager_mode.py](examples/eager_mode.py). -A ResNet18 example can be found at [eager_mode_resnet18.py](examples/eager_mode_resnet18.py). - ## Repository Layout The project follows the conventions of typical MLIR-based projects: diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 4cc7535f1..ff9d5aacf 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -262,9 +262,6 @@ function test_in_tree() { echo ":::: Run Linalg e2e integration tests" python -m e2e_testing.main --config=linalg -v - echo ":::: Run eager_mode e2e integration tests" - python -m e2e_testing.main --config=eager_mode -v - echo ":::: Run MHLO e2e integration tests" python -m e2e_testing.main --config=mhlo -v diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 0649562fc..40a164007 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -20,7 +20,6 @@ from torch_mlir_e2e_test.configs import ( NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, - EagerModeTestConfig, TorchDynamoTestConfig, ) @@ -28,14 +27,14 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackend from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET +from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core', 'torchdynamo'] + config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'lazy_tensor_core', 'torchdynamo'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('-c', '--config', choices=config_choices, @@ -47,7 +46,6 @@ Meaning of options: "tosa": run through torch-mlir's default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). -"eager_mode": run through torch-mlir's eager mode frontend, using Linalg-on-Tensors for execution. "lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. "torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors. ''') @@ -91,9 +89,6 @@ def main(): elif args.config == 'torchscript': config = TorchScriptTestConfig() xfail_set = {} - elif args.config == 'eager_mode': - config = EagerModeTestConfig() - xfail_set = EAGER_MODE_XFAIL_SET elif args.config == 'lazy_tensor_core': config = LazyTensorCoreTestConfig() xfail_set = LTC_XFAIL_SET diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 14b5a9af4..5336581bd 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -14,13 +14,6 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS -EAGER_MODE_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { - # RefBackend fails for some reason. - # These tests pass in the regular RefBackend flow, so it's unclear - # why they fail here. - "Matmul_vecmat", -} - TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors diff --git a/examples/eager_mode.py b/examples/eager_mode.py deleted file mode 100644 index 7382dbfcf..000000000 --- a/examples/eager_mode.py +++ /dev/null @@ -1,34 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import torch - -from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor - -torch_a = torch.randn(5, requires_grad=True) -torch_b = torch.randn(5, requires_grad=True) - -torch_c = torch_a + torch_b -torch_d = torch_a * torch_b -torch_e = torch_c / torch_d -torch_loss = torch_e.sum() -print("PyTorch loss: ", torch_loss) - -torch_loss.backward() -print("PyTorch grad a: ", torch_a.grad) -print("PyTorch grad b: ", torch_b.grad) - -a = TorchMLIRTensor(torch_a) -b = TorchMLIRTensor(torch_b) - -c = a + b -d = a * b -e = c / d -loss = e.sum() -print("Torch-MLIR loss: ", loss) - -loss.backward() -print("Torch-MLIR grad a: ", a.grad) -print("Torch-MLIR grad b: ", b.grad) diff --git a/examples/eager_mode_resnet18.py b/examples/eager_mode_resnet18.py deleted file mode 100644 index e08eb4935..000000000 --- a/examples/eager_mode_resnet18.py +++ /dev/null @@ -1,89 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import sys - -import requests -import torch -import torchvision.models as models -from PIL import Image -from torchvision import transforms - -from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor - - -def load_and_preprocess_image(url: str): - headers = { - 'User-Agent': - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36' - } - img = Image.open(requests.get(url, headers=headers, - stream=True).raw).convert("RGB") - # preprocessing pipeline - preprocess = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) - img_preprocessed = preprocess(img) - return torch.unsqueeze(img_preprocessed, 0) - - -def load_labels(): - classes_text = requests.get( - "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", - stream=True, - ).text - labels = [line.strip() for line in classes_text.splitlines()] - return labels - - -def top3_possibilities(res): - _, indexes = torch.sort(res, descending=True) - percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 - top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] - return top3 - - -def predictions(torch_func, img, labels): - golden_prediction = top3_possibilities(torch_func(img)) - print("PyTorch prediction") - print(golden_prediction) - prediction = top3_possibilities(torch_func(TorchMLIRTensor(img))) - print("torch-mlir prediction") - print(prediction) - - -class ResNet18Module(torch.nn.Module): - def __init__(self): - super().__init__() - self.resnet = models.resnet18(pretrained=True) - self.train(False) - - def forward(self, img): - return self.resnet.forward(img) - - -class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.s = ResNet18Module() - - def forward(self, x): - return self.s.forward(x) - - -image_url = ( - "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" -) - -print("load image from " + image_url, file=sys.stderr) -img = load_and_preprocess_image(image_url) -labels = load_labels() - -test_module = TestModule() -predictions(test_module.forward, img, labels) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 2167abd71..fbc29968f 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -102,12 +102,6 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) add_subdirectory(torch_mlir_e2e_test) endif() -################################################################################ -# Eager mode -################################################################################ - -add_subdirectory(torch_mlir/eager_mode) - ################################################################################ # Custom op example # Required for running the update_torch_ods.sh and update_shape_lib.sh scripts. diff --git a/python/test/eager_mode/build_script_function.py b/python/test/eager_mode/build_script_function.py deleted file mode 100644 index ea25f49e6..000000000 --- a/python/test/eager_mode/build_script_function.py +++ /dev/null @@ -1,320 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -# RUN: %PYTHON %s | FileCheck %s - - -import torch - -from framework import run_test -from torch_mlir.eager_mode.ir_building import build_ts_script_function - - -# CHECK: graph(%[[A1:.*]] : Tensor, -# CHECK: %[[A2:.*]] : Tensor, -# CHECK: %[[A3:.*]] : Tensor): -# CHECK: %[[A4:.*]] : int = prim::Constant[value=1]() -# CHECK: %[[A5:.*]] : int = prim::Constant[value=1]() -# CHECK: %[[A0:.*]] : Tensor = aten::addmm(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]]) -# CHECK: return (%[[A0]]) -# ----- -# CHECK: PASS - simple -@run_test -def simple(): - target = torch.ops.aten.addmm.default - kwargs = dict( - input=torch.randn(1, 3, 32, 32), - mat1=torch.randn(1, 3, 32, 32), - mat2=torch.randn(1, 3, 32, 32), - beta=1, - alpha=1, - ) - - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.graph) - - -# CHECK: graph(%[[B1:.*]] : Tensor, -# CHECK: %[[B2:.*]] : Tensor, -# CHECK: %[[B3:.*]] : Tensor): -# CHECK: %[[B4:.*]] : int[] = prim::Constant[value=[1, 1]]() -# CHECK: %[[B5:.*]] : int[] = prim::Constant[value=[0, 0]]() -# CHECK: %[[B6:.*]] : int[] = prim::Constant[value=[1, 1]]() -# CHECK: %[[B7:.*]] : bool = prim::Constant[value=0]() -# CHECK: %[[B8:.*]] : int[] = prim::Constant[value=[0, 0]]() -# CHECK: %[[B9:.*]] : int = prim::Constant[value=1]() -# CHECK: %[[B0:.*]] : Tensor = aten::convolution(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[B9]]) -# CHECK: return (%[[B0]]) -# ----- -# CHECK: PASS - handle_optional_tensor_input -@run_test -def handle_optional_tensor_input(): - target = torch.ops.aten.convolution.default - kwargs = dict( - input=torch.randn(1, 3, 32, 32), - weight=torch.randn(3, 3, 3, 3), - bias=torch.randn(3), - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - ) - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.graph) - - -# CHECK: FAIL - fail_not_enough_args -# CHECK: Errors: 'groups' -@run_test -def fail_not_enough_args(): - target = torch.ops.aten.convolution.default - kwargs = dict( - input=torch.randn(1, 3, 32, 32), - weight=torch.randn(3, 3, 3, 3), - bias=torch.randn(3), - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - # Missing groups=1, - ) - build_ts_script_function(target._schema, kwargs) - - -# CHECK: graph(%input : Tensor, -# CHECK: %weight : Tensor, -# CHECK: %bias : Tensor): -# CHECK: %4 : int[] = prim::Constant[value=[1, 1]]() -# CHECK: %5 : int[] = prim::Constant[value=[0, 0]]() -# CHECK: %6 : int[] = prim::Constant[value=[1, 1]]() -# CHECK: %7 : bool = prim::Constant[value=0]() -# CHECK: %8 : int[] = prim::Constant[value=[0, 0]]() -# CHECK: %9 : int = prim::Constant[value=1]() -# CHECK: %0 : Tensor = aten::convolution(%input, %weight, %bias, %4, %5, %6, %7, %8, %9) -# CHECK: return (%0) -# ----- -# CHECK: PASS - simple_kwargs -@run_test -def simple_kwargs(): - target = torch.ops.aten.convolution.default - script_fun1 = build_ts_script_function( - target._schema, - dict( - input=torch.randn(1, 3, 32, 32), - weight=torch.randn(3, 3, 3, 3), - bias=torch.randn(3), - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - ), - ) - - print(script_fun1.graph) - - -# CHECK: graph(%[[C2:.*]] : Tensor): -# CHECK: %[[C3:.*]] : int[] = prim::Constant[value=[3, 3]]() -# CHECK: %[[C4:.*]] : NoneType = prim::Constant() -# CHECK: %[[C5:.*]] : int[] = prim::Constant[value=[0, 0]]() -# CHECK: %[[C6:.*]] : int[] = prim::Constant[value=[1, 1]]() -# CHECK: %[[C7:.*]] : bool = prim::Constant[value=0]() -# CHECK: %[[C0:.*]] : Tensor, %[[C1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]]) -# CHECK: return (%[[C0]], %[[C1]]) -# ----- -# CHECK: PASS - handle_empty_lists -@run_test -def handle_empty_lists(): - target = torch.ops.aten.max_pool2d_with_indices.default - # print(target._schema) - input = torch.randn((1, 3, 32, 32)) - kwargs = dict( - input=input, - kernel_size=[3, 3], - stride=[], - padding=[0, 0], - dilation=[1, 1], - ceil_mode=False, - ) - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.graph) - - -# CHECK: graph(%[[D2:.*]] : Tensor): -# CHECK: %[[D3:.*]] : int[] = prim::Constant[value=[3, 3]]() -# CHECK: %[[D4:.*]] : NoneType = prim::Constant() -# CHECK: %[[D5:.*]] : int[] = prim::Constant[value=[0, 0]]() -# CHECK: %[[D6:.*]] : int[] = prim::Constant[value=[1, 1]]() -# CHECK: %[[D7:.*]] : bool = prim::Constant[value=0]() -# CHECK: %[[D0:.*]] : Tensor, %[[D1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[D2]], %[[D3]], %[[D4]], %[[D5]], %[[D6]], %[[D7]]) -# CHECK: return (%[[D0]], %[[D1]]) -# ----- -# CHECK: PASS - handle_nones -@run_test -def handle_nones(): - target = torch.ops.aten.max_pool2d_with_indices.default - # print(target._schema) - kwargs = dict( - input=torch.randn((1, 3, 32, 32)), - kernel_size=[3, 3], - stride=None, - padding=[0, 0], - dilation=[1, 1], - ceil_mode=False, - ) - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.graph) - - -# CHECK: graph(%[[E1:.*]] : Tensor, -# CHECK: %[[E2:.*]] : Tensor, -# CHECK: %[[E3:.*]] : Tensor): -# CHECK: %[[E4:.*]] : int[] = prim::Constant[value=[1, 1]]() -# CHECK: %[[E5:.*]] : int[] = prim::Constant[value=[0, 0]]() -# CHECK: %[[E6:.*]] : int[] = prim::Constant[value=[1, 1]]() -# CHECK: %[[E7:.*]] : bool = prim::Constant[value=0]() -# CHECK: %[[E8:.*]] : int[] = prim::Constant[value=[0, 0]]() -# CHECK: %[[E9:.*]] : int = prim::Constant[value=1]() -# CHECK: %[[E0:.*]] : Tensor = aten::convolution(%[[E1]], %[[E2]], %[[E3]], %[[E4]], %[[E5]], %[[E6]], %[[E7]], %[[E8]], %[[E9]]) -# CHECK: return (%[[E0]]) -# ----- -# CHECK: PASS - handle_optional_tensors -@run_test -def handle_optional_tensors(): - target = torch.ops.aten.convolution.default - kwargs = dict( - input=torch.randn(1, 3, 32, 32), - weight=torch.randn(3, 3, 3, 3), - bias=torch.randn(3), - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - transposed=False, - output_padding=[0, 0], - groups=1, - ) - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.graph) - - -# CHECK: graph(%[[F1:.*]] : Tensor): -# CHECK: %[[F2:.*]] : NoneType = prim::Constant() -# CHECK: %[[F3:.*]] : NoneType = prim::Constant() -# CHECK: %[[F4:.*]] : NoneType = prim::Constant() -# CHECK: %[[F5:.*]] : NoneType = prim::Constant() -# CHECK: %[[F6:.*]] : NoneType = prim::Constant() -# CHECK: %[[F0:.*]] : Tensor = aten::ones_like(%[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], %[[F6]]) -# CHECK: return (%[[F0]]) -# ----- -# CHECK: PASS - handle_ones_like -@run_test -def handle_ones_like(): - target = torch.ops.aten.ones_like.default - kwargs = dict( - input=torch.randn(1, 3, 32, 32), - dtype=None, - layout=None, - device=None, - pin_memory=None, - memory_format=None, - ) - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.graph) - - -# CHECK: graph(%[[G3:.*]] : Tensor, -# CHECK: %[[G4:.*]] : Tensor, -# CHECK: %[[G5:.*]] : Tensor): -# CHECK: %[[G6:.*]] : NoneType = prim::Constant() -# CHECK: %[[G7:.*]] : NoneType = prim::Constant() -# CHECK: %[[G8:.*]] : bool = prim::Constant[value=0]() -# CHECK: %[[G9:.*]] : float = prim::Constant[value=1.]() -# CHECK: %[[G10:.*]] : float = prim::Constant[value=1.]() -# CHECK: %[[G0:.*]] : Tensor, %[[G1:.*]] : Tensor, %[[G2:.*]] : Tensor = aten::native_batch_norm(%[[G3]], %[[G4]], %[[G5]], %[[G6]], %[[G7]], %[[G8]], %[[G9]], %[[G10]]) -# CHECK: return (%[[G0]], %[[G1]], %[[G2]]) -# ----- -# CHECK: PASS - handle_multiple_outputs -@run_test -def handle_multiple_outputs(): - target = torch.ops.aten.native_batch_norm.default - kwargs = dict( - input=torch.randn(1, 3, 32, 32), - weight=torch.randn(1, 3, 32, 32), - bias=torch.randn(1, 3, 32, 32), - running_mean=None, - running_var=None, - training=False, - momentum=1.0, - eps=1.0 - ) - - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.graph) - - -# CHECK: f -# CHECK: PASS - check_legal_name -@run_test -def check_legal_name(): - target = torch.ops.aten.native_batch_norm.default - kwargs = dict( - input=torch.randn(1, 3, 32, 32), - weight=torch.randn(1, 3, 32, 32), - bias=torch.randn(1, 3, 32, 32), - running_mean=None, - running_var=None, - training=False, - momentum=1.0, - eps=1.0 - ) - - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.name) - - -# CHECK: graph(%[[H3:.*]] : Tensor, -# CHECK: %[[H4:.*]] : Tensor, -# CHECK: %[[H5:.*]] : Tensor, -# CHECK: %[[H6:.*]] : Tensor, -# CHECK: %[[H7:.*]] : Tensor, -# CHECK: %out : Tensor, -# CHECK: %save_mean : Tensor, -# CHECK: %save_invstd : Tensor): -# CHECK: %[[H8:.*]] : bool = prim::Constant[value=0]() -# CHECK: %[[H9:.*]] : float = prim::Constant[value=0.10000000000000001]() -# CHECK: %[[H10:.*]] : float = prim::Constant[value=0.0001]() -# CHECK: %[[H0:.*]] : Tensor, %[[H1:.*]] : Tensor, %[[H2:.*]] : Tensor = aten::native_batch_norm(%[[H3]], %[[H4]], %[[H5]], %[[H6]], %[[H7]], %[[H8]], %[[H9]], %[[H10]], %out, %save_mean, %save_invstd) -# CHECK: return (%[[H0]], %[[H1]], %[[H2]]) -# ----- -# CHECK: PASS - correctly_order_kwargs -@run_test -def correctly_order_kwargs(): - target = torch.ops.aten.native_batch_norm.out - - input = torch.randn(2, 5, 2, 3) - running_mean = torch.randn(5) - running_var = torch.randn(5) - - kwargs = dict( - input=torch.randn(2, 5, 2, 3), - weight=torch.randn(5), - bias=torch.randn(5), - running_mean=running_mean, - running_var=running_var, - training=False, - momentum=0.1, - eps=0.0001, - out=torch.empty_like(input), - save_mean=torch.empty_like(running_mean), - save_invstd=torch.empty_like(running_var), - ) - - script_fun = build_ts_script_function(target._schema, kwargs) - print(script_fun.graph) diff --git a/python/test/eager_mode/framework.py b/python/test/eager_mode/framework.py deleted file mode 100644 index 4395ce14c..000000000 --- a/python/test/eager_mode/framework.py +++ /dev/null @@ -1,23 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -# RUN: true - - -def run_test(*args, XPASS=False, XFAIL=False): - def _run_test(test): - test_name = test.__name__ - try: - test() - print(("X" if XPASS else "") + f"PASS - {test_name}") - except Exception as e: - print(("X" if XFAIL else "") + f"FAIL - {test_name}") - print("Errors: ", e) - print() - - if len(args): - _run_test(args[0]) - else: - return _run_test diff --git a/python/test/eager_mode/normalize_args_kwargs.py b/python/test/eager_mode/normalize_args_kwargs.py deleted file mode 100644 index feea0104a..000000000 --- a/python/test/eager_mode/normalize_args_kwargs.py +++ /dev/null @@ -1,69 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -# RUN: %PYTHON %s | FileCheck %s - - -import torch - -from framework import run_test -from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs - - -# CHECK: PASS - should_normalize -@run_test -def should_normalize(): - target = torch.ops.aten.max_pool2d_with_indices.default - input = torch.randn((1, 3, 32, 32)) - kwargs = {"kernel_size": [3, 3]} - golden = { - "kernel_size": [3, 3], - # This is due to the schema for max_pool2d_with_indices defining - # the stride arg as int[2] stride=[]. - "stride": [], - "padding": [0, 0], - "dilation": [1, 1], - "ceil_mode": False, - } - - new_kwargs = normalize_args_kwargs(target, (input,), kwargs) - assert torch.allclose(new_kwargs["input"], input) - for k, v in new_kwargs.items(): - if k == "input": continue - assert v == golden[k] - - -# CHECK: FAIL - shouldnt_normalize1 -# CHECK: Errors: missing a required argument: 'kernel_size' -@run_test -def shouldnt_normalize1(): - target = torch.ops.aten.max_pool2d_with_indices.default - args = (torch.randn((1, 3, 32, 32)),) - kwargs = {"stride": []} - normalize_args_kwargs(target, args, kwargs) - - -# This next two tests are XPASS because of https://github.com/pytorch/pytorch/issues/75342 -# I.e., they should fail but in fact they pass because of the upstream bug. -# The reason for the bug is a fast path branch in operator_schemas.normalize_function -# that doesn't do rigorous type checking, and hence lets type mistmatches slip through. -# TODO(max): change these to FAIL when the upstream bug is fixed. - -# CHECK: XPASS - shouldnt_normalize2 -@run_test(XPASS=True) -def shouldnt_normalize2(): - target = torch.ops.aten.max_pool2d_with_indices.default - args = (torch.randn((1, 3, 32, 32)),) - kwargs = {"kernel_size": []} - normalize_args_kwargs(target, args, kwargs) - - -# CHECK: XPASS - shouldnt_normalize3 -@run_test(XPASS=True) -def shouldnt_normalize3(): - target = torch.ops.aten.max_pool2d_with_indices.default - args = (torch.randn((1, 3, 32, 32)),) - kwargs = {"kernel_size": [3, 3], "padding": None} - normalize_args_kwargs(target, args, kwargs) diff --git a/python/torch_mlir/eager_mode/CMakeLists.txt b/python/torch_mlir/eager_mode/CMakeLists.txt deleted file mode 100644 index 2b773a09f..000000000 --- a/python/torch_mlir/eager_mode/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -#------------------------------------------------------------------------------- -# Subdirectories -#------------------------------------------------------------------------------- - -## Declare the sources of the Python module. - -declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources - SOURCES_GLOB eager_mode/*.py lazytensor/*.py -) diff --git a/python/torch_mlir/eager_mode/__init__.py b/python/torch_mlir/eager_mode/__init__.py deleted file mode 100644 index 96929dbf8..000000000 --- a/python/torch_mlir/eager_mode/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -import os - -EAGER_MODE_DEBUG = os.environ.get("EAGER_MODE_DEBUG", 'False').lower() in ('true', '1', 't') diff --git a/python/torch_mlir/eager_mode/ir_building.py b/python/torch_mlir/eager_mode/ir_building.py deleted file mode 100644 index 8f5ec7747..000000000 --- a/python/torch_mlir/eager_mode/ir_building.py +++ /dev/null @@ -1,359 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. -""" -Translator from torch.jit.ScriptFunction to MLIR. - - -The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the -`torch` dialect. - -The expected use of this module is to create an instance of one of the classes below, and then calling the -`to_mlir` method to generate the MLIR representation of the type. - -Information about what types are supported by each class can be found in docstrings of each of the classes. - -In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module. - -The expected use for this module is to use the function -`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module` -to convert the TorchScript function into MLIR using the `torch` dialect. -""" - -import abc -import re -from typing import Any, Optional, Iterable, Dict -from typing import Union - -import numpy as np -import torch -import torch._C -import torch.jit -from torch._ops import OpOverload - -from torch_mlir import ir -from torch_mlir.dialects.func import FuncOp -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder - - -class TorchMlirType(abc.ABC): - """ - A `TorchMlirType` is an object that produces MLIR - types in the `torch` dialect. The only requirement - for a class to be a subclass of `TorchMlirType` is - to define a `to_mlir(self, ir.Context) -> ir.Type`. - Each class is allowed to have different types of - __init__ methods depending on the information they - require to produce the given MLIR representation. - """ - - @abc.abstractmethod - def to_mlir(self, context: ir.Context) -> ir.Type: - pass - - -class TorchTensorTypeError(Exception): - def __init__(self, value: str): - super().__init__() - self.value = value - - def __str__(self) -> str: - return self.value - - -class TorchTensorType(TorchMlirType): - """ - This class is used to generate types of the form - !torch.tensor and !torch.vtensor, - where SHAPE is a list representing the shape of the tensor, - and DTYPE is an MLIR data type. - """ - - def __init__( - self, - *, - shape: Optional[Iterable[Optional[int]]] = None, - dtype: Optional[torch.dtype] = None, - ): - self.shape = shape - self.dtype = dtype - - if dtype is None and shape is not None: - err = "If shape is specified, dtype must also be specified" - raise TorchTensorTypeError(err) - - def __str__(self): - return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})" - - def to_mlir(self, context: ir.Context) -> ir.Type: - if self.dtype is None: - return ir.Type.parse("!torch.tensor", context=context) - - shape_asm = self._shape_to_mlir_asm() - dtype_asm = self._dtype_to_mlir_asm() - return ir.Type.parse( - f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context - ) - - def _shape_to_mlir_asm(self) -> str: - if self.shape is None: - return "*" - - str_sizes = map(lambda x: "?" if x is None else str(x), self.shape) - return f'[{",".join(str_sizes)}]' - - def _dtype_to_mlir_asm(self) -> str: - if self.dtype in [torch.float64]: - return "f64" - if self.dtype in [torch.float, torch.float32]: - return "f32" - if self.dtype in [torch.int, torch.int32]: - return "si32" - if self.dtype in [torch.int64]: - return "si64" - if self.dtype in [torch.bool]: - return "i1" - - raise NotImplementedError(f"Unsupported dtype: {self.dtype}") - - -class TorchNnModuleType(TorchMlirType): - """This class is used to generate types for `!torch.nn.Module`s.""" - - def __init__(self, module_name: str): - self.module_name = module_name - - def __str__(self): - return "torch.nn.Module" - - def to_mlir(self, context: ir.Context) -> ir.Type: - return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context) - - -class PythonType(TorchMlirType): - """ - This class is used to convert regular Python types - into their corresponding `torch` dialect representation. - The list of supported types can be found in the dictionary - `_type_to_asm_dict`. - """ - - _type_to_asm_dict = { - bool: "!torch.bool", - int: "!torch.int", - type(None): "!torch.none", - } - - def __init__(self, type_: Any): - self.type_ = type_ - - def __str__(self): - return str(self.type_) - - def to_mlir(self, context: ir.Context) -> ir.Type: - asm = self._type_to_asm_dict.get(self.type_) - if asm is None: - raise NotImplementedError(f"Unsupported type: {self.type_}") - return ir.Type.parse(asm, context=context) - - -# TODO: This functionality should be incorporated into ModuleBuilder.import_function. -class Annotation: - def __init__(self, types: Iterable[Union[TorchTensorType, type]]): - self.types = list( - map(lambda t: PythonType(t) if isinstance(t, type) else t, types) - ) - - def __str__(self): - result = f"Annotation instance with {len(self.types)} types\n" - for e, type_ in enumerate(self.types): - result += f" Type of argument {e + 1}: {str(type_)}\n" - return result - - def __iter__(self): - return iter(self.types) - - -class AnnotationConverter: - @staticmethod - def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr: - dict_attrs = [] - for type_ in annotation: - if not isinstance(type_, TorchTensorType): - dict_attrs.append(ir.DictAttr.get({}, context=context)) - continue - - ir_type = type_.to_mlir(context) - with context: - type_attr = ir.TypeAttr.get(ir_type) - dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr}) - dict_attrs.append(dict_attr) - - return ir.ArrayAttr.get(dict_attrs, context=context) - - -def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]: - with module.context: - name_attr = ir.StringAttr.get(name) - for op in module.body.operations: - if isinstance(op, FuncOp) and op.name == name_attr: - # Add name of torch op as debug_module_name so that - # run_pipeline_with_repro_report can use it. - module.operation.attributes["torch.debug_module_name"] = name_attr - return op - - return None - - -def is_tensor_type(typ: torch._C.Type): - return typ.isSubtypeOf(torch.TensorType.get()) or ( - isinstance(typ, torch.OptionalType) - and typ.getElementType().isSubtypeOf(torch._C.TensorType.get()) - ) - - -def is_list_of_tensors_type(typ: torch._C.Type): - return isinstance(typ, torch.ListType) and is_tensor_type(typ.getElementType()) - - -name_mangle_regex = re.compile("[^a-zA-Z0-9]") - - -def build_ts_script_function( - schema: torch._C.FunctionSchema, kwargs: Dict[str, Any] -) -> torch.jit.ScriptFunction: - """Build a torch.jit.ScriptFunction that corresponds to the schema. - - Constants are inlined for the purposes of invalidating the compile cache when they change. - - Parameters - ---------- - schema: torch._C.FunctionSchema - PyTorch's representation for ops, contains type information needed for inlining constants into the TS graph. - kwargs: Dict - A dictionary with all arguments passed in through __torch_dispatch__ (including int/float/bool params). - - Returns - ------- - torch.jit.ScriptFunction - Fully specialized (all constants) TS graph whose only arguments are tensors. - """ - - # Creates empty TS graph. - graph = torch._C.Graph() - # Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point. - node = graph.insertNode(graph.create(schema.name, len(schema.returns))) - # Associate graph inputs/outputs with node inputs/outputs. - graph_inputs = [] - for arg in schema.arguments: - arg_name = arg.name if arg.name != "self" else "input" - - # If arg is a flattened list of tensors, such as in the case of torch.cat - # then add each element of the list to the graph corresponding to arg - # and insert a ListConstruct to function as input to the op. - if is_list_of_tensors_type(arg.type): - inps = [] - for kwarg in [ - kwarg for kwarg in kwargs if f"{arg_name}_flattened" in kwarg - ]: - inp = graph.addInput() - el_typ = arg.type.getElementType() - if isinstance(el_typ, torch.OptionalType): - el_typ = el_typ.getElementType() - inp.setType(el_typ) - inp.setDebugName(kwarg) - inps.append(inp) - graph_inputs.append(kwarg) - list_cons = graph.insertNode(graph.create("prim::ListConstruct", inps)) - list_cons.moveBefore(node) - inp = list_cons.output() - inp.setType(torch.ListType.ofTensors()) - # If arg is a tensor, then add input to the graph corresponding to arg. - elif is_tensor_type(arg.type) and kwargs[arg_name] is not None: - inp = graph.addInput() - if isinstance(arg.type, torch.OptionalType): - el_typ = arg.type.getElementType() - else: - el_typ = arg.type - inp.setType(el_typ) - inp.setDebugName(arg_name) - graph_inputs.append(arg_name) - # If arg is a constant, inline (at the top of the graph). - else: - val = kwargs[arg_name] - if val == []: - # Some ops have empty list default values for args - # (such as aten::max_pool2d_with_indices with int[2] stride=[] - # but graph.insertConstant doesnt' recognize [] as an empty list IValue. - # This might be an upstream bug but there doesn't seem to be a way to - # build a prim::ListConstruct list that's empty. - val = None - inp = graph.insertConstant(val) - inp.node().moveBefore(node) - - node.addInput(inp) - - # Reorder graph inputs to match kwargs. - permutes = [ - {inp: i for i, inp in enumerate(graph_inputs)}[kwarg] - for kwarg in [kwarg for kwarg in kwargs if kwarg in graph_inputs] - ] - graph.permuteInputs(permutes) - - if node.hasMultipleOutputs(): - for outp in node.outputs(): - graph.registerOutput(outp) - else: - graph.registerOutput(node.output()) - - fn = torch._C._create_function_from_graph( - f"{name_mangle_regex.sub('', str(graph))}", graph - ) - return fn - - -def build_mlir_module(op: OpOverload, kwargs: Dict[str, Any]) -> ir.Module: - """Translate input function into an MLIR module in the `torch` dialect. - - Parameters - ---------- - op: OpOverload - Callable from the torch.ops.aten module/namespace that has a _schema field. - kwargs: Dict - A dictionary with all arguments passed in through __torch_dispatch__ (including int/float,bool params). - - Returns - ------- - ir.Module - Translation of the input module into an MLIR module. - """ - - # The assert here is to catch tensor shapes that have size 0 dimensions, such as those produced in - # the course of evaluating SliceEndSleStartModule_basic and SliceOutOfLowerBoundEndIndexModule_basic. - # Such 0 size dimensions fail the assert at mlir/lib/IR/BuiltinTypes.cpp, line 887 - annotations = [] - for arg_name, arg in kwargs.items(): - if isinstance(arg, torch.Tensor): - assert np.prod(arg.shape) != 0, f"{arg_name} has invalid shape {arg.shape}" - annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg.dtype)) - annotations = tuple(annotations) - - script_fun = build_ts_script_function(op._schema, kwargs) - assert len(annotations) == len( - list(script_fun.graph.inputs()) - ), "Number of annotations and number of graph inputs differs." - - mb = ModuleBuilder() - mb.import_function(script_fun) - - func_op = get_func_op_with_name(mb.module, script_fun.name) - assert ( - func_op is not None - ), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder" - - func_annotation = Annotation(annotations) - arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context) - func_op.attributes["arg_attrs"] = arg_attrs - - return mb.module diff --git a/python/torch_mlir/eager_mode/torch_mlir_dispatch.py b/python/torch_mlir/eager_mode/torch_mlir_dispatch.py deleted file mode 100644 index cc7484b93..000000000 --- a/python/torch_mlir/eager_mode/torch_mlir_dispatch.py +++ /dev/null @@ -1,111 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. -from __future__ import annotations - -from typing import Any, Callable, Tuple -from typing import Dict - -import torch -from torch.fx import immutable_collections -from torch.fx.operator_schemas import ( - _torchscript_schema_to_signature, - _args_kwargs_to_normalized_args_kwargs, -) -from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops - -from torch_mlir.dialects import torch as torch_dialect - -OP_REGISTRY = {op["name"]: op for op in get_registered_ops()} -SUPPORTED_OPS = frozenset( - [ - member.OPERATION_NAME - for member in vars(torch_dialect).values() - if hasattr(member, "OPERATION_NAME") - ] -) - - -class UnsupportedByTorchMlirEagerMode(Exception): - def __init__(self, value: str): - super().__init__() - self.value = value - - def __str__(self) -> str: - return self.value - - -def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]): - """Fill in default values for optional args, which are dependent on the schema.""" - sig = _torchscript_schema_to_signature(target._schema) - _, new_kwargs = _args_kwargs_to_normalized_args_kwargs( - sig, args, kwargs, normalize_to_only_use_kwargs=True - ) - if "self" in new_kwargs: - new_kwargs["input"] = new_kwargs.pop("self") - - # Flatten lists of args for ops that takes lists, such as torch.cat. - to_remove = set() - to_add = {} - for k, v in new_kwargs.items(): - if isinstance(v, (tuple, list)) and len(v) and isinstance(v[0], torch.Tensor): - to_remove.add(k) - for i, vv in enumerate(v): - to_add[f"{k}_flattened_{i}"] = vv - - for rem in to_remove: - del new_kwargs[rem] - new_kwargs.update(**to_add) - - # Sort here in order to have consistency across TS graph and - # MLIR module. - sorted_kwargs = dict(sorted(new_kwargs.items())) - return immutable_collections.immutable_dict(sorted_kwargs) - - -def get_registered_op(op): - registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)] - return registered_op - - -def check_get_aliased_arg(func: Callable,): - """Write back to mutable args that aren't properly handled otherwise. - - Because of how we pass values to the backend we don't currently support ops that mutate operands. - That includes both inplace variants and outplace variants. Additionally, Torch-MLIR, - as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which - we can for these special cases. Hence, the solution is to manually write back to the same operand that the - conventional pytorch op variant would write to. - - Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that - mutate running_mean and running_var). We don't currently handle those. - """ - - registered_op = get_registered_op(func) - if not registered_op["is_mutable"]: - return None - - if len(registered_op["returns"]) > 1: - raise UnsupportedByTorchMlirEagerMode( - "TorchMLIR doesn't handle multiple aliased returns yet." - ) - - aliased_arg = next( - arg - for arg in registered_op["arguments"] - if "alias_info" in arg and arg["alias_info"]["is_write"] - ) - assert ( - "alias_info" in registered_op["returns"][0] - and registered_op["returns"][0]["alias_info"]["is_write"] - and len(registered_op["returns"][0]["alias_info"]["after"]) == 1 - and registered_op["returns"][0]["alias_info"]["after"][0] - ) - assert ( - len(aliased_arg["alias_info"]["after"]) == 1 - and aliased_arg["alias_info"]["after"][0] - == registered_op["returns"][0]["alias_info"]["after"][0] - ) - - return aliased_arg["name"] if aliased_arg["name"] != "self" else "input" diff --git a/python/torch_mlir/eager_mode/torch_mlir_eager_backend.py b/python/torch_mlir/eager_mode/torch_mlir_eager_backend.py deleted file mode 100644 index 0cac7cc46..000000000 --- a/python/torch_mlir/eager_mode/torch_mlir_eager_backend.py +++ /dev/null @@ -1,102 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import abc -from dataclasses import dataclass -from typing import TypeVar, Tuple, Callable, List, Dict, Any - -import torch - -from torch_mlir._mlir_libs._mlir.ir import Module - -# TODO: This might need to be an ABC too, such as -# to support finding the backend that created the tensor. -DeviceTensor = TypeVar("DeviceTensor") - - -@dataclass(frozen=True) -class TensorMetaData: - """A small container for metadata necessary for satisfying the pytorch dispatcher and other code (pytorch or - otherwise) that branches on these attributes. - - There is a lot of code in the PyTorch codebase that branches based on these attributes; the obvious ones here - are dtype, device, and requires_grad (necessary for autograd itself). There is ample warning from PyTorch that, - in principle, these should be as close as possible to true; see - https://github.com/albanD/subclass_zoo/blob/1566e038f03cd89ab3cc37e670a44e3c2bbc1897/trivial_tensors.py#L90-L92 - - The defaults (properties) simplify the api and seem to work after some testing but - might malfunction in unexpected ways. - # TODO: revisit these assumptions - """ - - size: Tuple[int] - dtype: torch.dtype - requires_grad: bool - - strides: Tuple[int] - storage_offset: int = 0 - layout: torch.layout = torch.strided - device: torch.device = torch.device("cpu") - - def __init__( - self, - size, - dtype, - requires_grad, - strides=None, - storage_offset=None, - layout=None, - device=None, - ): - super().__init__() - object.__setattr__(self, "size", size) - object.__setattr__(self, "dtype", dtype) - object.__setattr__(self, "requires_grad", requires_grad) - - object.__setattr__( - self, "strides", strides if strides is not None else len(size) * [0] - ) - object.__setattr__( - self, "storage_offset", storage_offset if storage_offset is not None else 0 - ) - object.__setattr__( - self, "layout", layout if layout is not None else torch.strided - ) - object.__setattr__( - self, "device", device if device is not None else torch.device("cpu") - ) - - -class TorchMLIREagerBackend(abc.ABC): - @abc.abstractmethod - def compile( - self, module: Module - ) -> Callable[[List[DeviceTensor]], List[DeviceTensor]]: - raise NotImplementedError - - @abc.abstractmethod - def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> DeviceTensor: - """Unwrap the backend representation in order to build a torch.Tensor.""" - raise NotImplementedError - - @abc.abstractmethod - def get_torch_metadata( - self, tensor: DeviceTensor, kwargs: Dict[str, Any] - ) -> TensorMetaData: - """Parse relevant tensor metadata from backend device array (e.g., shape, stride, layout) in order to build - wrapper tensor.""" - raise NotImplementedError - - @abc.abstractmethod - def transfer_from_device_to_torch(self, tensor: DeviceTensor) -> torch.Tensor: - """If compilation fails for some reason then device specific representations need to be munged into a - torch.Tensor representation. - """ - raise NotImplementedError - - @abc.abstractmethod - def copy_into(self, dst: DeviceTensor, src: DeviceTensor): - """This method is needed for things like handling aliased arguments.""" - raise NotImplementedError diff --git a/python/torch_mlir/eager_mode/torch_mlir_tensor.py b/python/torch_mlir/eager_mode/torch_mlir_tensor.py deleted file mode 100644 index a30f47a9e..000000000 --- a/python/torch_mlir/eager_mode/torch_mlir_tensor.py +++ /dev/null @@ -1,257 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. -import contextlib -import re -import traceback -import warnings -from typing import Any - -import torch -from torch.utils._pytree import tree_map - -from torch_mlir.eager_mode.ir_building import build_mlir_module -from torch_mlir.eager_mode.torch_mlir_dispatch import ( - UnsupportedByTorchMlirEagerMode, - normalize_args_kwargs, - check_get_aliased_arg, -) -from torch_mlir.eager_mode import EAGER_MODE_DEBUG -from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend - - -@contextlib.contextmanager -def no_dispatch(): - """Prevent infinite recursion in case accidentally calling a tensor method on a TorchMLIRTensor within - __torch_dispatch__.""" - - guard = torch._C._DisableTorchDispatch() - try: - yield - finally: - del guard - - -backend = EagerModeRefBackend() - -UNSUPPORTED_OPS = re.compile( - "|".join([ - # We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch. - "detach", - # We don't handle _local_scalar_dense because it's just a way to unwrap a tensor that wraps a number. - "_local_scalar_dense", - # https://github.com/llvm/torch-mlir/issues/878 - "_unsafe_view", - "view", - ]) -) - - -class TorchMLIRTensor(torch.Tensor): - """This class serves the role abstract class with common functionality for dispatching through Torch-MLIR instead of aten. - - It defers device specific behavior to device specific implementations. The deriving classes use the - make_bare_wrapper_subclass convenience method, adjacent here, and override __torch_dispatch__ in order to dispatch - through Torch-MLIR instead of aten. Backends are free to choose whatever representation of the buffers (i.e., `elem`) - and are expected to provide conversion mechanisms between their representation and torch.Tensor. - - Here we only verify that inputs abide by current supported features of Torch-MLIR (contiguous memory and - strided tensor layout) and build the mlir module. Importantly, we also recover from any malfunctions in the - deriving classes and dispatch back to conventional PyTorch. - - More documentation on how the __torch_dispatch__ pattern works can be found in this forum post - https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557 - and this RFC - https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call - and this repo with many examples - https://github.com/albanD/subclass_zoo - """ - - elem: Any - - __slots__ = ["elem"] - - def __new__(cls, elem, **kwargs): - """Wrap elem (which could be a torch.Tensor or otherwise) in a torch.Tensor subclass. - - Critically, this method needs to parse relevant metadata from the device representation - (such as shape, striding, dtype, etc.) and translate it into torch conventions. - - Deriving classes must provide a way to construct themselves from either their device specific representation - or torch.Tensor; the latter is to handle the case that dispatch to PyTorch to recover from an error. - """ - if kwargs.get("constructing_from_device_tensor", False): - tensor_meta_data = backend.get_torch_metadata(elem, kwargs) - r = make_bare_wrapper_subclass( - cls=cls, - size=tensor_meta_data.size, - strides=tensor_meta_data.strides, - storage_offset=tensor_meta_data.storage_offset, - dtype=tensor_meta_data.dtype, - layout=tensor_meta_data.layout, - device=tensor_meta_data.device, - requires_grad=tensor_meta_data.requires_grad, - ) - r.elem = elem - elif isinstance(elem, torch.nn.Parameter): - r = make_wrapper_subclass_from_torch_tensor(cls, elem.data, **kwargs) - r.elem = backend.transfer_from_torch_to_device(elem.detach().data) - elif isinstance(elem, torch.Tensor): - r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs) - r.elem = backend.transfer_from_torch_to_device(elem) - # This branch handles the case when a python scalar is passed to some op - # or is returned from some aten op, such as _local_scalar_dense. - elif isinstance(elem, (int, float, bool)): - return elem - else: - raise ValueError(f"Unknown element type: {type(elem)}") - - return r - - def __repr__(self): - if self.grad_fn: - return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})" - else: - return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__})" - - @classmethod - def __torch_dispatch__(cls, func, _types, args=(), kwargs=None): - requires_grad = check_requires_grad(*args, **kwargs) - try: - with no_dispatch(): - if hasattr(func, "op_name"): - op_name = func.op_name - elif hasattr(func, "__name__"): - # Handle builtin_function_or_method. - op_name = func.__name__ - else: - raise RuntimeError(f"op {func} has no name") - - requires_grad = requires_grad and "view" not in op_name - - if UNSUPPORTED_OPS.match(op_name): - raise UnsupportedByTorchMlirEagerMode(op_name) - - if not hasattr(func, "_schema"): - raise RuntimeError(f"op {func} has no schema.") - - normalized_kwargs = normalize_args_kwargs(func, args, kwargs) - - if "layout" in normalized_kwargs and normalized_kwargs[ - "layout" - ] not in {0, None}: - raise UnsupportedByTorchMlirEagerMode( - f"{normalized_kwargs['layout']} layout not supported." - ) - if "memory_format" in normalized_kwargs and normalized_kwargs[ - "memory_format" - ] not in {0, None}: - raise UnsupportedByTorchMlirEagerMode( - f"{normalized_kwargs['memory_format']} memory format not supported." - ) - eager_module = build_mlir_module(func, normalized_kwargs) - device_tensor_args = [ - kwarg.elem - for _, kwarg in normalized_kwargs.items() - if isinstance(kwarg, cls) - ] - assert len(eager_module.body.operations[0].arguments) == len( - device_tensor_args - ), "Number of parameters and number of arguments differs." - op_mlir_backend_callable = backend.compile(eager_module) - out = op_mlir_backend_callable(*device_tensor_args) - out = tree_map( - lambda x: cls( - x, requires_grad=requires_grad, constructing_from_device_tensor=True - ), - out, - ) - except Exception as e: - if EAGER_MODE_DEBUG: - warnings.warn(traceback.format_exc()) - if isinstance(e, UnsupportedByTorchMlirEagerMode): - warnings.warn( - f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager." - ) - else: - warnings.warn( - f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; " - f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues" - ) - - with no_dispatch(): - unwrapped_args = tree_map(cls.unwrap, args) - unwrapped_kwargs = tree_map(cls.unwrap, kwargs) - out = func(*unwrapped_args, **unwrapped_kwargs) - - out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out) - - maybe_aliased_arg_name = check_get_aliased_arg(func) - if maybe_aliased_arg_name is not None: - backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem) - - return out - - @classmethod - def unwrap(cls, e): - """Unwrap the TorchMLIRTensor representation in order to access the actual device specific representation.""" - if isinstance(e, cls): - return backend.transfer_from_device_to_torch(e.elem) - return e - - -def check_requires_grad(*args, **kwargs): - requires_grad = False - - def check_grad(e): - nonlocal requires_grad - if isinstance(e, TorchMLIRTensor): - requires_grad |= e.requires_grad - - tree_map(check_grad, args) - tree_map(check_grad, kwargs) - - return requires_grad - - -def make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs): - """Convenience method that parse out relevant metadata from a torch.Tensor, in order to produce - a wrapper subclass. - - NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility - of the device specific implementation. - """ - r = make_bare_wrapper_subclass( - cls=cls, - size=elem.size(), - strides=elem.stride(), - storage_offset=elem.storage_offset(), - dtype=elem.dtype, - layout=elem.layout, - device=elem.device, - # Only float tensors can have gradients. - requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64} - and (kwargs.get("requires_grad", False) or elem.requires_grad), - ) - return r - - -def make_bare_wrapper_subclass( - *, cls, size, strides, storage_offset, dtype, layout, device, requires_grad -): - """Convenience method that builds a wrapper subclass. - - NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility - of the device specific implementation. - """ - return torch.Tensor._make_wrapper_subclass( - cls, - size, - strides=strides, - storage_offset=storage_offset, - dtype=dtype, - layout=layout, - device=device, - requires_grad=requires_grad, - ) diff --git a/python/torch_mlir_e2e_test/configs/__init__.py b/python/torch_mlir_e2e_test/configs/__init__.py index 1f57a029b..36fab40bd 100644 --- a/python/torch_mlir_e2e_test/configs/__init__.py +++ b/python/torch_mlir_e2e_test/configs/__init__.py @@ -9,5 +9,4 @@ from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig from .mhlo_backend import MhloBackendTestConfig from .tosa_backend import TosaBackendTestConfig -from .eager_mode import EagerModeTestConfig from .torchdynamo import TorchDynamoTestConfig diff --git a/python/torch_mlir_e2e_test/configs/eager_mode.py b/python/torch_mlir_e2e_test/configs/eager_mode.py deleted file mode 100644 index 157ef0f36..000000000 --- a/python/torch_mlir_e2e_test/configs/eager_mode.py +++ /dev/null @@ -1,65 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import torch -from torch.utils._pytree import tree_map - -from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor -from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem - - -def wrap(e): - return TorchMLIRTensor(e.detach().clone()) if isinstance(e, torch.Tensor) else e - - -def unwrap(e): - return TorchMLIRTensor.unwrap(e) if isinstance(e, TorchMLIRTensor) else e - - -def to_tmt(m: torch.nn.Module): - for buf_name, buf in m.named_buffers(recurse=True): - if isinstance(buf, TorchMLIRTensor): - continue - m.register_buffer(buf_name, TorchMLIRTensor(buf)) - for param_name, param in m.named_parameters(recurse=True): - if isinstance(param, TorchMLIRTensor): - continue - m.register_parameter( - param_name, - torch.nn.Parameter( - TorchMLIRTensor(param), requires_grad=param.requires_grad - ), - ) - for attr in dir(m): - field = getattr(m, attr) - if isinstance(field, torch.Tensor) and not isinstance(field, TorchMLIRTensor): - setattr(m, attr, TorchMLIRTensor(field)) - - -class EagerModeTestConfig(TestConfig): - """Trivial test config that exercises eager mode plumbing""" - - def __init__(self): - super().__init__() - - def compile(self, program: torch.nn.Module) -> torch.nn.Module: - program.apply(to_tmt) - return program - - def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: - result: Trace = [] - for item in trace: - attr = artifact - for part in item.symbol.split("."): - attr = getattr(attr, part) - - inps = tree_map(wrap, item.inputs) - outps = attr(*inps) - output = tree_map(unwrap, outps) - - result.append( - TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) - ) - return result diff --git a/python/torch_mlir_e2e_test/eager_backends/refbackend.py b/python/torch_mlir_e2e_test/eager_backends/refbackend.py deleted file mode 100644 index 70b9e5911..000000000 --- a/python/torch_mlir_e2e_test/eager_backends/refbackend.py +++ /dev/null @@ -1,90 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -from __future__ import annotations - -from typing import Dict, Any - -import numpy as np -import torch - -from torch_mlir.compiler_utils import ( - get_module_name_for_debug_dump, - run_pipeline_with_repro_report, -) -from torch_mlir.eager_mode.torch_mlir_eager_backend import ( - TorchMLIREagerBackend, - TensorMetaData, -) -from torch_mlir.ir import Module -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( - RefBackendLinalgOnTensorsBackend, -) - -NUMPY_TO_TORCH_DTYPE_DICT = { - np.bool_: torch.bool, - np.uint8: torch.uint8, - np.int8: torch.int8, - np.int16: torch.int16, - np.int32: torch.int32, - np.int64: torch.int64, - np.float16: torch.float16, - np.float32: torch.float32, - np.float64: torch.float64, - np.complex64: torch.complex64, - np.complex128: torch.complex128, -} - -_ref_backend = RefBackendLinalgOnTensorsBackend() - - -class EagerModeRefBackend(TorchMLIREagerBackend): - """Main entry-point for the reference backend for eager mode. - - RefBackend uses numpy.ndarray representations of tensors and thus all of the wrapping and unwrapping - and munging here is done to between torch.Tensor and numpy.ndarray. - """ - - module_to_refbackend_invoker = {} - - def get_torch_metadata( - self, tensor: np.ndarray, kwargs: Dict[str, Any] - ) -> TensorMetaData: - return TensorMetaData( - size=tensor.shape, - dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type], - requires_grad=tensor.dtype in {np.float, np.float32, np.float64} - and kwargs.get("requires_grad", False), - ) - - def compile(self, imported_module: Module): - """Lower the imported TS module to linalg and then further compile for the reference backend and then call.""" - fn_name = get_module_name_for_debug_dump(imported_module) - module_hash = str(imported_module) - if module_hash not in self.module_to_refbackend_invoker: - run_pipeline_with_repro_report( - imported_module, - "builtin.module(torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline)", - "EagerMode", - ) - self.module_to_refbackend_invoker[module_hash] = _ref_backend.load( - _ref_backend.compile(imported_module) - ) - - ref_backend_invoker = self.module_to_refbackend_invoker[module_hash] - op_mlir_backend_callable = getattr(ref_backend_invoker, fn_name) - assert ( - op_mlir_backend_callable is not None - ), f"Couldn't find function in module." - return op_mlir_backend_callable - - def copy_into(self, dst: np.ndarray, src: np.ndarray): - np.copyto(dst, src) - - def transfer_from_device_to_torch(self, e: np.ndarray): - return torch.from_numpy(e).clone() - - def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> np.ndarray: - return tensor.detach().numpy()