mirror of https://github.com/llvm/torch-mlir
Remove eager_mode
This was an experimental attempt at rolling out own op-by-op executor with `__torch_dispatch__`, but it proved difficult to make it robust. Op-by-op execution is very easy to implement robustly now with the PyTorch 2.0 stack, so we don't need eager_mode. Downstream users were using eager_mode to implement lockstep numerical accuracy debuggers. We implemented the same functionality with TorchDynamo in https://github.com/llvm/torch-mlir/pull/1681 so now there is not much reason to continue maintaining it.pull/1692/head
parent
109c91ae9b
commit
7731211d02
|
@ -75,14 +75,6 @@ torch-mlir prediction
|
||||||
|
|
||||||
View examples [here](docs/ltc_examples.md).
|
View examples [here](docs/ltc_examples.md).
|
||||||
|
|
||||||
### Eager Mode
|
|
||||||
|
|
||||||
Eager mode with TorchMLIR is a very experimental eager mode backend for PyTorch through the torch-mlir framework.
|
|
||||||
Effectively, this mode works by compiling operator by operator as the NN is eagerly executed by PyTorch.
|
|
||||||
This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported operator).
|
|
||||||
A simple example can be found at [eager_mode.py](examples/eager_mode.py).
|
|
||||||
A ResNet18 example can be found at [eager_mode_resnet18.py](examples/eager_mode_resnet18.py).
|
|
||||||
|
|
||||||
## Repository Layout
|
## Repository Layout
|
||||||
|
|
||||||
The project follows the conventions of typical MLIR-based projects:
|
The project follows the conventions of typical MLIR-based projects:
|
||||||
|
|
|
@ -262,9 +262,6 @@ function test_in_tree() {
|
||||||
echo ":::: Run Linalg e2e integration tests"
|
echo ":::: Run Linalg e2e integration tests"
|
||||||
python -m e2e_testing.main --config=linalg -v
|
python -m e2e_testing.main --config=linalg -v
|
||||||
|
|
||||||
echo ":::: Run eager_mode e2e integration tests"
|
|
||||||
python -m e2e_testing.main --config=eager_mode -v
|
|
||||||
|
|
||||||
echo ":::: Run MHLO e2e integration tests"
|
echo ":::: Run MHLO e2e integration tests"
|
||||||
python -m e2e_testing.main --config=mhlo -v
|
python -m e2e_testing.main --config=mhlo -v
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ from torch_mlir_e2e_test.configs import (
|
||||||
NativeTorchTestConfig,
|
NativeTorchTestConfig,
|
||||||
TorchScriptTestConfig,
|
TorchScriptTestConfig,
|
||||||
TosaBackendTestConfig,
|
TosaBackendTestConfig,
|
||||||
EagerModeTestConfig,
|
|
||||||
TorchDynamoTestConfig,
|
TorchDynamoTestConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,14 +27,14 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackend
|
||||||
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
|
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
|
||||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||||
|
|
||||||
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||||
register_all_tests()
|
register_all_tests()
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core', 'torchdynamo']
|
config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'lazy_tensor_core', 'torchdynamo']
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
parser.add_argument('-c', '--config',
|
parser.add_argument('-c', '--config',
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
|
@ -47,7 +46,6 @@ Meaning of options:
|
||||||
"tosa": run through torch-mlir's default TOSA backend.
|
"tosa": run through torch-mlir's default TOSA backend.
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
"eager_mode": run through torch-mlir's eager mode frontend, using Linalg-on-Tensors for execution.
|
|
||||||
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
|
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
|
||||||
"torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors.
|
"torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors.
|
||||||
''')
|
''')
|
||||||
|
@ -91,9 +89,6 @@ def main():
|
||||||
elif args.config == 'torchscript':
|
elif args.config == 'torchscript':
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
xfail_set = {}
|
xfail_set = {}
|
||||||
elif args.config == 'eager_mode':
|
|
||||||
config = EagerModeTestConfig()
|
|
||||||
xfail_set = EAGER_MODE_XFAIL_SET
|
|
||||||
elif args.config == 'lazy_tensor_core':
|
elif args.config == 'lazy_tensor_core':
|
||||||
config = LazyTensorCoreTestConfig()
|
config = LazyTensorCoreTestConfig()
|
||||||
xfail_set = LTC_XFAIL_SET
|
xfail_set = LTC_XFAIL_SET
|
||||||
|
|
|
@ -14,13 +14,6 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||||
|
|
||||||
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||||
|
|
||||||
EAGER_MODE_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|
||||||
# RefBackend fails for some reason.
|
|
||||||
# These tests pass in the regular RefBackend flow, so it's unclear
|
|
||||||
# why they fail here.
|
|
||||||
"Matmul_vecmat",
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCHDYNAMO_XFAIL_SET = {
|
TORCHDYNAMO_XFAIL_SET = {
|
||||||
#### General TorchDynamo/PyTorch errors
|
#### General TorchDynamo/PyTorch errors
|
||||||
|
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
|
||||||
|
|
||||||
torch_a = torch.randn(5, requires_grad=True)
|
|
||||||
torch_b = torch.randn(5, requires_grad=True)
|
|
||||||
|
|
||||||
torch_c = torch_a + torch_b
|
|
||||||
torch_d = torch_a * torch_b
|
|
||||||
torch_e = torch_c / torch_d
|
|
||||||
torch_loss = torch_e.sum()
|
|
||||||
print("PyTorch loss: ", torch_loss)
|
|
||||||
|
|
||||||
torch_loss.backward()
|
|
||||||
print("PyTorch grad a: ", torch_a.grad)
|
|
||||||
print("PyTorch grad b: ", torch_b.grad)
|
|
||||||
|
|
||||||
a = TorchMLIRTensor(torch_a)
|
|
||||||
b = TorchMLIRTensor(torch_b)
|
|
||||||
|
|
||||||
c = a + b
|
|
||||||
d = a * b
|
|
||||||
e = c / d
|
|
||||||
loss = e.sum()
|
|
||||||
print("Torch-MLIR loss: ", loss)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
print("Torch-MLIR grad a: ", a.grad)
|
|
||||||
print("Torch-MLIR grad b: ", b.grad)
|
|
|
@ -1,89 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
import torchvision.models as models
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_preprocess_image(url: str):
|
|
||||||
headers = {
|
|
||||||
'User-Agent':
|
|
||||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
|
||||||
}
|
|
||||||
img = Image.open(requests.get(url, headers=headers,
|
|
||||||
stream=True).raw).convert("RGB")
|
|
||||||
# preprocessing pipeline
|
|
||||||
preprocess = transforms.Compose([
|
|
||||||
transforms.Resize(256),
|
|
||||||
transforms.CenterCrop(224),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
||||||
std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
img_preprocessed = preprocess(img)
|
|
||||||
return torch.unsqueeze(img_preprocessed, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def load_labels():
|
|
||||||
classes_text = requests.get(
|
|
||||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
|
||||||
stream=True,
|
|
||||||
).text
|
|
||||||
labels = [line.strip() for line in classes_text.splitlines()]
|
|
||||||
return labels
|
|
||||||
|
|
||||||
|
|
||||||
def top3_possibilities(res):
|
|
||||||
_, indexes = torch.sort(res, descending=True)
|
|
||||||
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
|
||||||
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
|
||||||
return top3
|
|
||||||
|
|
||||||
|
|
||||||
def predictions(torch_func, img, labels):
|
|
||||||
golden_prediction = top3_possibilities(torch_func(img))
|
|
||||||
print("PyTorch prediction")
|
|
||||||
print(golden_prediction)
|
|
||||||
prediction = top3_possibilities(torch_func(TorchMLIRTensor(img)))
|
|
||||||
print("torch-mlir prediction")
|
|
||||||
print(prediction)
|
|
||||||
|
|
||||||
|
|
||||||
class ResNet18Module(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.resnet = models.resnet18(pretrained=True)
|
|
||||||
self.train(False)
|
|
||||||
|
|
||||||
def forward(self, img):
|
|
||||||
return self.resnet.forward(img)
|
|
||||||
|
|
||||||
|
|
||||||
class TestModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.s = ResNet18Module()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.s.forward(x)
|
|
||||||
|
|
||||||
|
|
||||||
image_url = (
|
|
||||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("load image from " + image_url, file=sys.stderr)
|
|
||||||
img = load_and_preprocess_image(image_url)
|
|
||||||
labels = load_labels()
|
|
||||||
|
|
||||||
test_module = TestModule()
|
|
||||||
predictions(test_module.forward, img, labels)
|
|
|
@ -102,12 +102,6 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
||||||
add_subdirectory(torch_mlir_e2e_test)
|
add_subdirectory(torch_mlir_e2e_test)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# Eager mode
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
add_subdirectory(torch_mlir/eager_mode)
|
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Custom op example
|
# Custom op example
|
||||||
# Required for running the update_torch_ods.sh and update_shape_lib.sh scripts.
|
# Required for running the update_torch_ods.sh and update_shape_lib.sh scripts.
|
||||||
|
|
|
@ -1,320 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
# RUN: %PYTHON %s | FileCheck %s
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from framework import run_test
|
|
||||||
from torch_mlir.eager_mode.ir_building import build_ts_script_function
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%[[A1:.*]] : Tensor,
|
|
||||||
# CHECK: %[[A2:.*]] : Tensor,
|
|
||||||
# CHECK: %[[A3:.*]] : Tensor):
|
|
||||||
# CHECK: %[[A4:.*]] : int = prim::Constant[value=1]()
|
|
||||||
# CHECK: %[[A5:.*]] : int = prim::Constant[value=1]()
|
|
||||||
# CHECK: %[[A0:.*]] : Tensor = aten::addmm(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]])
|
|
||||||
# CHECK: return (%[[A0]])
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - simple
|
|
||||||
@run_test
|
|
||||||
def simple():
|
|
||||||
target = torch.ops.aten.addmm.default
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn(1, 3, 32, 32),
|
|
||||||
mat1=torch.randn(1, 3, 32, 32),
|
|
||||||
mat2=torch.randn(1, 3, 32, 32),
|
|
||||||
beta=1,
|
|
||||||
alpha=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.graph)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%[[B1:.*]] : Tensor,
|
|
||||||
# CHECK: %[[B2:.*]] : Tensor,
|
|
||||||
# CHECK: %[[B3:.*]] : Tensor):
|
|
||||||
# CHECK: %[[B4:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
|
||||||
# CHECK: %[[B5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
|
||||||
# CHECK: %[[B6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
|
||||||
# CHECK: %[[B7:.*]] : bool = prim::Constant[value=0]()
|
|
||||||
# CHECK: %[[B8:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
|
||||||
# CHECK: %[[B9:.*]] : int = prim::Constant[value=1]()
|
|
||||||
# CHECK: %[[B0:.*]] : Tensor = aten::convolution(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[B9]])
|
|
||||||
# CHECK: return (%[[B0]])
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - handle_optional_tensor_input
|
|
||||||
@run_test
|
|
||||||
def handle_optional_tensor_input():
|
|
||||||
target = torch.ops.aten.convolution.default
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn(1, 3, 32, 32),
|
|
||||||
weight=torch.randn(3, 3, 3, 3),
|
|
||||||
bias=torch.randn(3),
|
|
||||||
stride=[1, 1],
|
|
||||||
padding=[0, 0],
|
|
||||||
dilation=[1, 1],
|
|
||||||
transposed=False,
|
|
||||||
output_padding=[0, 0],
|
|
||||||
groups=1,
|
|
||||||
)
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.graph)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: FAIL - fail_not_enough_args
|
|
||||||
# CHECK: Errors: 'groups'
|
|
||||||
@run_test
|
|
||||||
def fail_not_enough_args():
|
|
||||||
target = torch.ops.aten.convolution.default
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn(1, 3, 32, 32),
|
|
||||||
weight=torch.randn(3, 3, 3, 3),
|
|
||||||
bias=torch.randn(3),
|
|
||||||
stride=[1, 1],
|
|
||||||
padding=[0, 0],
|
|
||||||
dilation=[1, 1],
|
|
||||||
transposed=False,
|
|
||||||
output_padding=[0, 0],
|
|
||||||
# Missing groups=1,
|
|
||||||
)
|
|
||||||
build_ts_script_function(target._schema, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%input : Tensor,
|
|
||||||
# CHECK: %weight : Tensor,
|
|
||||||
# CHECK: %bias : Tensor):
|
|
||||||
# CHECK: %4 : int[] = prim::Constant[value=[1, 1]]()
|
|
||||||
# CHECK: %5 : int[] = prim::Constant[value=[0, 0]]()
|
|
||||||
# CHECK: %6 : int[] = prim::Constant[value=[1, 1]]()
|
|
||||||
# CHECK: %7 : bool = prim::Constant[value=0]()
|
|
||||||
# CHECK: %8 : int[] = prim::Constant[value=[0, 0]]()
|
|
||||||
# CHECK: %9 : int = prim::Constant[value=1]()
|
|
||||||
# CHECK: %0 : Tensor = aten::convolution(%input, %weight, %bias, %4, %5, %6, %7, %8, %9)
|
|
||||||
# CHECK: return (%0)
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - simple_kwargs
|
|
||||||
@run_test
|
|
||||||
def simple_kwargs():
|
|
||||||
target = torch.ops.aten.convolution.default
|
|
||||||
script_fun1 = build_ts_script_function(
|
|
||||||
target._schema,
|
|
||||||
dict(
|
|
||||||
input=torch.randn(1, 3, 32, 32),
|
|
||||||
weight=torch.randn(3, 3, 3, 3),
|
|
||||||
bias=torch.randn(3),
|
|
||||||
stride=[1, 1],
|
|
||||||
padding=[0, 0],
|
|
||||||
dilation=[1, 1],
|
|
||||||
transposed=False,
|
|
||||||
output_padding=[0, 0],
|
|
||||||
groups=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(script_fun1.graph)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%[[C2:.*]] : Tensor):
|
|
||||||
# CHECK: %[[C3:.*]] : int[] = prim::Constant[value=[3, 3]]()
|
|
||||||
# CHECK: %[[C4:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[C5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
|
||||||
# CHECK: %[[C6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
|
||||||
# CHECK: %[[C7:.*]] : bool = prim::Constant[value=0]()
|
|
||||||
# CHECK: %[[C0:.*]] : Tensor, %[[C1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]])
|
|
||||||
# CHECK: return (%[[C0]], %[[C1]])
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - handle_empty_lists
|
|
||||||
@run_test
|
|
||||||
def handle_empty_lists():
|
|
||||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
|
||||||
# print(target._schema)
|
|
||||||
input = torch.randn((1, 3, 32, 32))
|
|
||||||
kwargs = dict(
|
|
||||||
input=input,
|
|
||||||
kernel_size=[3, 3],
|
|
||||||
stride=[],
|
|
||||||
padding=[0, 0],
|
|
||||||
dilation=[1, 1],
|
|
||||||
ceil_mode=False,
|
|
||||||
)
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.graph)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%[[D2:.*]] : Tensor):
|
|
||||||
# CHECK: %[[D3:.*]] : int[] = prim::Constant[value=[3, 3]]()
|
|
||||||
# CHECK: %[[D4:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[D5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
|
||||||
# CHECK: %[[D6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
|
||||||
# CHECK: %[[D7:.*]] : bool = prim::Constant[value=0]()
|
|
||||||
# CHECK: %[[D0:.*]] : Tensor, %[[D1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[D2]], %[[D3]], %[[D4]], %[[D5]], %[[D6]], %[[D7]])
|
|
||||||
# CHECK: return (%[[D0]], %[[D1]])
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - handle_nones
|
|
||||||
@run_test
|
|
||||||
def handle_nones():
|
|
||||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
|
||||||
# print(target._schema)
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn((1, 3, 32, 32)),
|
|
||||||
kernel_size=[3, 3],
|
|
||||||
stride=None,
|
|
||||||
padding=[0, 0],
|
|
||||||
dilation=[1, 1],
|
|
||||||
ceil_mode=False,
|
|
||||||
)
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.graph)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%[[E1:.*]] : Tensor,
|
|
||||||
# CHECK: %[[E2:.*]] : Tensor,
|
|
||||||
# CHECK: %[[E3:.*]] : Tensor):
|
|
||||||
# CHECK: %[[E4:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
|
||||||
# CHECK: %[[E5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
|
||||||
# CHECK: %[[E6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
|
||||||
# CHECK: %[[E7:.*]] : bool = prim::Constant[value=0]()
|
|
||||||
# CHECK: %[[E8:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
|
||||||
# CHECK: %[[E9:.*]] : int = prim::Constant[value=1]()
|
|
||||||
# CHECK: %[[E0:.*]] : Tensor = aten::convolution(%[[E1]], %[[E2]], %[[E3]], %[[E4]], %[[E5]], %[[E6]], %[[E7]], %[[E8]], %[[E9]])
|
|
||||||
# CHECK: return (%[[E0]])
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - handle_optional_tensors
|
|
||||||
@run_test
|
|
||||||
def handle_optional_tensors():
|
|
||||||
target = torch.ops.aten.convolution.default
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn(1, 3, 32, 32),
|
|
||||||
weight=torch.randn(3, 3, 3, 3),
|
|
||||||
bias=torch.randn(3),
|
|
||||||
stride=[1, 1],
|
|
||||||
padding=[0, 0],
|
|
||||||
dilation=[1, 1],
|
|
||||||
transposed=False,
|
|
||||||
output_padding=[0, 0],
|
|
||||||
groups=1,
|
|
||||||
)
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.graph)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%[[F1:.*]] : Tensor):
|
|
||||||
# CHECK: %[[F2:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[F3:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[F4:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[F5:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[F6:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[F0:.*]] : Tensor = aten::ones_like(%[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], %[[F6]])
|
|
||||||
# CHECK: return (%[[F0]])
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - handle_ones_like
|
|
||||||
@run_test
|
|
||||||
def handle_ones_like():
|
|
||||||
target = torch.ops.aten.ones_like.default
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn(1, 3, 32, 32),
|
|
||||||
dtype=None,
|
|
||||||
layout=None,
|
|
||||||
device=None,
|
|
||||||
pin_memory=None,
|
|
||||||
memory_format=None,
|
|
||||||
)
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.graph)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%[[G3:.*]] : Tensor,
|
|
||||||
# CHECK: %[[G4:.*]] : Tensor,
|
|
||||||
# CHECK: %[[G5:.*]] : Tensor):
|
|
||||||
# CHECK: %[[G6:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[G7:.*]] : NoneType = prim::Constant()
|
|
||||||
# CHECK: %[[G8:.*]] : bool = prim::Constant[value=0]()
|
|
||||||
# CHECK: %[[G9:.*]] : float = prim::Constant[value=1.]()
|
|
||||||
# CHECK: %[[G10:.*]] : float = prim::Constant[value=1.]()
|
|
||||||
# CHECK: %[[G0:.*]] : Tensor, %[[G1:.*]] : Tensor, %[[G2:.*]] : Tensor = aten::native_batch_norm(%[[G3]], %[[G4]], %[[G5]], %[[G6]], %[[G7]], %[[G8]], %[[G9]], %[[G10]])
|
|
||||||
# CHECK: return (%[[G0]], %[[G1]], %[[G2]])
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - handle_multiple_outputs
|
|
||||||
@run_test
|
|
||||||
def handle_multiple_outputs():
|
|
||||||
target = torch.ops.aten.native_batch_norm.default
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn(1, 3, 32, 32),
|
|
||||||
weight=torch.randn(1, 3, 32, 32),
|
|
||||||
bias=torch.randn(1, 3, 32, 32),
|
|
||||||
running_mean=None,
|
|
||||||
running_var=None,
|
|
||||||
training=False,
|
|
||||||
momentum=1.0,
|
|
||||||
eps=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.graph)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: f
|
|
||||||
# CHECK: PASS - check_legal_name
|
|
||||||
@run_test
|
|
||||||
def check_legal_name():
|
|
||||||
target = torch.ops.aten.native_batch_norm.default
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn(1, 3, 32, 32),
|
|
||||||
weight=torch.randn(1, 3, 32, 32),
|
|
||||||
bias=torch.randn(1, 3, 32, 32),
|
|
||||||
running_mean=None,
|
|
||||||
running_var=None,
|
|
||||||
training=False,
|
|
||||||
momentum=1.0,
|
|
||||||
eps=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.name)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: graph(%[[H3:.*]] : Tensor,
|
|
||||||
# CHECK: %[[H4:.*]] : Tensor,
|
|
||||||
# CHECK: %[[H5:.*]] : Tensor,
|
|
||||||
# CHECK: %[[H6:.*]] : Tensor,
|
|
||||||
# CHECK: %[[H7:.*]] : Tensor,
|
|
||||||
# CHECK: %out : Tensor,
|
|
||||||
# CHECK: %save_mean : Tensor,
|
|
||||||
# CHECK: %save_invstd : Tensor):
|
|
||||||
# CHECK: %[[H8:.*]] : bool = prim::Constant[value=0]()
|
|
||||||
# CHECK: %[[H9:.*]] : float = prim::Constant[value=0.10000000000000001]()
|
|
||||||
# CHECK: %[[H10:.*]] : float = prim::Constant[value=0.0001]()
|
|
||||||
# CHECK: %[[H0:.*]] : Tensor, %[[H1:.*]] : Tensor, %[[H2:.*]] : Tensor = aten::native_batch_norm(%[[H3]], %[[H4]], %[[H5]], %[[H6]], %[[H7]], %[[H8]], %[[H9]], %[[H10]], %out, %save_mean, %save_invstd)
|
|
||||||
# CHECK: return (%[[H0]], %[[H1]], %[[H2]])
|
|
||||||
# -----
|
|
||||||
# CHECK: PASS - correctly_order_kwargs
|
|
||||||
@run_test
|
|
||||||
def correctly_order_kwargs():
|
|
||||||
target = torch.ops.aten.native_batch_norm.out
|
|
||||||
|
|
||||||
input = torch.randn(2, 5, 2, 3)
|
|
||||||
running_mean = torch.randn(5)
|
|
||||||
running_var = torch.randn(5)
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
input=torch.randn(2, 5, 2, 3),
|
|
||||||
weight=torch.randn(5),
|
|
||||||
bias=torch.randn(5),
|
|
||||||
running_mean=running_mean,
|
|
||||||
running_var=running_var,
|
|
||||||
training=False,
|
|
||||||
momentum=0.1,
|
|
||||||
eps=0.0001,
|
|
||||||
out=torch.empty_like(input),
|
|
||||||
save_mean=torch.empty_like(running_mean),
|
|
||||||
save_invstd=torch.empty_like(running_var),
|
|
||||||
)
|
|
||||||
|
|
||||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
|
||||||
print(script_fun.graph)
|
|
|
@ -1,23 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
# RUN: true
|
|
||||||
|
|
||||||
|
|
||||||
def run_test(*args, XPASS=False, XFAIL=False):
|
|
||||||
def _run_test(test):
|
|
||||||
test_name = test.__name__
|
|
||||||
try:
|
|
||||||
test()
|
|
||||||
print(("X" if XPASS else "") + f"PASS - {test_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(("X" if XFAIL else "") + f"FAIL - {test_name}")
|
|
||||||
print("Errors: ", e)
|
|
||||||
print()
|
|
||||||
|
|
||||||
if len(args):
|
|
||||||
_run_test(args[0])
|
|
||||||
else:
|
|
||||||
return _run_test
|
|
|
@ -1,69 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
# RUN: %PYTHON %s | FileCheck %s
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from framework import run_test
|
|
||||||
from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: PASS - should_normalize
|
|
||||||
@run_test
|
|
||||||
def should_normalize():
|
|
||||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
|
||||||
input = torch.randn((1, 3, 32, 32))
|
|
||||||
kwargs = {"kernel_size": [3, 3]}
|
|
||||||
golden = {
|
|
||||||
"kernel_size": [3, 3],
|
|
||||||
# This is due to the schema for max_pool2d_with_indices defining
|
|
||||||
# the stride arg as int[2] stride=[].
|
|
||||||
"stride": [],
|
|
||||||
"padding": [0, 0],
|
|
||||||
"dilation": [1, 1],
|
|
||||||
"ceil_mode": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
new_kwargs = normalize_args_kwargs(target, (input,), kwargs)
|
|
||||||
assert torch.allclose(new_kwargs["input"], input)
|
|
||||||
for k, v in new_kwargs.items():
|
|
||||||
if k == "input": continue
|
|
||||||
assert v == golden[k]
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: FAIL - shouldnt_normalize1
|
|
||||||
# CHECK: Errors: missing a required argument: 'kernel_size'
|
|
||||||
@run_test
|
|
||||||
def shouldnt_normalize1():
|
|
||||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
|
||||||
args = (torch.randn((1, 3, 32, 32)),)
|
|
||||||
kwargs = {"stride": []}
|
|
||||||
normalize_args_kwargs(target, args, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# This next two tests are XPASS because of https://github.com/pytorch/pytorch/issues/75342
|
|
||||||
# I.e., they should fail but in fact they pass because of the upstream bug.
|
|
||||||
# The reason for the bug is a fast path branch in operator_schemas.normalize_function
|
|
||||||
# that doesn't do rigorous type checking, and hence lets type mistmatches slip through.
|
|
||||||
# TODO(max): change these to FAIL when the upstream bug is fixed.
|
|
||||||
|
|
||||||
# CHECK: XPASS - shouldnt_normalize2
|
|
||||||
@run_test(XPASS=True)
|
|
||||||
def shouldnt_normalize2():
|
|
||||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
|
||||||
args = (torch.randn((1, 3, 32, 32)),)
|
|
||||||
kwargs = {"kernel_size": []}
|
|
||||||
normalize_args_kwargs(target, args, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# CHECK: XPASS - shouldnt_normalize3
|
|
||||||
@run_test(XPASS=True)
|
|
||||||
def shouldnt_normalize3():
|
|
||||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
|
||||||
args = (torch.randn((1, 3, 32, 32)),)
|
|
||||||
kwargs = {"kernel_size": [3, 3], "padding": None}
|
|
||||||
normalize_args_kwargs(target, args, kwargs)
|
|
|
@ -1,11 +0,0 @@
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
# Subdirectories
|
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
## Declare the sources of the Python module.
|
|
||||||
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode
|
|
||||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
|
||||||
ADD_TO_PARENT TorchMLIRPythonSources
|
|
||||||
SOURCES_GLOB eager_mode/*.py lazytensor/*.py
|
|
||||||
)
|
|
|
@ -1,3 +0,0 @@
|
||||||
import os
|
|
||||||
|
|
||||||
EAGER_MODE_DEBUG = os.environ.get("EAGER_MODE_DEBUG", 'False').lower() in ('true', '1', 't')
|
|
|
@ -1,359 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
"""
|
|
||||||
Translator from torch.jit.ScriptFunction to MLIR.
|
|
||||||
|
|
||||||
|
|
||||||
The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the
|
|
||||||
`torch` dialect.
|
|
||||||
|
|
||||||
The expected use of this module is to create an instance of one of the classes below, and then calling the
|
|
||||||
`to_mlir` method to generate the MLIR representation of the type.
|
|
||||||
|
|
||||||
Information about what types are supported by each class can be found in docstrings of each of the classes.
|
|
||||||
|
|
||||||
In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module.
|
|
||||||
|
|
||||||
The expected use for this module is to use the function
|
|
||||||
`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module`
|
|
||||||
to convert the TorchScript function into MLIR using the `torch` dialect.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import re
|
|
||||||
from typing import Any, Optional, Iterable, Dict
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch._C
|
|
||||||
import torch.jit
|
|
||||||
from torch._ops import OpOverload
|
|
||||||
|
|
||||||
from torch_mlir import ir
|
|
||||||
from torch_mlir.dialects.func import FuncOp
|
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
|
||||||
|
|
||||||
|
|
||||||
class TorchMlirType(abc.ABC):
|
|
||||||
"""
|
|
||||||
A `TorchMlirType` is an object that produces MLIR
|
|
||||||
types in the `torch` dialect. The only requirement
|
|
||||||
for a class to be a subclass of `TorchMlirType` is
|
|
||||||
to define a `to_mlir(self, ir.Context) -> ir.Type`.
|
|
||||||
Each class is allowed to have different types of
|
|
||||||
__init__ methods depending on the information they
|
|
||||||
require to produce the given MLIR representation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TorchTensorTypeError(Exception):
|
|
||||||
def __init__(self, value: str):
|
|
||||||
super().__init__()
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
class TorchTensorType(TorchMlirType):
|
|
||||||
"""
|
|
||||||
This class is used to generate types of the form
|
|
||||||
!torch.tensor and !torch.vtensor<SHAPE, DTYPE>,
|
|
||||||
where SHAPE is a list representing the shape of the tensor,
|
|
||||||
and DTYPE is an MLIR data type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
shape: Optional[Iterable[Optional[int]]] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
self.shape = shape
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
if dtype is None and shape is not None:
|
|
||||||
err = "If shape is specified, dtype must also be specified"
|
|
||||||
raise TorchTensorTypeError(err)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})"
|
|
||||||
|
|
||||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
|
||||||
if self.dtype is None:
|
|
||||||
return ir.Type.parse("!torch.tensor", context=context)
|
|
||||||
|
|
||||||
shape_asm = self._shape_to_mlir_asm()
|
|
||||||
dtype_asm = self._dtype_to_mlir_asm()
|
|
||||||
return ir.Type.parse(
|
|
||||||
f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context
|
|
||||||
)
|
|
||||||
|
|
||||||
def _shape_to_mlir_asm(self) -> str:
|
|
||||||
if self.shape is None:
|
|
||||||
return "*"
|
|
||||||
|
|
||||||
str_sizes = map(lambda x: "?" if x is None else str(x), self.shape)
|
|
||||||
return f'[{",".join(str_sizes)}]'
|
|
||||||
|
|
||||||
def _dtype_to_mlir_asm(self) -> str:
|
|
||||||
if self.dtype in [torch.float64]:
|
|
||||||
return "f64"
|
|
||||||
if self.dtype in [torch.float, torch.float32]:
|
|
||||||
return "f32"
|
|
||||||
if self.dtype in [torch.int, torch.int32]:
|
|
||||||
return "si32"
|
|
||||||
if self.dtype in [torch.int64]:
|
|
||||||
return "si64"
|
|
||||||
if self.dtype in [torch.bool]:
|
|
||||||
return "i1"
|
|
||||||
|
|
||||||
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
|
|
||||||
|
|
||||||
|
|
||||||
class TorchNnModuleType(TorchMlirType):
|
|
||||||
"""This class is used to generate types for `!torch.nn.Module`s."""
|
|
||||||
|
|
||||||
def __init__(self, module_name: str):
|
|
||||||
self.module_name = module_name
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "torch.nn.Module"
|
|
||||||
|
|
||||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
|
||||||
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context)
|
|
||||||
|
|
||||||
|
|
||||||
class PythonType(TorchMlirType):
|
|
||||||
"""
|
|
||||||
This class is used to convert regular Python types
|
|
||||||
into their corresponding `torch` dialect representation.
|
|
||||||
The list of supported types can be found in the dictionary
|
|
||||||
`_type_to_asm_dict`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_type_to_asm_dict = {
|
|
||||||
bool: "!torch.bool",
|
|
||||||
int: "!torch.int",
|
|
||||||
type(None): "!torch.none",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, type_: Any):
|
|
||||||
self.type_ = type_
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return str(self.type_)
|
|
||||||
|
|
||||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
|
||||||
asm = self._type_to_asm_dict.get(self.type_)
|
|
||||||
if asm is None:
|
|
||||||
raise NotImplementedError(f"Unsupported type: {self.type_}")
|
|
||||||
return ir.Type.parse(asm, context=context)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This functionality should be incorporated into ModuleBuilder.import_function.
|
|
||||||
class Annotation:
|
|
||||||
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
|
|
||||||
self.types = list(
|
|
||||||
map(lambda t: PythonType(t) if isinstance(t, type) else t, types)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
result = f"Annotation instance with {len(self.types)} types\n"
|
|
||||||
for e, type_ in enumerate(self.types):
|
|
||||||
result += f" Type of argument {e + 1}: {str(type_)}\n"
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.types)
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationConverter:
|
|
||||||
@staticmethod
|
|
||||||
def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr:
|
|
||||||
dict_attrs = []
|
|
||||||
for type_ in annotation:
|
|
||||||
if not isinstance(type_, TorchTensorType):
|
|
||||||
dict_attrs.append(ir.DictAttr.get({}, context=context))
|
|
||||||
continue
|
|
||||||
|
|
||||||
ir_type = type_.to_mlir(context)
|
|
||||||
with context:
|
|
||||||
type_attr = ir.TypeAttr.get(ir_type)
|
|
||||||
dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr})
|
|
||||||
dict_attrs.append(dict_attr)
|
|
||||||
|
|
||||||
return ir.ArrayAttr.get(dict_attrs, context=context)
|
|
||||||
|
|
||||||
|
|
||||||
def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
|
||||||
with module.context:
|
|
||||||
name_attr = ir.StringAttr.get(name)
|
|
||||||
for op in module.body.operations:
|
|
||||||
if isinstance(op, FuncOp) and op.name == name_attr:
|
|
||||||
# Add name of torch op as debug_module_name so that
|
|
||||||
# run_pipeline_with_repro_report can use it.
|
|
||||||
module.operation.attributes["torch.debug_module_name"] = name_attr
|
|
||||||
return op
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def is_tensor_type(typ: torch._C.Type):
|
|
||||||
return typ.isSubtypeOf(torch.TensorType.get()) or (
|
|
||||||
isinstance(typ, torch.OptionalType)
|
|
||||||
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_list_of_tensors_type(typ: torch._C.Type):
|
|
||||||
return isinstance(typ, torch.ListType) and is_tensor_type(typ.getElementType())
|
|
||||||
|
|
||||||
|
|
||||||
name_mangle_regex = re.compile("[^a-zA-Z0-9]")
|
|
||||||
|
|
||||||
|
|
||||||
def build_ts_script_function(
|
|
||||||
schema: torch._C.FunctionSchema, kwargs: Dict[str, Any]
|
|
||||||
) -> torch.jit.ScriptFunction:
|
|
||||||
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
|
|
||||||
|
|
||||||
Constants are inlined for the purposes of invalidating the compile cache when they change.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
schema: torch._C.FunctionSchema
|
|
||||||
PyTorch's representation for ops, contains type information needed for inlining constants into the TS graph.
|
|
||||||
kwargs: Dict
|
|
||||||
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float/bool params).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.jit.ScriptFunction
|
|
||||||
Fully specialized (all constants) TS graph whose only arguments are tensors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Creates empty TS graph.
|
|
||||||
graph = torch._C.Graph()
|
|
||||||
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
|
|
||||||
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
|
|
||||||
# Associate graph inputs/outputs with node inputs/outputs.
|
|
||||||
graph_inputs = []
|
|
||||||
for arg in schema.arguments:
|
|
||||||
arg_name = arg.name if arg.name != "self" else "input"
|
|
||||||
|
|
||||||
# If arg is a flattened list of tensors, such as in the case of torch.cat
|
|
||||||
# then add each element of the list to the graph corresponding to arg
|
|
||||||
# and insert a ListConstruct to function as input to the op.
|
|
||||||
if is_list_of_tensors_type(arg.type):
|
|
||||||
inps = []
|
|
||||||
for kwarg in [
|
|
||||||
kwarg for kwarg in kwargs if f"{arg_name}_flattened" in kwarg
|
|
||||||
]:
|
|
||||||
inp = graph.addInput()
|
|
||||||
el_typ = arg.type.getElementType()
|
|
||||||
if isinstance(el_typ, torch.OptionalType):
|
|
||||||
el_typ = el_typ.getElementType()
|
|
||||||
inp.setType(el_typ)
|
|
||||||
inp.setDebugName(kwarg)
|
|
||||||
inps.append(inp)
|
|
||||||
graph_inputs.append(kwarg)
|
|
||||||
list_cons = graph.insertNode(graph.create("prim::ListConstruct", inps))
|
|
||||||
list_cons.moveBefore(node)
|
|
||||||
inp = list_cons.output()
|
|
||||||
inp.setType(torch.ListType.ofTensors())
|
|
||||||
# If arg is a tensor, then add input to the graph corresponding to arg.
|
|
||||||
elif is_tensor_type(arg.type) and kwargs[arg_name] is not None:
|
|
||||||
inp = graph.addInput()
|
|
||||||
if isinstance(arg.type, torch.OptionalType):
|
|
||||||
el_typ = arg.type.getElementType()
|
|
||||||
else:
|
|
||||||
el_typ = arg.type
|
|
||||||
inp.setType(el_typ)
|
|
||||||
inp.setDebugName(arg_name)
|
|
||||||
graph_inputs.append(arg_name)
|
|
||||||
# If arg is a constant, inline (at the top of the graph).
|
|
||||||
else:
|
|
||||||
val = kwargs[arg_name]
|
|
||||||
if val == []:
|
|
||||||
# Some ops have empty list default values for args
|
|
||||||
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
|
|
||||||
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
|
|
||||||
# This might be an upstream bug but there doesn't seem to be a way to
|
|
||||||
# build a prim::ListConstruct list that's empty.
|
|
||||||
val = None
|
|
||||||
inp = graph.insertConstant(val)
|
|
||||||
inp.node().moveBefore(node)
|
|
||||||
|
|
||||||
node.addInput(inp)
|
|
||||||
|
|
||||||
# Reorder graph inputs to match kwargs.
|
|
||||||
permutes = [
|
|
||||||
{inp: i for i, inp in enumerate(graph_inputs)}[kwarg]
|
|
||||||
for kwarg in [kwarg for kwarg in kwargs if kwarg in graph_inputs]
|
|
||||||
]
|
|
||||||
graph.permuteInputs(permutes)
|
|
||||||
|
|
||||||
if node.hasMultipleOutputs():
|
|
||||||
for outp in node.outputs():
|
|
||||||
graph.registerOutput(outp)
|
|
||||||
else:
|
|
||||||
graph.registerOutput(node.output())
|
|
||||||
|
|
||||||
fn = torch._C._create_function_from_graph(
|
|
||||||
f"{name_mangle_regex.sub('', str(graph))}", graph
|
|
||||||
)
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
|
||||||
def build_mlir_module(op: OpOverload, kwargs: Dict[str, Any]) -> ir.Module:
|
|
||||||
"""Translate input function into an MLIR module in the `torch` dialect.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
op: OpOverload
|
|
||||||
Callable from the torch.ops.aten module/namespace that has a _schema field.
|
|
||||||
kwargs: Dict
|
|
||||||
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float,bool params).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ir.Module
|
|
||||||
Translation of the input module into an MLIR module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# The assert here is to catch tensor shapes that have size 0 dimensions, such as those produced in
|
|
||||||
# the course of evaluating SliceEndSleStartModule_basic and SliceOutOfLowerBoundEndIndexModule_basic.
|
|
||||||
# Such 0 size dimensions fail the assert at mlir/lib/IR/BuiltinTypes.cpp, line 887
|
|
||||||
annotations = []
|
|
||||||
for arg_name, arg in kwargs.items():
|
|
||||||
if isinstance(arg, torch.Tensor):
|
|
||||||
assert np.prod(arg.shape) != 0, f"{arg_name} has invalid shape {arg.shape}"
|
|
||||||
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg.dtype))
|
|
||||||
annotations = tuple(annotations)
|
|
||||||
|
|
||||||
script_fun = build_ts_script_function(op._schema, kwargs)
|
|
||||||
assert len(annotations) == len(
|
|
||||||
list(script_fun.graph.inputs())
|
|
||||||
), "Number of annotations and number of graph inputs differs."
|
|
||||||
|
|
||||||
mb = ModuleBuilder()
|
|
||||||
mb.import_function(script_fun)
|
|
||||||
|
|
||||||
func_op = get_func_op_with_name(mb.module, script_fun.name)
|
|
||||||
assert (
|
|
||||||
func_op is not None
|
|
||||||
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
|
|
||||||
|
|
||||||
func_annotation = Annotation(annotations)
|
|
||||||
arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context)
|
|
||||||
func_op.attributes["arg_attrs"] = arg_attrs
|
|
||||||
|
|
||||||
return mb.module
|
|
|
@ -1,111 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Callable, Tuple
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.fx import immutable_collections
|
|
||||||
from torch.fx.operator_schemas import (
|
|
||||||
_torchscript_schema_to_signature,
|
|
||||||
_args_kwargs_to_normalized_args_kwargs,
|
|
||||||
)
|
|
||||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops
|
|
||||||
|
|
||||||
from torch_mlir.dialects import torch as torch_dialect
|
|
||||||
|
|
||||||
OP_REGISTRY = {op["name"]: op for op in get_registered_ops()}
|
|
||||||
SUPPORTED_OPS = frozenset(
|
|
||||||
[
|
|
||||||
member.OPERATION_NAME
|
|
||||||
for member in vars(torch_dialect).values()
|
|
||||||
if hasattr(member, "OPERATION_NAME")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedByTorchMlirEagerMode(Exception):
|
|
||||||
def __init__(self, value: str):
|
|
||||||
super().__init__()
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
|
|
||||||
"""Fill in default values for optional args, which are dependent on the schema."""
|
|
||||||
sig = _torchscript_schema_to_signature(target._schema)
|
|
||||||
_, new_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
|
||||||
sig, args, kwargs, normalize_to_only_use_kwargs=True
|
|
||||||
)
|
|
||||||
if "self" in new_kwargs:
|
|
||||||
new_kwargs["input"] = new_kwargs.pop("self")
|
|
||||||
|
|
||||||
# Flatten lists of args for ops that takes lists, such as torch.cat.
|
|
||||||
to_remove = set()
|
|
||||||
to_add = {}
|
|
||||||
for k, v in new_kwargs.items():
|
|
||||||
if isinstance(v, (tuple, list)) and len(v) and isinstance(v[0], torch.Tensor):
|
|
||||||
to_remove.add(k)
|
|
||||||
for i, vv in enumerate(v):
|
|
||||||
to_add[f"{k}_flattened_{i}"] = vv
|
|
||||||
|
|
||||||
for rem in to_remove:
|
|
||||||
del new_kwargs[rem]
|
|
||||||
new_kwargs.update(**to_add)
|
|
||||||
|
|
||||||
# Sort here in order to have consistency across TS graph and
|
|
||||||
# MLIR module.
|
|
||||||
sorted_kwargs = dict(sorted(new_kwargs.items()))
|
|
||||||
return immutable_collections.immutable_dict(sorted_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_op(op):
|
|
||||||
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
|
|
||||||
return registered_op
|
|
||||||
|
|
||||||
|
|
||||||
def check_get_aliased_arg(func: Callable,):
|
|
||||||
"""Write back to mutable args that aren't properly handled otherwise.
|
|
||||||
|
|
||||||
Because of how we pass values to the backend we don't currently support ops that mutate operands.
|
|
||||||
That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
|
|
||||||
as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which
|
|
||||||
we can for these special cases. Hence, the solution is to manually write back to the same operand that the
|
|
||||||
conventional pytorch op variant would write to.
|
|
||||||
|
|
||||||
Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that
|
|
||||||
mutate running_mean and running_var). We don't currently handle those.
|
|
||||||
"""
|
|
||||||
|
|
||||||
registered_op = get_registered_op(func)
|
|
||||||
if not registered_op["is_mutable"]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if len(registered_op["returns"]) > 1:
|
|
||||||
raise UnsupportedByTorchMlirEagerMode(
|
|
||||||
"TorchMLIR doesn't handle multiple aliased returns yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
aliased_arg = next(
|
|
||||||
arg
|
|
||||||
for arg in registered_op["arguments"]
|
|
||||||
if "alias_info" in arg and arg["alias_info"]["is_write"]
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
"alias_info" in registered_op["returns"][0]
|
|
||||||
and registered_op["returns"][0]["alias_info"]["is_write"]
|
|
||||||
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
|
|
||||||
and registered_op["returns"][0]["alias_info"]["after"][0]
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
len(aliased_arg["alias_info"]["after"]) == 1
|
|
||||||
and aliased_arg["alias_info"]["after"][0]
|
|
||||||
== registered_op["returns"][0]["alias_info"]["after"][0]
|
|
||||||
)
|
|
||||||
|
|
||||||
return aliased_arg["name"] if aliased_arg["name"] != "self" else "input"
|
|
|
@ -1,102 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
import abc
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TypeVar, Tuple, Callable, List, Dict, Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from torch_mlir._mlir_libs._mlir.ir import Module
|
|
||||||
|
|
||||||
# TODO: This might need to be an ABC too, such as
|
|
||||||
# to support finding the backend that created the tensor.
|
|
||||||
DeviceTensor = TypeVar("DeviceTensor")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TensorMetaData:
|
|
||||||
"""A small container for metadata necessary for satisfying the pytorch dispatcher and other code (pytorch or
|
|
||||||
otherwise) that branches on these attributes.
|
|
||||||
|
|
||||||
There is a lot of code in the PyTorch codebase that branches based on these attributes; the obvious ones here
|
|
||||||
are dtype, device, and requires_grad (necessary for autograd itself). There is ample warning from PyTorch that,
|
|
||||||
in principle, these should be as close as possible to true; see
|
|
||||||
https://github.com/albanD/subclass_zoo/blob/1566e038f03cd89ab3cc37e670a44e3c2bbc1897/trivial_tensors.py#L90-L92
|
|
||||||
|
|
||||||
The defaults (properties) simplify the api and seem to work after some testing but
|
|
||||||
might malfunction in unexpected ways.
|
|
||||||
# TODO: revisit these assumptions
|
|
||||||
"""
|
|
||||||
|
|
||||||
size: Tuple[int]
|
|
||||||
dtype: torch.dtype
|
|
||||||
requires_grad: bool
|
|
||||||
|
|
||||||
strides: Tuple[int]
|
|
||||||
storage_offset: int = 0
|
|
||||||
layout: torch.layout = torch.strided
|
|
||||||
device: torch.device = torch.device("cpu")
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
size,
|
|
||||||
dtype,
|
|
||||||
requires_grad,
|
|
||||||
strides=None,
|
|
||||||
storage_offset=None,
|
|
||||||
layout=None,
|
|
||||||
device=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
object.__setattr__(self, "size", size)
|
|
||||||
object.__setattr__(self, "dtype", dtype)
|
|
||||||
object.__setattr__(self, "requires_grad", requires_grad)
|
|
||||||
|
|
||||||
object.__setattr__(
|
|
||||||
self, "strides", strides if strides is not None else len(size) * [0]
|
|
||||||
)
|
|
||||||
object.__setattr__(
|
|
||||||
self, "storage_offset", storage_offset if storage_offset is not None else 0
|
|
||||||
)
|
|
||||||
object.__setattr__(
|
|
||||||
self, "layout", layout if layout is not None else torch.strided
|
|
||||||
)
|
|
||||||
object.__setattr__(
|
|
||||||
self, "device", device if device is not None else torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TorchMLIREagerBackend(abc.ABC):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def compile(
|
|
||||||
self, module: Module
|
|
||||||
) -> Callable[[List[DeviceTensor]], List[DeviceTensor]]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> DeviceTensor:
|
|
||||||
"""Unwrap the backend representation in order to build a torch.Tensor."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_torch_metadata(
|
|
||||||
self, tensor: DeviceTensor, kwargs: Dict[str, Any]
|
|
||||||
) -> TensorMetaData:
|
|
||||||
"""Parse relevant tensor metadata from backend device array (e.g., shape, stride, layout) in order to build
|
|
||||||
wrapper tensor."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def transfer_from_device_to_torch(self, tensor: DeviceTensor) -> torch.Tensor:
|
|
||||||
"""If compilation fails for some reason then device specific representations need to be munged into a
|
|
||||||
torch.Tensor representation.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def copy_into(self, dst: DeviceTensor, src: DeviceTensor):
|
|
||||||
"""This method is needed for things like handling aliased arguments."""
|
|
||||||
raise NotImplementedError
|
|
|
@ -1,257 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
import contextlib
|
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
import warnings
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils._pytree import tree_map
|
|
||||||
|
|
||||||
from torch_mlir.eager_mode.ir_building import build_mlir_module
|
|
||||||
from torch_mlir.eager_mode.torch_mlir_dispatch import (
|
|
||||||
UnsupportedByTorchMlirEagerMode,
|
|
||||||
normalize_args_kwargs,
|
|
||||||
check_get_aliased_arg,
|
|
||||||
)
|
|
||||||
from torch_mlir.eager_mode import EAGER_MODE_DEBUG
|
|
||||||
from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def no_dispatch():
|
|
||||||
"""Prevent infinite recursion in case accidentally calling a tensor method on a TorchMLIRTensor within
|
|
||||||
__torch_dispatch__."""
|
|
||||||
|
|
||||||
guard = torch._C._DisableTorchDispatch()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
del guard
|
|
||||||
|
|
||||||
|
|
||||||
backend = EagerModeRefBackend()
|
|
||||||
|
|
||||||
UNSUPPORTED_OPS = re.compile(
|
|
||||||
"|".join([
|
|
||||||
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
|
|
||||||
"detach",
|
|
||||||
# We don't handle _local_scalar_dense because it's just a way to unwrap a tensor that wraps a number.
|
|
||||||
"_local_scalar_dense",
|
|
||||||
# https://github.com/llvm/torch-mlir/issues/878
|
|
||||||
"_unsafe_view",
|
|
||||||
"view",
|
|
||||||
])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TorchMLIRTensor(torch.Tensor):
|
|
||||||
"""This class serves the role abstract class with common functionality for dispatching through Torch-MLIR instead of aten.
|
|
||||||
|
|
||||||
It defers device specific behavior to device specific implementations. The deriving classes use the
|
|
||||||
make_bare_wrapper_subclass convenience method, adjacent here, and override __torch_dispatch__ in order to dispatch
|
|
||||||
through Torch-MLIR instead of aten. Backends are free to choose whatever representation of the buffers (i.e., `elem`)
|
|
||||||
and are expected to provide conversion mechanisms between their representation and torch.Tensor.
|
|
||||||
|
|
||||||
Here we only verify that inputs abide by current supported features of Torch-MLIR (contiguous memory and
|
|
||||||
strided tensor layout) and build the mlir module. Importantly, we also recover from any malfunctions in the
|
|
||||||
deriving classes and dispatch back to conventional PyTorch.
|
|
||||||
|
|
||||||
More documentation on how the __torch_dispatch__ pattern works can be found in this forum post
|
|
||||||
https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557
|
|
||||||
and this RFC
|
|
||||||
https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call
|
|
||||||
and this repo with many examples
|
|
||||||
https://github.com/albanD/subclass_zoo
|
|
||||||
"""
|
|
||||||
|
|
||||||
elem: Any
|
|
||||||
|
|
||||||
__slots__ = ["elem"]
|
|
||||||
|
|
||||||
def __new__(cls, elem, **kwargs):
|
|
||||||
"""Wrap elem (which could be a torch.Tensor or otherwise) in a torch.Tensor subclass.
|
|
||||||
|
|
||||||
Critically, this method needs to parse relevant metadata from the device representation
|
|
||||||
(such as shape, striding, dtype, etc.) and translate it into torch conventions.
|
|
||||||
|
|
||||||
Deriving classes must provide a way to construct themselves from either their device specific representation
|
|
||||||
or torch.Tensor; the latter is to handle the case that dispatch to PyTorch to recover from an error.
|
|
||||||
"""
|
|
||||||
if kwargs.get("constructing_from_device_tensor", False):
|
|
||||||
tensor_meta_data = backend.get_torch_metadata(elem, kwargs)
|
|
||||||
r = make_bare_wrapper_subclass(
|
|
||||||
cls=cls,
|
|
||||||
size=tensor_meta_data.size,
|
|
||||||
strides=tensor_meta_data.strides,
|
|
||||||
storage_offset=tensor_meta_data.storage_offset,
|
|
||||||
dtype=tensor_meta_data.dtype,
|
|
||||||
layout=tensor_meta_data.layout,
|
|
||||||
device=tensor_meta_data.device,
|
|
||||||
requires_grad=tensor_meta_data.requires_grad,
|
|
||||||
)
|
|
||||||
r.elem = elem
|
|
||||||
elif isinstance(elem, torch.nn.Parameter):
|
|
||||||
r = make_wrapper_subclass_from_torch_tensor(cls, elem.data, **kwargs)
|
|
||||||
r.elem = backend.transfer_from_torch_to_device(elem.detach().data)
|
|
||||||
elif isinstance(elem, torch.Tensor):
|
|
||||||
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
|
|
||||||
r.elem = backend.transfer_from_torch_to_device(elem)
|
|
||||||
# This branch handles the case when a python scalar is passed to some op
|
|
||||||
# or is returned from some aten op, such as _local_scalar_dense.
|
|
||||||
elif isinstance(elem, (int, float, bool)):
|
|
||||||
return elem
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown element type: {type(elem)}")
|
|
||||||
|
|
||||||
return r
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
if self.grad_fn:
|
|
||||||
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})"
|
|
||||||
else:
|
|
||||||
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__})"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
|
|
||||||
requires_grad = check_requires_grad(*args, **kwargs)
|
|
||||||
try:
|
|
||||||
with no_dispatch():
|
|
||||||
if hasattr(func, "op_name"):
|
|
||||||
op_name = func.op_name
|
|
||||||
elif hasattr(func, "__name__"):
|
|
||||||
# Handle builtin_function_or_method.
|
|
||||||
op_name = func.__name__
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"op {func} has no name")
|
|
||||||
|
|
||||||
requires_grad = requires_grad and "view" not in op_name
|
|
||||||
|
|
||||||
if UNSUPPORTED_OPS.match(op_name):
|
|
||||||
raise UnsupportedByTorchMlirEagerMode(op_name)
|
|
||||||
|
|
||||||
if not hasattr(func, "_schema"):
|
|
||||||
raise RuntimeError(f"op {func} has no schema.")
|
|
||||||
|
|
||||||
normalized_kwargs = normalize_args_kwargs(func, args, kwargs)
|
|
||||||
|
|
||||||
if "layout" in normalized_kwargs and normalized_kwargs[
|
|
||||||
"layout"
|
|
||||||
] not in {0, None}:
|
|
||||||
raise UnsupportedByTorchMlirEagerMode(
|
|
||||||
f"{normalized_kwargs['layout']} layout not supported."
|
|
||||||
)
|
|
||||||
if "memory_format" in normalized_kwargs and normalized_kwargs[
|
|
||||||
"memory_format"
|
|
||||||
] not in {0, None}:
|
|
||||||
raise UnsupportedByTorchMlirEagerMode(
|
|
||||||
f"{normalized_kwargs['memory_format']} memory format not supported."
|
|
||||||
)
|
|
||||||
eager_module = build_mlir_module(func, normalized_kwargs)
|
|
||||||
device_tensor_args = [
|
|
||||||
kwarg.elem
|
|
||||||
for _, kwarg in normalized_kwargs.items()
|
|
||||||
if isinstance(kwarg, cls)
|
|
||||||
]
|
|
||||||
assert len(eager_module.body.operations[0].arguments) == len(
|
|
||||||
device_tensor_args
|
|
||||||
), "Number of parameters and number of arguments differs."
|
|
||||||
op_mlir_backend_callable = backend.compile(eager_module)
|
|
||||||
out = op_mlir_backend_callable(*device_tensor_args)
|
|
||||||
out = tree_map(
|
|
||||||
lambda x: cls(
|
|
||||||
x, requires_grad=requires_grad, constructing_from_device_tensor=True
|
|
||||||
),
|
|
||||||
out,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
if EAGER_MODE_DEBUG:
|
|
||||||
warnings.warn(traceback.format_exc())
|
|
||||||
if isinstance(e, UnsupportedByTorchMlirEagerMode):
|
|
||||||
warnings.warn(
|
|
||||||
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
|
|
||||||
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
|
|
||||||
)
|
|
||||||
|
|
||||||
with no_dispatch():
|
|
||||||
unwrapped_args = tree_map(cls.unwrap, args)
|
|
||||||
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
|
|
||||||
out = func(*unwrapped_args, **unwrapped_kwargs)
|
|
||||||
|
|
||||||
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)
|
|
||||||
|
|
||||||
maybe_aliased_arg_name = check_get_aliased_arg(func)
|
|
||||||
if maybe_aliased_arg_name is not None:
|
|
||||||
backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def unwrap(cls, e):
|
|
||||||
"""Unwrap the TorchMLIRTensor representation in order to access the actual device specific representation."""
|
|
||||||
if isinstance(e, cls):
|
|
||||||
return backend.transfer_from_device_to_torch(e.elem)
|
|
||||||
return e
|
|
||||||
|
|
||||||
|
|
||||||
def check_requires_grad(*args, **kwargs):
|
|
||||||
requires_grad = False
|
|
||||||
|
|
||||||
def check_grad(e):
|
|
||||||
nonlocal requires_grad
|
|
||||||
if isinstance(e, TorchMLIRTensor):
|
|
||||||
requires_grad |= e.requires_grad
|
|
||||||
|
|
||||||
tree_map(check_grad, args)
|
|
||||||
tree_map(check_grad, kwargs)
|
|
||||||
|
|
||||||
return requires_grad
|
|
||||||
|
|
||||||
|
|
||||||
def make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs):
|
|
||||||
"""Convenience method that parse out relevant metadata from a torch.Tensor, in order to produce
|
|
||||||
a wrapper subclass.
|
|
||||||
|
|
||||||
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
|
|
||||||
of the device specific implementation.
|
|
||||||
"""
|
|
||||||
r = make_bare_wrapper_subclass(
|
|
||||||
cls=cls,
|
|
||||||
size=elem.size(),
|
|
||||||
strides=elem.stride(),
|
|
||||||
storage_offset=elem.storage_offset(),
|
|
||||||
dtype=elem.dtype,
|
|
||||||
layout=elem.layout,
|
|
||||||
device=elem.device,
|
|
||||||
# Only float tensors can have gradients.
|
|
||||||
requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64}
|
|
||||||
and (kwargs.get("requires_grad", False) or elem.requires_grad),
|
|
||||||
)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
def make_bare_wrapper_subclass(
|
|
||||||
*, cls, size, strides, storage_offset, dtype, layout, device, requires_grad
|
|
||||||
):
|
|
||||||
"""Convenience method that builds a wrapper subclass.
|
|
||||||
|
|
||||||
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
|
|
||||||
of the device specific implementation.
|
|
||||||
"""
|
|
||||||
return torch.Tensor._make_wrapper_subclass(
|
|
||||||
cls,
|
|
||||||
size,
|
|
||||||
strides=strides,
|
|
||||||
storage_offset=storage_offset,
|
|
||||||
dtype=dtype,
|
|
||||||
layout=layout,
|
|
||||||
device=device,
|
|
||||||
requires_grad=requires_grad,
|
|
||||||
)
|
|
|
@ -9,5 +9,4 @@ from .native_torch import NativeTorchTestConfig
|
||||||
from .torchscript import TorchScriptTestConfig
|
from .torchscript import TorchScriptTestConfig
|
||||||
from .mhlo_backend import MhloBackendTestConfig
|
from .mhlo_backend import MhloBackendTestConfig
|
||||||
from .tosa_backend import TosaBackendTestConfig
|
from .tosa_backend import TosaBackendTestConfig
|
||||||
from .eager_mode import EagerModeTestConfig
|
|
||||||
from .torchdynamo import TorchDynamoTestConfig
|
from .torchdynamo import TorchDynamoTestConfig
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils._pytree import tree_map
|
|
||||||
|
|
||||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
|
||||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
|
||||||
|
|
||||||
|
|
||||||
def wrap(e):
|
|
||||||
return TorchMLIRTensor(e.detach().clone()) if isinstance(e, torch.Tensor) else e
|
|
||||||
|
|
||||||
|
|
||||||
def unwrap(e):
|
|
||||||
return TorchMLIRTensor.unwrap(e) if isinstance(e, TorchMLIRTensor) else e
|
|
||||||
|
|
||||||
|
|
||||||
def to_tmt(m: torch.nn.Module):
|
|
||||||
for buf_name, buf in m.named_buffers(recurse=True):
|
|
||||||
if isinstance(buf, TorchMLIRTensor):
|
|
||||||
continue
|
|
||||||
m.register_buffer(buf_name, TorchMLIRTensor(buf))
|
|
||||||
for param_name, param in m.named_parameters(recurse=True):
|
|
||||||
if isinstance(param, TorchMLIRTensor):
|
|
||||||
continue
|
|
||||||
m.register_parameter(
|
|
||||||
param_name,
|
|
||||||
torch.nn.Parameter(
|
|
||||||
TorchMLIRTensor(param), requires_grad=param.requires_grad
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for attr in dir(m):
|
|
||||||
field = getattr(m, attr)
|
|
||||||
if isinstance(field, torch.Tensor) and not isinstance(field, TorchMLIRTensor):
|
|
||||||
setattr(m, attr, TorchMLIRTensor(field))
|
|
||||||
|
|
||||||
|
|
||||||
class EagerModeTestConfig(TestConfig):
|
|
||||||
"""Trivial test config that exercises eager mode plumbing"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
|
||||||
program.apply(to_tmt)
|
|
||||||
return program
|
|
||||||
|
|
||||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
|
||||||
result: Trace = []
|
|
||||||
for item in trace:
|
|
||||||
attr = artifact
|
|
||||||
for part in item.symbol.split("."):
|
|
||||||
attr = getattr(attr, part)
|
|
||||||
|
|
||||||
inps = tree_map(wrap, item.inputs)
|
|
||||||
outps = attr(*inps)
|
|
||||||
output = tree_map(unwrap, outps)
|
|
||||||
|
|
||||||
result.append(
|
|
||||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
|
||||||
)
|
|
||||||
return result
|
|
|
@ -1,90 +0,0 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from torch_mlir.compiler_utils import (
|
|
||||||
get_module_name_for_debug_dump,
|
|
||||||
run_pipeline_with_repro_report,
|
|
||||||
)
|
|
||||||
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
|
|
||||||
TorchMLIREagerBackend,
|
|
||||||
TensorMetaData,
|
|
||||||
)
|
|
||||||
from torch_mlir.ir import Module
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
|
||||||
RefBackendLinalgOnTensorsBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
NUMPY_TO_TORCH_DTYPE_DICT = {
|
|
||||||
np.bool_: torch.bool,
|
|
||||||
np.uint8: torch.uint8,
|
|
||||||
np.int8: torch.int8,
|
|
||||||
np.int16: torch.int16,
|
|
||||||
np.int32: torch.int32,
|
|
||||||
np.int64: torch.int64,
|
|
||||||
np.float16: torch.float16,
|
|
||||||
np.float32: torch.float32,
|
|
||||||
np.float64: torch.float64,
|
|
||||||
np.complex64: torch.complex64,
|
|
||||||
np.complex128: torch.complex128,
|
|
||||||
}
|
|
||||||
|
|
||||||
_ref_backend = RefBackendLinalgOnTensorsBackend()
|
|
||||||
|
|
||||||
|
|
||||||
class EagerModeRefBackend(TorchMLIREagerBackend):
|
|
||||||
"""Main entry-point for the reference backend for eager mode.
|
|
||||||
|
|
||||||
RefBackend uses numpy.ndarray representations of tensors and thus all of the wrapping and unwrapping
|
|
||||||
and munging here is done to between torch.Tensor and numpy.ndarray.
|
|
||||||
"""
|
|
||||||
|
|
||||||
module_to_refbackend_invoker = {}
|
|
||||||
|
|
||||||
def get_torch_metadata(
|
|
||||||
self, tensor: np.ndarray, kwargs: Dict[str, Any]
|
|
||||||
) -> TensorMetaData:
|
|
||||||
return TensorMetaData(
|
|
||||||
size=tensor.shape,
|
|
||||||
dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type],
|
|
||||||
requires_grad=tensor.dtype in {np.float, np.float32, np.float64}
|
|
||||||
and kwargs.get("requires_grad", False),
|
|
||||||
)
|
|
||||||
|
|
||||||
def compile(self, imported_module: Module):
|
|
||||||
"""Lower the imported TS module to linalg and then further compile for the reference backend and then call."""
|
|
||||||
fn_name = get_module_name_for_debug_dump(imported_module)
|
|
||||||
module_hash = str(imported_module)
|
|
||||||
if module_hash not in self.module_to_refbackend_invoker:
|
|
||||||
run_pipeline_with_repro_report(
|
|
||||||
imported_module,
|
|
||||||
"builtin.module(torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
|
||||||
"EagerMode",
|
|
||||||
)
|
|
||||||
self.module_to_refbackend_invoker[module_hash] = _ref_backend.load(
|
|
||||||
_ref_backend.compile(imported_module)
|
|
||||||
)
|
|
||||||
|
|
||||||
ref_backend_invoker = self.module_to_refbackend_invoker[module_hash]
|
|
||||||
op_mlir_backend_callable = getattr(ref_backend_invoker, fn_name)
|
|
||||||
assert (
|
|
||||||
op_mlir_backend_callable is not None
|
|
||||||
), f"Couldn't find function in module."
|
|
||||||
return op_mlir_backend_callable
|
|
||||||
|
|
||||||
def copy_into(self, dst: np.ndarray, src: np.ndarray):
|
|
||||||
np.copyto(dst, src)
|
|
||||||
|
|
||||||
def transfer_from_device_to_torch(self, e: np.ndarray):
|
|
||||||
return torch.from_numpy(e).clone()
|
|
||||||
|
|
||||||
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> np.ndarray:
|
|
||||||
return tensor.detach().numpy()
|
|
Loading…
Reference in New Issue