mirror of https://github.com/llvm/torch-mlir
Remove eager_mode
This was an experimental attempt at rolling out own op-by-op executor with `__torch_dispatch__`, but it proved difficult to make it robust. Op-by-op execution is very easy to implement robustly now with the PyTorch 2.0 stack, so we don't need eager_mode. Downstream users were using eager_mode to implement lockstep numerical accuracy debuggers. We implemented the same functionality with TorchDynamo in https://github.com/llvm/torch-mlir/pull/1681 so now there is not much reason to continue maintaining it.pull/1692/head
parent
109c91ae9b
commit
7731211d02
|
@ -75,14 +75,6 @@ torch-mlir prediction
|
|||
|
||||
View examples [here](docs/ltc_examples.md).
|
||||
|
||||
### Eager Mode
|
||||
|
||||
Eager mode with TorchMLIR is a very experimental eager mode backend for PyTorch through the torch-mlir framework.
|
||||
Effectively, this mode works by compiling operator by operator as the NN is eagerly executed by PyTorch.
|
||||
This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported operator).
|
||||
A simple example can be found at [eager_mode.py](examples/eager_mode.py).
|
||||
A ResNet18 example can be found at [eager_mode_resnet18.py](examples/eager_mode_resnet18.py).
|
||||
|
||||
## Repository Layout
|
||||
|
||||
The project follows the conventions of typical MLIR-based projects:
|
||||
|
|
|
@ -262,9 +262,6 @@ function test_in_tree() {
|
|||
echo ":::: Run Linalg e2e integration tests"
|
||||
python -m e2e_testing.main --config=linalg -v
|
||||
|
||||
echo ":::: Run eager_mode e2e integration tests"
|
||||
python -m e2e_testing.main --config=eager_mode -v
|
||||
|
||||
echo ":::: Run MHLO e2e integration tests"
|
||||
python -m e2e_testing.main --config=mhlo -v
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ from torch_mlir_e2e_test.configs import (
|
|||
NativeTorchTestConfig,
|
||||
TorchScriptTestConfig,
|
||||
TosaBackendTestConfig,
|
||||
EagerModeTestConfig,
|
||||
TorchDynamoTestConfig,
|
||||
)
|
||||
|
||||
|
@ -28,14 +27,14 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackend
|
|||
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||
|
||||
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
||||
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||
register_all_tests()
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core', 'torchdynamo']
|
||||
config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'lazy_tensor_core', 'torchdynamo']
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
parser.add_argument('-c', '--config',
|
||||
choices=config_choices,
|
||||
|
@ -47,7 +46,6 @@ Meaning of options:
|
|||
"tosa": run through torch-mlir's default TOSA backend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
"eager_mode": run through torch-mlir's eager mode frontend, using Linalg-on-Tensors for execution.
|
||||
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
|
||||
"torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors.
|
||||
''')
|
||||
|
@ -91,9 +89,6 @@ def main():
|
|||
elif args.config == 'torchscript':
|
||||
config = TorchScriptTestConfig()
|
||||
xfail_set = {}
|
||||
elif args.config == 'eager_mode':
|
||||
config = EagerModeTestConfig()
|
||||
xfail_set = EAGER_MODE_XFAIL_SET
|
||||
elif args.config == 'lazy_tensor_core':
|
||||
config = LazyTensorCoreTestConfig()
|
||||
xfail_set = LTC_XFAIL_SET
|
||||
|
|
|
@ -14,13 +14,6 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
|
|||
|
||||
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
EAGER_MODE_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||
# RefBackend fails for some reason.
|
||||
# These tests pass in the regular RefBackend flow, so it's unclear
|
||||
# why they fail here.
|
||||
"Matmul_vecmat",
|
||||
}
|
||||
|
||||
TORCHDYNAMO_XFAIL_SET = {
|
||||
#### General TorchDynamo/PyTorch errors
|
||||
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
|
||||
torch_a = torch.randn(5, requires_grad=True)
|
||||
torch_b = torch.randn(5, requires_grad=True)
|
||||
|
||||
torch_c = torch_a + torch_b
|
||||
torch_d = torch_a * torch_b
|
||||
torch_e = torch_c / torch_d
|
||||
torch_loss = torch_e.sum()
|
||||
print("PyTorch loss: ", torch_loss)
|
||||
|
||||
torch_loss.backward()
|
||||
print("PyTorch grad a: ", torch_a.grad)
|
||||
print("PyTorch grad b: ", torch_b.grad)
|
||||
|
||||
a = TorchMLIRTensor(torch_a)
|
||||
b = TorchMLIRTensor(torch_b)
|
||||
|
||||
c = a + b
|
||||
d = a * b
|
||||
e = c / d
|
||||
loss = e.sum()
|
||||
print("Torch-MLIR loss: ", loss)
|
||||
|
||||
loss.backward()
|
||||
print("Torch-MLIR grad a: ", a.grad)
|
||||
print("Torch-MLIR grad b: ", b.grad)
|
|
@ -1,89 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import sys
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
|
||||
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
||||
}
|
||||
img = Image.open(requests.get(url, headers=headers,
|
||||
stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
||||
def load_labels():
|
||||
classes_text = requests.get(
|
||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
||||
stream=True,
|
||||
).text
|
||||
labels = [line.strip() for line in classes_text.splitlines()]
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res):
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
||||
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
||||
return top3
|
||||
|
||||
|
||||
def predictions(torch_func, img, labels):
|
||||
golden_prediction = top3_possibilities(torch_func(img))
|
||||
print("PyTorch prediction")
|
||||
print(golden_prediction)
|
||||
prediction = top3_possibilities(torch_func(TorchMLIRTensor(img)))
|
||||
print("torch-mlir prediction")
|
||||
print(prediction)
|
||||
|
||||
|
||||
class ResNet18Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = models.resnet18(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.s = ResNet18Module()
|
||||
|
||||
def forward(self, x):
|
||||
return self.s.forward(x)
|
||||
|
||||
|
||||
image_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
|
||||
test_module = TestModule()
|
||||
predictions(test_module.forward, img, labels)
|
|
@ -102,12 +102,6 @@ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
|||
add_subdirectory(torch_mlir_e2e_test)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
# Eager mode
|
||||
################################################################################
|
||||
|
||||
add_subdirectory(torch_mlir/eager_mode)
|
||||
|
||||
################################################################################
|
||||
# Custom op example
|
||||
# Required for running the update_torch_ods.sh and update_shape_lib.sh scripts.
|
||||
|
|
|
@ -1,320 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from framework import run_test
|
||||
from torch_mlir.eager_mode.ir_building import build_ts_script_function
|
||||
|
||||
|
||||
# CHECK: graph(%[[A1:.*]] : Tensor,
|
||||
# CHECK: %[[A2:.*]] : Tensor,
|
||||
# CHECK: %[[A3:.*]] : Tensor):
|
||||
# CHECK: %[[A4:.*]] : int = prim::Constant[value=1]()
|
||||
# CHECK: %[[A5:.*]] : int = prim::Constant[value=1]()
|
||||
# CHECK: %[[A0:.*]] : Tensor = aten::addmm(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]])
|
||||
# CHECK: return (%[[A0]])
|
||||
# -----
|
||||
# CHECK: PASS - simple
|
||||
@run_test
|
||||
def simple():
|
||||
target = torch.ops.aten.addmm.default
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
mat1=torch.randn(1, 3, 32, 32),
|
||||
mat2=torch.randn(1, 3, 32, 32),
|
||||
beta=1,
|
||||
alpha=1,
|
||||
)
|
||||
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[B1:.*]] : Tensor,
|
||||
# CHECK: %[[B2:.*]] : Tensor,
|
||||
# CHECK: %[[B3:.*]] : Tensor):
|
||||
# CHECK: %[[B4:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[B5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[B6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[B7:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[B8:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[B9:.*]] : int = prim::Constant[value=1]()
|
||||
# CHECK: %[[B0:.*]] : Tensor = aten::convolution(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[B9]])
|
||||
# CHECK: return (%[[B0]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_optional_tensor_input
|
||||
@run_test
|
||||
def handle_optional_tensor_input():
|
||||
target = torch.ops.aten.convolution.default
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(3, 3, 3, 3),
|
||||
bias=torch.randn(3),
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: FAIL - fail_not_enough_args
|
||||
# CHECK: Errors: 'groups'
|
||||
@run_test
|
||||
def fail_not_enough_args():
|
||||
target = torch.ops.aten.convolution.default
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(3, 3, 3, 3),
|
||||
bias=torch.randn(3),
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
# Missing groups=1,
|
||||
)
|
||||
build_ts_script_function(target._schema, kwargs)
|
||||
|
||||
|
||||
# CHECK: graph(%input : Tensor,
|
||||
# CHECK: %weight : Tensor,
|
||||
# CHECK: %bias : Tensor):
|
||||
# CHECK: %4 : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %5 : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %6 : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %7 : bool = prim::Constant[value=0]()
|
||||
# CHECK: %8 : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %9 : int = prim::Constant[value=1]()
|
||||
# CHECK: %0 : Tensor = aten::convolution(%input, %weight, %bias, %4, %5, %6, %7, %8, %9)
|
||||
# CHECK: return (%0)
|
||||
# -----
|
||||
# CHECK: PASS - simple_kwargs
|
||||
@run_test
|
||||
def simple_kwargs():
|
||||
target = torch.ops.aten.convolution.default
|
||||
script_fun1 = build_ts_script_function(
|
||||
target._schema,
|
||||
dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(3, 3, 3, 3),
|
||||
bias=torch.randn(3),
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
print(script_fun1.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[C2:.*]] : Tensor):
|
||||
# CHECK: %[[C3:.*]] : int[] = prim::Constant[value=[3, 3]]()
|
||||
# CHECK: %[[C4:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[C5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[C6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[C7:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[C0:.*]] : Tensor, %[[C1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]])
|
||||
# CHECK: return (%[[C0]], %[[C1]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_empty_lists
|
||||
@run_test
|
||||
def handle_empty_lists():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
# print(target._schema)
|
||||
input = torch.randn((1, 3, 32, 32))
|
||||
kwargs = dict(
|
||||
input=input,
|
||||
kernel_size=[3, 3],
|
||||
stride=[],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
ceil_mode=False,
|
||||
)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[D2:.*]] : Tensor):
|
||||
# CHECK: %[[D3:.*]] : int[] = prim::Constant[value=[3, 3]]()
|
||||
# CHECK: %[[D4:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[D5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[D6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[D7:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[D0:.*]] : Tensor, %[[D1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[D2]], %[[D3]], %[[D4]], %[[D5]], %[[D6]], %[[D7]])
|
||||
# CHECK: return (%[[D0]], %[[D1]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_nones
|
||||
@run_test
|
||||
def handle_nones():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
# print(target._schema)
|
||||
kwargs = dict(
|
||||
input=torch.randn((1, 3, 32, 32)),
|
||||
kernel_size=[3, 3],
|
||||
stride=None,
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
ceil_mode=False,
|
||||
)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[E1:.*]] : Tensor,
|
||||
# CHECK: %[[E2:.*]] : Tensor,
|
||||
# CHECK: %[[E3:.*]] : Tensor):
|
||||
# CHECK: %[[E4:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[E5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[E6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[E7:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[E8:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[E9:.*]] : int = prim::Constant[value=1]()
|
||||
# CHECK: %[[E0:.*]] : Tensor = aten::convolution(%[[E1]], %[[E2]], %[[E3]], %[[E4]], %[[E5]], %[[E6]], %[[E7]], %[[E8]], %[[E9]])
|
||||
# CHECK: return (%[[E0]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_optional_tensors
|
||||
@run_test
|
||||
def handle_optional_tensors():
|
||||
target = torch.ops.aten.convolution.default
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(3, 3, 3, 3),
|
||||
bias=torch.randn(3),
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[F1:.*]] : Tensor):
|
||||
# CHECK: %[[F2:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F3:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F4:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F5:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F6:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F0:.*]] : Tensor = aten::ones_like(%[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], %[[F6]])
|
||||
# CHECK: return (%[[F0]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_ones_like
|
||||
@run_test
|
||||
def handle_ones_like():
|
||||
target = torch.ops.aten.ones_like.default
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
dtype=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
memory_format=None,
|
||||
)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[G3:.*]] : Tensor,
|
||||
# CHECK: %[[G4:.*]] : Tensor,
|
||||
# CHECK: %[[G5:.*]] : Tensor):
|
||||
# CHECK: %[[G6:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[G7:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[G8:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[G9:.*]] : float = prim::Constant[value=1.]()
|
||||
# CHECK: %[[G10:.*]] : float = prim::Constant[value=1.]()
|
||||
# CHECK: %[[G0:.*]] : Tensor, %[[G1:.*]] : Tensor, %[[G2:.*]] : Tensor = aten::native_batch_norm(%[[G3]], %[[G4]], %[[G5]], %[[G6]], %[[G7]], %[[G8]], %[[G9]], %[[G10]])
|
||||
# CHECK: return (%[[G0]], %[[G1]], %[[G2]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_multiple_outputs
|
||||
@run_test
|
||||
def handle_multiple_outputs():
|
||||
target = torch.ops.aten.native_batch_norm.default
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(1, 3, 32, 32),
|
||||
bias=torch.randn(1, 3, 32, 32),
|
||||
running_mean=None,
|
||||
running_var=None,
|
||||
training=False,
|
||||
momentum=1.0,
|
||||
eps=1.0
|
||||
)
|
||||
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: f
|
||||
# CHECK: PASS - check_legal_name
|
||||
@run_test
|
||||
def check_legal_name():
|
||||
target = torch.ops.aten.native_batch_norm.default
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(1, 3, 32, 32),
|
||||
bias=torch.randn(1, 3, 32, 32),
|
||||
running_mean=None,
|
||||
running_var=None,
|
||||
training=False,
|
||||
momentum=1.0,
|
||||
eps=1.0
|
||||
)
|
||||
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.name)
|
||||
|
||||
|
||||
# CHECK: graph(%[[H3:.*]] : Tensor,
|
||||
# CHECK: %[[H4:.*]] : Tensor,
|
||||
# CHECK: %[[H5:.*]] : Tensor,
|
||||
# CHECK: %[[H6:.*]] : Tensor,
|
||||
# CHECK: %[[H7:.*]] : Tensor,
|
||||
# CHECK: %out : Tensor,
|
||||
# CHECK: %save_mean : Tensor,
|
||||
# CHECK: %save_invstd : Tensor):
|
||||
# CHECK: %[[H8:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[H9:.*]] : float = prim::Constant[value=0.10000000000000001]()
|
||||
# CHECK: %[[H10:.*]] : float = prim::Constant[value=0.0001]()
|
||||
# CHECK: %[[H0:.*]] : Tensor, %[[H1:.*]] : Tensor, %[[H2:.*]] : Tensor = aten::native_batch_norm(%[[H3]], %[[H4]], %[[H5]], %[[H6]], %[[H7]], %[[H8]], %[[H9]], %[[H10]], %out, %save_mean, %save_invstd)
|
||||
# CHECK: return (%[[H0]], %[[H1]], %[[H2]])
|
||||
# -----
|
||||
# CHECK: PASS - correctly_order_kwargs
|
||||
@run_test
|
||||
def correctly_order_kwargs():
|
||||
target = torch.ops.aten.native_batch_norm.out
|
||||
|
||||
input = torch.randn(2, 5, 2, 3)
|
||||
running_mean = torch.randn(5)
|
||||
running_var = torch.randn(5)
|
||||
|
||||
kwargs = dict(
|
||||
input=torch.randn(2, 5, 2, 3),
|
||||
weight=torch.randn(5),
|
||||
bias=torch.randn(5),
|
||||
running_mean=running_mean,
|
||||
running_var=running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.0001,
|
||||
out=torch.empty_like(input),
|
||||
save_mean=torch.empty_like(running_mean),
|
||||
save_invstd=torch.empty_like(running_var),
|
||||
)
|
||||
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
|
@ -1,23 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: true
|
||||
|
||||
|
||||
def run_test(*args, XPASS=False, XFAIL=False):
|
||||
def _run_test(test):
|
||||
test_name = test.__name__
|
||||
try:
|
||||
test()
|
||||
print(("X" if XPASS else "") + f"PASS - {test_name}")
|
||||
except Exception as e:
|
||||
print(("X" if XFAIL else "") + f"FAIL - {test_name}")
|
||||
print("Errors: ", e)
|
||||
print()
|
||||
|
||||
if len(args):
|
||||
_run_test(args[0])
|
||||
else:
|
||||
return _run_test
|
|
@ -1,69 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from framework import run_test
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs
|
||||
|
||||
|
||||
# CHECK: PASS - should_normalize
|
||||
@run_test
|
||||
def should_normalize():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
input = torch.randn((1, 3, 32, 32))
|
||||
kwargs = {"kernel_size": [3, 3]}
|
||||
golden = {
|
||||
"kernel_size": [3, 3],
|
||||
# This is due to the schema for max_pool2d_with_indices defining
|
||||
# the stride arg as int[2] stride=[].
|
||||
"stride": [],
|
||||
"padding": [0, 0],
|
||||
"dilation": [1, 1],
|
||||
"ceil_mode": False,
|
||||
}
|
||||
|
||||
new_kwargs = normalize_args_kwargs(target, (input,), kwargs)
|
||||
assert torch.allclose(new_kwargs["input"], input)
|
||||
for k, v in new_kwargs.items():
|
||||
if k == "input": continue
|
||||
assert v == golden[k]
|
||||
|
||||
|
||||
# CHECK: FAIL - shouldnt_normalize1
|
||||
# CHECK: Errors: missing a required argument: 'kernel_size'
|
||||
@run_test
|
||||
def shouldnt_normalize1():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"stride": []}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
||||
|
||||
# This next two tests are XPASS because of https://github.com/pytorch/pytorch/issues/75342
|
||||
# I.e., they should fail but in fact they pass because of the upstream bug.
|
||||
# The reason for the bug is a fast path branch in operator_schemas.normalize_function
|
||||
# that doesn't do rigorous type checking, and hence lets type mistmatches slip through.
|
||||
# TODO(max): change these to FAIL when the upstream bug is fixed.
|
||||
|
||||
# CHECK: XPASS - shouldnt_normalize2
|
||||
@run_test(XPASS=True)
|
||||
def shouldnt_normalize2():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"kernel_size": []}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
||||
|
||||
# CHECK: XPASS - shouldnt_normalize3
|
||||
@run_test(XPASS=True)
|
||||
def shouldnt_normalize3():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"kernel_size": [3, 3], "padding": None}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
|
@ -1,11 +0,0 @@
|
|||
#-------------------------------------------------------------------------------
|
||||
# Subdirectories
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
## Declare the sources of the Python module.
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.EagerMode
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES_GLOB eager_mode/*.py lazytensor/*.py
|
||||
)
|
|
@ -1,3 +0,0 @@
|
|||
import os
|
||||
|
||||
EAGER_MODE_DEBUG = os.environ.get("EAGER_MODE_DEBUG", 'False').lower() in ('true', '1', 't')
|
|
@ -1,359 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
"""
|
||||
Translator from torch.jit.ScriptFunction to MLIR.
|
||||
|
||||
|
||||
The following defines a set of classes for converting types used by Python and PyTorch into MLIR types from the
|
||||
`torch` dialect.
|
||||
|
||||
The expected use of this module is to create an instance of one of the classes below, and then calling the
|
||||
`to_mlir` method to generate the MLIR representation of the type.
|
||||
|
||||
Information about what types are supported by each class can be found in docstrings of each of the classes.
|
||||
|
||||
In addition this module defines a function that take a torch.jit.ScriptFunction and converts it into an MLIR module.
|
||||
|
||||
The expected use for this module is to use the function
|
||||
`build_module(jit_function: torch.jit.ScriptFunction annotation: Annotation) -> ir.Module`
|
||||
to convert the TorchScript function into MLIR using the `torch` dialect.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import re
|
||||
from typing import Any, Optional, Iterable, Dict
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.jit
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects.func import FuncOp
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
|
||||
|
||||
class TorchMlirType(abc.ABC):
|
||||
"""
|
||||
A `TorchMlirType` is an object that produces MLIR
|
||||
types in the `torch` dialect. The only requirement
|
||||
for a class to be a subclass of `TorchMlirType` is
|
||||
to define a `to_mlir(self, ir.Context) -> ir.Type`.
|
||||
Each class is allowed to have different types of
|
||||
__init__ methods depending on the information they
|
||||
require to produce the given MLIR representation.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
pass
|
||||
|
||||
|
||||
class TorchTensorTypeError(Exception):
|
||||
def __init__(self, value: str):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class TorchTensorType(TorchMlirType):
|
||||
"""
|
||||
This class is used to generate types of the form
|
||||
!torch.tensor and !torch.vtensor<SHAPE, DTYPE>,
|
||||
where SHAPE is a list representing the shape of the tensor,
|
||||
and DTYPE is an MLIR data type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shape: Optional[Iterable[Optional[int]]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
if dtype is None and shape is not None:
|
||||
err = "If shape is specified, dtype must also be specified"
|
||||
raise TorchTensorTypeError(err)
|
||||
|
||||
def __str__(self):
|
||||
return f"Torch Tensor (shape={self.shape}, dtype={self.dtype})"
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
if self.dtype is None:
|
||||
return ir.Type.parse("!torch.tensor", context=context)
|
||||
|
||||
shape_asm = self._shape_to_mlir_asm()
|
||||
dtype_asm = self._dtype_to_mlir_asm()
|
||||
return ir.Type.parse(
|
||||
f"!torch.vtensor<{shape_asm},{dtype_asm}>", context=context
|
||||
)
|
||||
|
||||
def _shape_to_mlir_asm(self) -> str:
|
||||
if self.shape is None:
|
||||
return "*"
|
||||
|
||||
str_sizes = map(lambda x: "?" if x is None else str(x), self.shape)
|
||||
return f'[{",".join(str_sizes)}]'
|
||||
|
||||
def _dtype_to_mlir_asm(self) -> str:
|
||||
if self.dtype in [torch.float64]:
|
||||
return "f64"
|
||||
if self.dtype in [torch.float, torch.float32]:
|
||||
return "f32"
|
||||
if self.dtype in [torch.int, torch.int32]:
|
||||
return "si32"
|
||||
if self.dtype in [torch.int64]:
|
||||
return "si64"
|
||||
if self.dtype in [torch.bool]:
|
||||
return "i1"
|
||||
|
||||
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
|
||||
|
||||
|
||||
class TorchNnModuleType(TorchMlirType):
|
||||
"""This class is used to generate types for `!torch.nn.Module`s."""
|
||||
|
||||
def __init__(self, module_name: str):
|
||||
self.module_name = module_name
|
||||
|
||||
def __str__(self):
|
||||
return "torch.nn.Module"
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', context=context)
|
||||
|
||||
|
||||
class PythonType(TorchMlirType):
|
||||
"""
|
||||
This class is used to convert regular Python types
|
||||
into their corresponding `torch` dialect representation.
|
||||
The list of supported types can be found in the dictionary
|
||||
`_type_to_asm_dict`.
|
||||
"""
|
||||
|
||||
_type_to_asm_dict = {
|
||||
bool: "!torch.bool",
|
||||
int: "!torch.int",
|
||||
type(None): "!torch.none",
|
||||
}
|
||||
|
||||
def __init__(self, type_: Any):
|
||||
self.type_ = type_
|
||||
|
||||
def __str__(self):
|
||||
return str(self.type_)
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
asm = self._type_to_asm_dict.get(self.type_)
|
||||
if asm is None:
|
||||
raise NotImplementedError(f"Unsupported type: {self.type_}")
|
||||
return ir.Type.parse(asm, context=context)
|
||||
|
||||
|
||||
# TODO: This functionality should be incorporated into ModuleBuilder.import_function.
|
||||
class Annotation:
|
||||
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
|
||||
self.types = list(
|
||||
map(lambda t: PythonType(t) if isinstance(t, type) else t, types)
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
result = f"Annotation instance with {len(self.types)} types\n"
|
||||
for e, type_ in enumerate(self.types):
|
||||
result += f" Type of argument {e + 1}: {str(type_)}\n"
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.types)
|
||||
|
||||
|
||||
class AnnotationConverter:
|
||||
@staticmethod
|
||||
def to_mlir_array_attr(annotation: Annotation, context: ir.Context) -> ir.ArrayAttr:
|
||||
dict_attrs = []
|
||||
for type_ in annotation:
|
||||
if not isinstance(type_, TorchTensorType):
|
||||
dict_attrs.append(ir.DictAttr.get({}, context=context))
|
||||
continue
|
||||
|
||||
ir_type = type_.to_mlir(context)
|
||||
with context:
|
||||
type_attr = ir.TypeAttr.get(ir_type)
|
||||
dict_attr = ir.DictAttr.get({"torch.type_bound": type_attr})
|
||||
dict_attrs.append(dict_attr)
|
||||
|
||||
return ir.ArrayAttr.get(dict_attrs, context=context)
|
||||
|
||||
|
||||
def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
||||
with module.context:
|
||||
name_attr = ir.StringAttr.get(name)
|
||||
for op in module.body.operations:
|
||||
if isinstance(op, FuncOp) and op.name == name_attr:
|
||||
# Add name of torch op as debug_module_name so that
|
||||
# run_pipeline_with_repro_report can use it.
|
||||
module.operation.attributes["torch.debug_module_name"] = name_attr
|
||||
return op
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_tensor_type(typ: torch._C.Type):
|
||||
return typ.isSubtypeOf(torch.TensorType.get()) or (
|
||||
isinstance(typ, torch.OptionalType)
|
||||
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
|
||||
)
|
||||
|
||||
|
||||
def is_list_of_tensors_type(typ: torch._C.Type):
|
||||
return isinstance(typ, torch.ListType) and is_tensor_type(typ.getElementType())
|
||||
|
||||
|
||||
name_mangle_regex = re.compile("[^a-zA-Z0-9]")
|
||||
|
||||
|
||||
def build_ts_script_function(
|
||||
schema: torch._C.FunctionSchema, kwargs: Dict[str, Any]
|
||||
) -> torch.jit.ScriptFunction:
|
||||
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
|
||||
|
||||
Constants are inlined for the purposes of invalidating the compile cache when they change.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema: torch._C.FunctionSchema
|
||||
PyTorch's representation for ops, contains type information needed for inlining constants into the TS graph.
|
||||
kwargs: Dict
|
||||
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float/bool params).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.jit.ScriptFunction
|
||||
Fully specialized (all constants) TS graph whose only arguments are tensors.
|
||||
"""
|
||||
|
||||
# Creates empty TS graph.
|
||||
graph = torch._C.Graph()
|
||||
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
|
||||
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
|
||||
# Associate graph inputs/outputs with node inputs/outputs.
|
||||
graph_inputs = []
|
||||
for arg in schema.arguments:
|
||||
arg_name = arg.name if arg.name != "self" else "input"
|
||||
|
||||
# If arg is a flattened list of tensors, such as in the case of torch.cat
|
||||
# then add each element of the list to the graph corresponding to arg
|
||||
# and insert a ListConstruct to function as input to the op.
|
||||
if is_list_of_tensors_type(arg.type):
|
||||
inps = []
|
||||
for kwarg in [
|
||||
kwarg for kwarg in kwargs if f"{arg_name}_flattened" in kwarg
|
||||
]:
|
||||
inp = graph.addInput()
|
||||
el_typ = arg.type.getElementType()
|
||||
if isinstance(el_typ, torch.OptionalType):
|
||||
el_typ = el_typ.getElementType()
|
||||
inp.setType(el_typ)
|
||||
inp.setDebugName(kwarg)
|
||||
inps.append(inp)
|
||||
graph_inputs.append(kwarg)
|
||||
list_cons = graph.insertNode(graph.create("prim::ListConstruct", inps))
|
||||
list_cons.moveBefore(node)
|
||||
inp = list_cons.output()
|
||||
inp.setType(torch.ListType.ofTensors())
|
||||
# If arg is a tensor, then add input to the graph corresponding to arg.
|
||||
elif is_tensor_type(arg.type) and kwargs[arg_name] is not None:
|
||||
inp = graph.addInput()
|
||||
if isinstance(arg.type, torch.OptionalType):
|
||||
el_typ = arg.type.getElementType()
|
||||
else:
|
||||
el_typ = arg.type
|
||||
inp.setType(el_typ)
|
||||
inp.setDebugName(arg_name)
|
||||
graph_inputs.append(arg_name)
|
||||
# If arg is a constant, inline (at the top of the graph).
|
||||
else:
|
||||
val = kwargs[arg_name]
|
||||
if val == []:
|
||||
# Some ops have empty list default values for args
|
||||
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
|
||||
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
|
||||
# This might be an upstream bug but there doesn't seem to be a way to
|
||||
# build a prim::ListConstruct list that's empty.
|
||||
val = None
|
||||
inp = graph.insertConstant(val)
|
||||
inp.node().moveBefore(node)
|
||||
|
||||
node.addInput(inp)
|
||||
|
||||
# Reorder graph inputs to match kwargs.
|
||||
permutes = [
|
||||
{inp: i for i, inp in enumerate(graph_inputs)}[kwarg]
|
||||
for kwarg in [kwarg for kwarg in kwargs if kwarg in graph_inputs]
|
||||
]
|
||||
graph.permuteInputs(permutes)
|
||||
|
||||
if node.hasMultipleOutputs():
|
||||
for outp in node.outputs():
|
||||
graph.registerOutput(outp)
|
||||
else:
|
||||
graph.registerOutput(node.output())
|
||||
|
||||
fn = torch._C._create_function_from_graph(
|
||||
f"{name_mangle_regex.sub('', str(graph))}", graph
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def build_mlir_module(op: OpOverload, kwargs: Dict[str, Any]) -> ir.Module:
|
||||
"""Translate input function into an MLIR module in the `torch` dialect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op: OpOverload
|
||||
Callable from the torch.ops.aten module/namespace that has a _schema field.
|
||||
kwargs: Dict
|
||||
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float,bool params).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ir.Module
|
||||
Translation of the input module into an MLIR module.
|
||||
"""
|
||||
|
||||
# The assert here is to catch tensor shapes that have size 0 dimensions, such as those produced in
|
||||
# the course of evaluating SliceEndSleStartModule_basic and SliceOutOfLowerBoundEndIndexModule_basic.
|
||||
# Such 0 size dimensions fail the assert at mlir/lib/IR/BuiltinTypes.cpp, line 887
|
||||
annotations = []
|
||||
for arg_name, arg in kwargs.items():
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert np.prod(arg.shape) != 0, f"{arg_name} has invalid shape {arg.shape}"
|
||||
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg.dtype))
|
||||
annotations = tuple(annotations)
|
||||
|
||||
script_fun = build_ts_script_function(op._schema, kwargs)
|
||||
assert len(annotations) == len(
|
||||
list(script_fun.graph.inputs())
|
||||
), "Number of annotations and number of graph inputs differs."
|
||||
|
||||
mb = ModuleBuilder()
|
||||
mb.import_function(script_fun)
|
||||
|
||||
func_op = get_func_op_with_name(mb.module, script_fun.name)
|
||||
assert (
|
||||
func_op is not None
|
||||
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
|
||||
|
||||
func_annotation = Annotation(annotations)
|
||||
arg_attrs = AnnotationConverter.to_mlir_array_attr(func_annotation, mb.context)
|
||||
func_op.attributes["arg_attrs"] = arg_attrs
|
||||
|
||||
return mb.module
|
|
@ -1,111 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Tuple
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.fx import immutable_collections
|
||||
from torch.fx.operator_schemas import (
|
||||
_torchscript_schema_to_signature,
|
||||
_args_kwargs_to_normalized_args_kwargs,
|
||||
)
|
||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops
|
||||
|
||||
from torch_mlir.dialects import torch as torch_dialect
|
||||
|
||||
OP_REGISTRY = {op["name"]: op for op in get_registered_ops()}
|
||||
SUPPORTED_OPS = frozenset(
|
||||
[
|
||||
member.OPERATION_NAME
|
||||
for member in vars(torch_dialect).values()
|
||||
if hasattr(member, "OPERATION_NAME")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedByTorchMlirEagerMode(Exception):
|
||||
def __init__(self, value: str):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
|
||||
"""Fill in default values for optional args, which are dependent on the schema."""
|
||||
sig = _torchscript_schema_to_signature(target._schema)
|
||||
_, new_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
||||
sig, args, kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
if "self" in new_kwargs:
|
||||
new_kwargs["input"] = new_kwargs.pop("self")
|
||||
|
||||
# Flatten lists of args for ops that takes lists, such as torch.cat.
|
||||
to_remove = set()
|
||||
to_add = {}
|
||||
for k, v in new_kwargs.items():
|
||||
if isinstance(v, (tuple, list)) and len(v) and isinstance(v[0], torch.Tensor):
|
||||
to_remove.add(k)
|
||||
for i, vv in enumerate(v):
|
||||
to_add[f"{k}_flattened_{i}"] = vv
|
||||
|
||||
for rem in to_remove:
|
||||
del new_kwargs[rem]
|
||||
new_kwargs.update(**to_add)
|
||||
|
||||
# Sort here in order to have consistency across TS graph and
|
||||
# MLIR module.
|
||||
sorted_kwargs = dict(sorted(new_kwargs.items()))
|
||||
return immutable_collections.immutable_dict(sorted_kwargs)
|
||||
|
||||
|
||||
def get_registered_op(op):
|
||||
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
|
||||
return registered_op
|
||||
|
||||
|
||||
def check_get_aliased_arg(func: Callable,):
|
||||
"""Write back to mutable args that aren't properly handled otherwise.
|
||||
|
||||
Because of how we pass values to the backend we don't currently support ops that mutate operands.
|
||||
That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
|
||||
as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which
|
||||
we can for these special cases. Hence, the solution is to manually write back to the same operand that the
|
||||
conventional pytorch op variant would write to.
|
||||
|
||||
Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that
|
||||
mutate running_mean and running_var). We don't currently handle those.
|
||||
"""
|
||||
|
||||
registered_op = get_registered_op(func)
|
||||
if not registered_op["is_mutable"]:
|
||||
return None
|
||||
|
||||
if len(registered_op["returns"]) > 1:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
"TorchMLIR doesn't handle multiple aliased returns yet."
|
||||
)
|
||||
|
||||
aliased_arg = next(
|
||||
arg
|
||||
for arg in registered_op["arguments"]
|
||||
if "alias_info" in arg and arg["alias_info"]["is_write"]
|
||||
)
|
||||
assert (
|
||||
"alias_info" in registered_op["returns"][0]
|
||||
and registered_op["returns"][0]["alias_info"]["is_write"]
|
||||
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
|
||||
and registered_op["returns"][0]["alias_info"]["after"][0]
|
||||
)
|
||||
assert (
|
||||
len(aliased_arg["alias_info"]["after"]) == 1
|
||||
and aliased_arg["alias_info"]["after"][0]
|
||||
== registered_op["returns"][0]["alias_info"]["after"][0]
|
||||
)
|
||||
|
||||
return aliased_arg["name"] if aliased_arg["name"] != "self" else "input"
|
|
@ -1,102 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Tuple, Callable, List, Dict, Any
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir._mlir_libs._mlir.ir import Module
|
||||
|
||||
# TODO: This might need to be an ABC too, such as
|
||||
# to support finding the backend that created the tensor.
|
||||
DeviceTensor = TypeVar("DeviceTensor")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorMetaData:
|
||||
"""A small container for metadata necessary for satisfying the pytorch dispatcher and other code (pytorch or
|
||||
otherwise) that branches on these attributes.
|
||||
|
||||
There is a lot of code in the PyTorch codebase that branches based on these attributes; the obvious ones here
|
||||
are dtype, device, and requires_grad (necessary for autograd itself). There is ample warning from PyTorch that,
|
||||
in principle, these should be as close as possible to true; see
|
||||
https://github.com/albanD/subclass_zoo/blob/1566e038f03cd89ab3cc37e670a44e3c2bbc1897/trivial_tensors.py#L90-L92
|
||||
|
||||
The defaults (properties) simplify the api and seem to work after some testing but
|
||||
might malfunction in unexpected ways.
|
||||
# TODO: revisit these assumptions
|
||||
"""
|
||||
|
||||
size: Tuple[int]
|
||||
dtype: torch.dtype
|
||||
requires_grad: bool
|
||||
|
||||
strides: Tuple[int]
|
||||
storage_offset: int = 0
|
||||
layout: torch.layout = torch.strided
|
||||
device: torch.device = torch.device("cpu")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
dtype,
|
||||
requires_grad,
|
||||
strides=None,
|
||||
storage_offset=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "size", size)
|
||||
object.__setattr__(self, "dtype", dtype)
|
||||
object.__setattr__(self, "requires_grad", requires_grad)
|
||||
|
||||
object.__setattr__(
|
||||
self, "strides", strides if strides is not None else len(size) * [0]
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "storage_offset", storage_offset if storage_offset is not None else 0
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "layout", layout if layout is not None else torch.strided
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "device", device if device is not None else torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
class TorchMLIREagerBackend(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def compile(
|
||||
self, module: Module
|
||||
) -> Callable[[List[DeviceTensor]], List[DeviceTensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> DeviceTensor:
|
||||
"""Unwrap the backend representation in order to build a torch.Tensor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_torch_metadata(
|
||||
self, tensor: DeviceTensor, kwargs: Dict[str, Any]
|
||||
) -> TensorMetaData:
|
||||
"""Parse relevant tensor metadata from backend device array (e.g., shape, stride, layout) in order to build
|
||||
wrapper tensor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer_from_device_to_torch(self, tensor: DeviceTensor) -> torch.Tensor:
|
||||
"""If compilation fails for some reason then device specific representations need to be munged into a
|
||||
torch.Tensor representation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def copy_into(self, dst: DeviceTensor, src: DeviceTensor):
|
||||
"""This method is needed for things like handling aliased arguments."""
|
||||
raise NotImplementedError
|
|
@ -1,257 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
import contextlib
|
||||
import re
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from torch_mlir.eager_mode.ir_building import build_mlir_module
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import (
|
||||
UnsupportedByTorchMlirEagerMode,
|
||||
normalize_args_kwargs,
|
||||
check_get_aliased_arg,
|
||||
)
|
||||
from torch_mlir.eager_mode import EAGER_MODE_DEBUG
|
||||
from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_dispatch():
|
||||
"""Prevent infinite recursion in case accidentally calling a tensor method on a TorchMLIRTensor within
|
||||
__torch_dispatch__."""
|
||||
|
||||
guard = torch._C._DisableTorchDispatch()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
backend = EagerModeRefBackend()
|
||||
|
||||
UNSUPPORTED_OPS = re.compile(
|
||||
"|".join([
|
||||
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
|
||||
"detach",
|
||||
# We don't handle _local_scalar_dense because it's just a way to unwrap a tensor that wraps a number.
|
||||
"_local_scalar_dense",
|
||||
# https://github.com/llvm/torch-mlir/issues/878
|
||||
"_unsafe_view",
|
||||
"view",
|
||||
])
|
||||
)
|
||||
|
||||
|
||||
class TorchMLIRTensor(torch.Tensor):
|
||||
"""This class serves the role abstract class with common functionality for dispatching through Torch-MLIR instead of aten.
|
||||
|
||||
It defers device specific behavior to device specific implementations. The deriving classes use the
|
||||
make_bare_wrapper_subclass convenience method, adjacent here, and override __torch_dispatch__ in order to dispatch
|
||||
through Torch-MLIR instead of aten. Backends are free to choose whatever representation of the buffers (i.e., `elem`)
|
||||
and are expected to provide conversion mechanisms between their representation and torch.Tensor.
|
||||
|
||||
Here we only verify that inputs abide by current supported features of Torch-MLIR (contiguous memory and
|
||||
strided tensor layout) and build the mlir module. Importantly, we also recover from any malfunctions in the
|
||||
deriving classes and dispatch back to conventional PyTorch.
|
||||
|
||||
More documentation on how the __torch_dispatch__ pattern works can be found in this forum post
|
||||
https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557
|
||||
and this RFC
|
||||
https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call
|
||||
and this repo with many examples
|
||||
https://github.com/albanD/subclass_zoo
|
||||
"""
|
||||
|
||||
elem: Any
|
||||
|
||||
__slots__ = ["elem"]
|
||||
|
||||
def __new__(cls, elem, **kwargs):
|
||||
"""Wrap elem (which could be a torch.Tensor or otherwise) in a torch.Tensor subclass.
|
||||
|
||||
Critically, this method needs to parse relevant metadata from the device representation
|
||||
(such as shape, striding, dtype, etc.) and translate it into torch conventions.
|
||||
|
||||
Deriving classes must provide a way to construct themselves from either their device specific representation
|
||||
or torch.Tensor; the latter is to handle the case that dispatch to PyTorch to recover from an error.
|
||||
"""
|
||||
if kwargs.get("constructing_from_device_tensor", False):
|
||||
tensor_meta_data = backend.get_torch_metadata(elem, kwargs)
|
||||
r = make_bare_wrapper_subclass(
|
||||
cls=cls,
|
||||
size=tensor_meta_data.size,
|
||||
strides=tensor_meta_data.strides,
|
||||
storage_offset=tensor_meta_data.storage_offset,
|
||||
dtype=tensor_meta_data.dtype,
|
||||
layout=tensor_meta_data.layout,
|
||||
device=tensor_meta_data.device,
|
||||
requires_grad=tensor_meta_data.requires_grad,
|
||||
)
|
||||
r.elem = elem
|
||||
elif isinstance(elem, torch.nn.Parameter):
|
||||
r = make_wrapper_subclass_from_torch_tensor(cls, elem.data, **kwargs)
|
||||
r.elem = backend.transfer_from_torch_to_device(elem.detach().data)
|
||||
elif isinstance(elem, torch.Tensor):
|
||||
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
|
||||
r.elem = backend.transfer_from_torch_to_device(elem)
|
||||
# This branch handles the case when a python scalar is passed to some op
|
||||
# or is returned from some aten op, such as _local_scalar_dense.
|
||||
elif isinstance(elem, (int, float, bool)):
|
||||
return elem
|
||||
else:
|
||||
raise ValueError(f"Unknown element type: {type(elem)}")
|
||||
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})"
|
||||
else:
|
||||
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
|
||||
requires_grad = check_requires_grad(*args, **kwargs)
|
||||
try:
|
||||
with no_dispatch():
|
||||
if hasattr(func, "op_name"):
|
||||
op_name = func.op_name
|
||||
elif hasattr(func, "__name__"):
|
||||
# Handle builtin_function_or_method.
|
||||
op_name = func.__name__
|
||||
else:
|
||||
raise RuntimeError(f"op {func} has no name")
|
||||
|
||||
requires_grad = requires_grad and "view" not in op_name
|
||||
|
||||
if UNSUPPORTED_OPS.match(op_name):
|
||||
raise UnsupportedByTorchMlirEagerMode(op_name)
|
||||
|
||||
if not hasattr(func, "_schema"):
|
||||
raise RuntimeError(f"op {func} has no schema.")
|
||||
|
||||
normalized_kwargs = normalize_args_kwargs(func, args, kwargs)
|
||||
|
||||
if "layout" in normalized_kwargs and normalized_kwargs[
|
||||
"layout"
|
||||
] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{normalized_kwargs['layout']} layout not supported."
|
||||
)
|
||||
if "memory_format" in normalized_kwargs and normalized_kwargs[
|
||||
"memory_format"
|
||||
] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{normalized_kwargs['memory_format']} memory format not supported."
|
||||
)
|
||||
eager_module = build_mlir_module(func, normalized_kwargs)
|
||||
device_tensor_args = [
|
||||
kwarg.elem
|
||||
for _, kwarg in normalized_kwargs.items()
|
||||
if isinstance(kwarg, cls)
|
||||
]
|
||||
assert len(eager_module.body.operations[0].arguments) == len(
|
||||
device_tensor_args
|
||||
), "Number of parameters and number of arguments differs."
|
||||
op_mlir_backend_callable = backend.compile(eager_module)
|
||||
out = op_mlir_backend_callable(*device_tensor_args)
|
||||
out = tree_map(
|
||||
lambda x: cls(
|
||||
x, requires_grad=requires_grad, constructing_from_device_tensor=True
|
||||
),
|
||||
out,
|
||||
)
|
||||
except Exception as e:
|
||||
if EAGER_MODE_DEBUG:
|
||||
warnings.warn(traceback.format_exc())
|
||||
if isinstance(e, UnsupportedByTorchMlirEagerMode):
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
|
||||
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
|
||||
)
|
||||
|
||||
with no_dispatch():
|
||||
unwrapped_args = tree_map(cls.unwrap, args)
|
||||
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
|
||||
out = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)
|
||||
|
||||
maybe_aliased_arg_name = check_get_aliased_arg(func)
|
||||
if maybe_aliased_arg_name is not None:
|
||||
backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem)
|
||||
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def unwrap(cls, e):
|
||||
"""Unwrap the TorchMLIRTensor representation in order to access the actual device specific representation."""
|
||||
if isinstance(e, cls):
|
||||
return backend.transfer_from_device_to_torch(e.elem)
|
||||
return e
|
||||
|
||||
|
||||
def check_requires_grad(*args, **kwargs):
|
||||
requires_grad = False
|
||||
|
||||
def check_grad(e):
|
||||
nonlocal requires_grad
|
||||
if isinstance(e, TorchMLIRTensor):
|
||||
requires_grad |= e.requires_grad
|
||||
|
||||
tree_map(check_grad, args)
|
||||
tree_map(check_grad, kwargs)
|
||||
|
||||
return requires_grad
|
||||
|
||||
|
||||
def make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs):
|
||||
"""Convenience method that parse out relevant metadata from a torch.Tensor, in order to produce
|
||||
a wrapper subclass.
|
||||
|
||||
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
|
||||
of the device specific implementation.
|
||||
"""
|
||||
r = make_bare_wrapper_subclass(
|
||||
cls=cls,
|
||||
size=elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
# Only float tensors can have gradients.
|
||||
requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64}
|
||||
and (kwargs.get("requires_grad", False) or elem.requires_grad),
|
||||
)
|
||||
return r
|
||||
|
||||
|
||||
def make_bare_wrapper_subclass(
|
||||
*, cls, size, strides, storage_offset, dtype, layout, device, requires_grad
|
||||
):
|
||||
"""Convenience method that builds a wrapper subclass.
|
||||
|
||||
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
|
||||
of the device specific implementation.
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
size,
|
||||
strides=strides,
|
||||
storage_offset=storage_offset,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
requires_grad=requires_grad,
|
||||
)
|
|
@ -9,5 +9,4 @@ from .native_torch import NativeTorchTestConfig
|
|||
from .torchscript import TorchScriptTestConfig
|
||||
from .mhlo_backend import MhloBackendTestConfig
|
||||
from .tosa_backend import TosaBackendTestConfig
|
||||
from .eager_mode import EagerModeTestConfig
|
||||
from .torchdynamo import TorchDynamoTestConfig
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
def wrap(e):
|
||||
return TorchMLIRTensor(e.detach().clone()) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
|
||||
def unwrap(e):
|
||||
return TorchMLIRTensor.unwrap(e) if isinstance(e, TorchMLIRTensor) else e
|
||||
|
||||
|
||||
def to_tmt(m: torch.nn.Module):
|
||||
for buf_name, buf in m.named_buffers(recurse=True):
|
||||
if isinstance(buf, TorchMLIRTensor):
|
||||
continue
|
||||
m.register_buffer(buf_name, TorchMLIRTensor(buf))
|
||||
for param_name, param in m.named_parameters(recurse=True):
|
||||
if isinstance(param, TorchMLIRTensor):
|
||||
continue
|
||||
m.register_parameter(
|
||||
param_name,
|
||||
torch.nn.Parameter(
|
||||
TorchMLIRTensor(param), requires_grad=param.requires_grad
|
||||
),
|
||||
)
|
||||
for attr in dir(m):
|
||||
field = getattr(m, attr)
|
||||
if isinstance(field, torch.Tensor) and not isinstance(field, TorchMLIRTensor):
|
||||
setattr(m, attr, TorchMLIRTensor(field))
|
||||
|
||||
|
||||
class EagerModeTestConfig(TestConfig):
|
||||
"""Trivial test config that exercises eager mode plumbing"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||
program.apply(to_tmt)
|
||||
return program
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
attr = artifact
|
||||
for part in item.symbol.split("."):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
inps = tree_map(wrap, item.inputs)
|
||||
outps = attr(*inps)
|
||||
output = tree_map(unwrap, outps)
|
||||
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
|
@ -1,90 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch_mlir.compiler_utils import (
|
||||
get_module_name_for_debug_dump,
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
|
||||
TorchMLIREagerBackend,
|
||||
TensorMetaData,
|
||||
)
|
||||
from torch_mlir.ir import Module
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||
RefBackendLinalgOnTensorsBackend,
|
||||
)
|
||||
|
||||
NUMPY_TO_TORCH_DTYPE_DICT = {
|
||||
np.bool_: torch.bool,
|
||||
np.uint8: torch.uint8,
|
||||
np.int8: torch.int8,
|
||||
np.int16: torch.int16,
|
||||
np.int32: torch.int32,
|
||||
np.int64: torch.int64,
|
||||
np.float16: torch.float16,
|
||||
np.float32: torch.float32,
|
||||
np.float64: torch.float64,
|
||||
np.complex64: torch.complex64,
|
||||
np.complex128: torch.complex128,
|
||||
}
|
||||
|
||||
_ref_backend = RefBackendLinalgOnTensorsBackend()
|
||||
|
||||
|
||||
class EagerModeRefBackend(TorchMLIREagerBackend):
|
||||
"""Main entry-point for the reference backend for eager mode.
|
||||
|
||||
RefBackend uses numpy.ndarray representations of tensors and thus all of the wrapping and unwrapping
|
||||
and munging here is done to between torch.Tensor and numpy.ndarray.
|
||||
"""
|
||||
|
||||
module_to_refbackend_invoker = {}
|
||||
|
||||
def get_torch_metadata(
|
||||
self, tensor: np.ndarray, kwargs: Dict[str, Any]
|
||||
) -> TensorMetaData:
|
||||
return TensorMetaData(
|
||||
size=tensor.shape,
|
||||
dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type],
|
||||
requires_grad=tensor.dtype in {np.float, np.float32, np.float64}
|
||||
and kwargs.get("requires_grad", False),
|
||||
)
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Lower the imported TS module to linalg and then further compile for the reference backend and then call."""
|
||||
fn_name = get_module_name_for_debug_dump(imported_module)
|
||||
module_hash = str(imported_module)
|
||||
if module_hash not in self.module_to_refbackend_invoker:
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"builtin.module(torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"EagerMode",
|
||||
)
|
||||
self.module_to_refbackend_invoker[module_hash] = _ref_backend.load(
|
||||
_ref_backend.compile(imported_module)
|
||||
)
|
||||
|
||||
ref_backend_invoker = self.module_to_refbackend_invoker[module_hash]
|
||||
op_mlir_backend_callable = getattr(ref_backend_invoker, fn_name)
|
||||
assert (
|
||||
op_mlir_backend_callable is not None
|
||||
), f"Couldn't find function in module."
|
||||
return op_mlir_backend_callable
|
||||
|
||||
def copy_into(self, dst: np.ndarray, src: np.ndarray):
|
||||
np.copyto(dst, src)
|
||||
|
||||
def transfer_from_device_to_torch(self, e: np.ndarray):
|
||||
return torch.from_numpy(e).clone()
|
||||
|
||||
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> np.ndarray:
|
||||
return tensor.detach().numpy()
|
Loading…
Reference in New Issue