mirror of https://github.com/llvm/torch-mlir
add ci tests (#754)
parent
24e04d5729
commit
cec5aeedb0
|
@ -44,6 +44,11 @@ jobs:
|
|||
cd $GITHUB_WORKSPACE
|
||||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||
python -m e2e_testing.torchscript.main --config=refbackend -v
|
||||
- name: EagerMode - TorchScript end-to-end tests
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||
python -m e2e_testing.torchscript.main --config=eager_mode -v
|
||||
- name: TOSA backend - TorchScript end-to-end tests
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
|
|
|
@ -14,12 +14,11 @@ from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
|
|||
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
EAGER_MODE_XFAIL_SET = REFBACKEND_XFAIL_SET.union({
|
||||
# These fail because an upstream pytorch bug; more information at the following issue
|
||||
# https://github.com/pytorch/pytorch/issues/74400
|
||||
"ElementwiseMulScalarModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
})
|
||||
EAGER_MODE_XFAIL_SET = {
|
||||
# RefBackend fails
|
||||
"TableBatchEmbeddingModule_basic",
|
||||
"QuantizedMLP_basic"
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from framework import run_test
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import (
|
||||
annotate_args_kwargs,
|
||||
normalize_args_kwargs,
|
||||
build_script_function,
|
||||
)
|
||||
|
||||
|
||||
# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32)
|
||||
# -----
|
||||
# CHECK: PASS - simple
|
||||
@run_test
|
||||
def simple():
|
||||
target = torch.ops.aten.addmm.default
|
||||
A = torch.randn(1, 3, 32, 32)
|
||||
B = torch.randn(1, 3, 32, 32)
|
||||
C = torch.randn(1, 3, 32, 32)
|
||||
args = (A, B, C)
|
||||
kwargs = dict(beta=1, alpha=1)
|
||||
|
||||
new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs)
|
||||
script_fun = build_script_function(target._schema, new_args, new_kwargs)
|
||||
annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs)
|
||||
for annot in annotations:
|
||||
print(annot)
|
||||
|
||||
|
||||
# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32)
|
||||
# -----
|
||||
# CHECK: PASS - handle_zero_dim
|
||||
@run_test
|
||||
def handle_zero_dim():
|
||||
target = torch.ops.aten.addmm.default
|
||||
A = torch.randn(0, 3, 32, 32)
|
||||
B = torch.randn(0, 3, 32, 32)
|
||||
C = torch.randn(0, 3, 32, 32)
|
||||
args = (A, B, C)
|
||||
kwargs = dict(beta=1, alpha=1)
|
||||
|
||||
new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs)
|
||||
script_fun = build_script_function(target._schema, new_args, new_kwargs)
|
||||
annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs)
|
||||
for annot in annotations:
|
||||
print(annot)
|
||||
|
||||
|
||||
# CHECK: Torch Tensor (shape=(2, 5, 2, 3), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(2, 5, 2, 3), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
|
||||
# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32)
|
||||
# -----
|
||||
# CHECK: PASS - correctly_order_kwargs
|
||||
@run_test
|
||||
def correctly_order_kwargs():
|
||||
target = torch.ops.aten.native_batch_norm.out
|
||||
|
||||
input = torch.randn(2, 5, 2, 3)
|
||||
weight = torch.randn(5)
|
||||
bias = torch.randn(5)
|
||||
running_mean = torch.randn(5)
|
||||
running_var = torch.randn(5)
|
||||
args = (input, weight, bias, running_mean, running_var)
|
||||
|
||||
out = torch.empty_like(input)
|
||||
save_mean = torch.empty_like(running_mean)
|
||||
save_invstd = torch.empty_like(running_var)
|
||||
|
||||
kwargs = dict(
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.0001,
|
||||
out=out,
|
||||
save_mean=save_mean,
|
||||
save_invstd=save_invstd,
|
||||
)
|
||||
|
||||
new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs)
|
||||
script_fun = build_script_function(target._schema, new_args, new_kwargs)
|
||||
annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs)
|
||||
for annot in annotations:
|
||||
print(annot)
|
|
@ -9,7 +9,7 @@
|
|||
import torch
|
||||
|
||||
from framework import run_test
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import build_script_function
|
||||
from torch_mlir.eager_mode.ir_building import build_ts_script_function
|
||||
|
||||
|
||||
# CHECK: graph(%[[A1:.*]] : Tensor,
|
||||
|
@ -24,13 +24,15 @@ from torch_mlir.eager_mode.torch_mlir_dispatch import build_script_function
|
|||
@run_test
|
||||
def simple():
|
||||
target = torch.ops.aten.addmm.default
|
||||
A = torch.randn(1, 3, 32, 32)
|
||||
B = torch.randn(1, 3, 32, 32)
|
||||
C = torch.randn(1, 3, 32, 32)
|
||||
args = (A, B, C)
|
||||
kwargs = dict(beta=1, alpha=1)
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
mat1=torch.randn(1, 3, 32, 32),
|
||||
mat2=torch.randn(1, 3, 32, 32),
|
||||
beta=1,
|
||||
alpha=1,
|
||||
)
|
||||
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
|
@ -50,11 +52,10 @@ def simple():
|
|||
@run_test
|
||||
def handle_optional_tensor_input():
|
||||
target = torch.ops.aten.convolution.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
weight = torch.randn(3, 3, 3, 3)
|
||||
bias = torch.randn(3)
|
||||
args = (input, weight, bias)
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(3, 3, 3, 3),
|
||||
bias=torch.randn(3),
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
|
@ -62,20 +63,19 @@ def handle_optional_tensor_input():
|
|||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
)
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: FAIL - fail_not_enough_args
|
||||
# CHECK: Errors: tuple index out of range
|
||||
# CHECK: Errors: 'groups'
|
||||
@run_test
|
||||
def fail_not_enough_args():
|
||||
target = torch.ops.aten.convolution.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
weight = torch.randn(3, 3, 3, 3)
|
||||
bias = torch.randn(3)
|
||||
args = (input, weight, bias)
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(3, 3, 3, 3),
|
||||
bias=torch.randn(3),
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
|
@ -83,41 +83,41 @@ def fail_not_enough_args():
|
|||
output_padding=[0, 0],
|
||||
# Missing groups=1,
|
||||
)
|
||||
build_script_function(target._schema, args, kwargs)
|
||||
build_ts_script_function(target._schema, kwargs)
|
||||
|
||||
|
||||
# CHECK: PASS - simple_args_or_kwargs
|
||||
# CHECK: graph(%input : Tensor,
|
||||
# CHECK: %weight : Tensor,
|
||||
# CHECK: %bias : Tensor):
|
||||
# CHECK: %4 : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %5 : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %6 : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %7 : bool = prim::Constant[value=0]()
|
||||
# CHECK: %8 : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %9 : int = prim::Constant[value=1]()
|
||||
# CHECK: %0 : Tensor = aten::convolution(%input, %weight, %bias, %4, %5, %6, %7, %8, %9)
|
||||
# CHECK: return (%0)
|
||||
# -----
|
||||
# CHECK: PASS - simple_kwargs
|
||||
@run_test
|
||||
def simple_args_or_kwargs():
|
||||
def simple_kwargs():
|
||||
target = torch.ops.aten.convolution.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
weight = torch.randn(3, 3, 3, 3)
|
||||
bias = torch.randn(3)
|
||||
stride = [1, 1]
|
||||
padding = [0, 0]
|
||||
dilation = [1, 1]
|
||||
transposed = False
|
||||
output_padding = [0, 0]
|
||||
groups = 1
|
||||
script_fun1 = build_script_function(
|
||||
script_fun1 = build_ts_script_function(
|
||||
target._schema,
|
||||
(input, weight, bias),
|
||||
dict(
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
transposed=transposed,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(3, 3, 3, 3),
|
||||
bias=torch.randn(3),
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
script_fun2 = build_script_function(
|
||||
target._schema,
|
||||
(input, weight, bias, stride, padding, dilation),
|
||||
dict(transposed=transposed, output_padding=output_padding, groups=groups),
|
||||
)
|
||||
assert str(script_fun1.graph) == str(script_fun2.graph)
|
||||
print(script_fun1.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[C2:.*]] : Tensor):
|
||||
|
@ -134,15 +134,16 @@ def simple_args_or_kwargs():
|
|||
def handle_empty_lists():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
# print(target._schema)
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {
|
||||
"kernel_size": [3, 3],
|
||||
"stride": [],
|
||||
"padding": [0, 0],
|
||||
"dilation": [1, 1],
|
||||
"ceil_mode": False,
|
||||
}
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
input = torch.randn((1, 3, 32, 32))
|
||||
kwargs = dict(
|
||||
input=input,
|
||||
kernel_size=[3, 3],
|
||||
stride=[],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
ceil_mode=False,
|
||||
)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
|
@ -160,15 +161,15 @@ def handle_empty_lists():
|
|||
def handle_nones():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
# print(target._schema)
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {
|
||||
"kernel_size": [3, 3],
|
||||
"stride": None,
|
||||
"padding": [0, 0],
|
||||
"dilation": [1, 1],
|
||||
"ceil_mode": False,
|
||||
}
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
kwargs = dict(
|
||||
input=torch.randn((1, 3, 32, 32)),
|
||||
kernel_size=[3, 3],
|
||||
stride=None,
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
ceil_mode=False,
|
||||
)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
|
@ -188,11 +189,10 @@ def handle_nones():
|
|||
@run_test
|
||||
def handle_optional_tensors():
|
||||
target = torch.ops.aten.convolution.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
weight = torch.randn(3, 3, 3, 3)
|
||||
bias = torch.randn(3)
|
||||
args = (input, weight, bias)
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(3, 3, 3, 3),
|
||||
bias=torch.randn(3),
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
|
@ -200,7 +200,7 @@ def handle_optional_tensors():
|
|||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
)
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
|
@ -217,12 +217,15 @@ def handle_optional_tensors():
|
|||
@run_test
|
||||
def handle_ones_like():
|
||||
target = torch.ops.aten.ones_like.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
args = (input,)
|
||||
kwargs = dict(
|
||||
dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
dtype=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
memory_format=None,
|
||||
)
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
|
@ -241,13 +244,18 @@ def handle_ones_like():
|
|||
@run_test
|
||||
def handle_multiple_outputs():
|
||||
target = torch.ops.aten.native_batch_norm.default
|
||||
A = torch.randn(1, 3, 32, 32)
|
||||
B = torch.randn(1, 3, 32, 32)
|
||||
C = torch.randn(1, 3, 32, 32)
|
||||
args = (A, B, C, None, None, False, 1.0, 1.0)
|
||||
kwargs = dict()
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(1, 3, 32, 32),
|
||||
bias=torch.randn(1, 3, 32, 32),
|
||||
running_mean=None,
|
||||
running_var=None,
|
||||
training=False,
|
||||
momentum=1.0,
|
||||
eps=1.0
|
||||
)
|
||||
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
|
@ -256,13 +264,18 @@ def handle_multiple_outputs():
|
|||
@run_test
|
||||
def check_legal_name():
|
||||
target = torch.ops.aten.native_batch_norm.default
|
||||
A = torch.randn(1, 3, 32, 32)
|
||||
B = torch.randn(1, 3, 32, 32)
|
||||
C = torch.randn(1, 3, 32, 32)
|
||||
args = (A, B, C, None, None, False, 1.0, 1.0)
|
||||
kwargs = dict()
|
||||
kwargs = dict(
|
||||
input=torch.randn(1, 3, 32, 32),
|
||||
weight=torch.randn(1, 3, 32, 32),
|
||||
bias=torch.randn(1, 3, 32, 32),
|
||||
running_mean=None,
|
||||
running_var=None,
|
||||
training=False,
|
||||
momentum=1.0,
|
||||
eps=1.0
|
||||
)
|
||||
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.name)
|
||||
|
||||
|
||||
|
@ -286,24 +299,22 @@ def correctly_order_kwargs():
|
|||
target = torch.ops.aten.native_batch_norm.out
|
||||
|
||||
input = torch.randn(2, 5, 2, 3)
|
||||
weight = torch.randn(5)
|
||||
bias = torch.randn(5)
|
||||
running_mean = torch.randn(5)
|
||||
running_var = torch.randn(5)
|
||||
args = (input, weight, bias, running_mean, running_var)
|
||||
|
||||
out = torch.empty_like(input)
|
||||
save_mean = torch.empty_like(running_mean)
|
||||
save_invstd = torch.empty_like(running_var)
|
||||
|
||||
kwargs = dict(
|
||||
input=torch.randn(2, 5, 2, 3),
|
||||
weight=torch.randn(5),
|
||||
bias=torch.randn(5),
|
||||
running_mean=running_mean,
|
||||
running_var=running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.0001,
|
||||
out=out,
|
||||
save_mean=save_mean,
|
||||
save_invstd=save_invstd,
|
||||
out=torch.empty_like(input),
|
||||
save_mean=torch.empty_like(running_mean),
|
||||
save_invstd=torch.empty_like(running_var),
|
||||
)
|
||||
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
script_fun = build_ts_script_function(target._schema, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
|
|
@ -15,8 +15,8 @@ from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs
|
|||
# CHECK: PASS - should_normalize
|
||||
@run_test
|
||||
def should_normalize():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
input = torch.randn((1, 3, 32, 32))
|
||||
kwargs = {"kernel_size": [3, 3]}
|
||||
golden = {
|
||||
"kernel_size": [3, 3],
|
||||
|
@ -28,18 +28,18 @@ def should_normalize():
|
|||
"ceil_mode": False,
|
||||
}
|
||||
|
||||
new_args, new_kwargs = normalize_args_kwargs(target, args, kwargs)
|
||||
for arg, new_arg in zip(args, new_args):
|
||||
assert torch.allclose(arg, new_arg)
|
||||
new_kwargs = normalize_args_kwargs(target, (input,), kwargs)
|
||||
assert torch.allclose(new_kwargs["input"], input)
|
||||
for k, v in new_kwargs.items():
|
||||
if k == "input": continue
|
||||
assert v == golden[k]
|
||||
|
||||
|
||||
# CHECK: FAIL - shouldnt_normalize1
|
||||
# CHECK: Couldn't normalize args and kwargs
|
||||
# CHECK: Errors: missing a required argument: 'kernel_size'
|
||||
@run_test
|
||||
def shouldnt_normalize1():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"stride": []}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
@ -54,7 +54,7 @@ def shouldnt_normalize1():
|
|||
# CHECK: XPASS - shouldnt_normalize2
|
||||
@run_test(XPASS=True)
|
||||
def shouldnt_normalize2():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"kernel_size": []}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
@ -63,7 +63,7 @@ def shouldnt_normalize2():
|
|||
# CHECK: XPASS - shouldnt_normalize3
|
||||
@run_test(XPASS=True)
|
||||
def shouldnt_normalize3():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"kernel_size": [3, 3], "padding": None}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
import os
|
||||
|
||||
EAGER_MODE_DEBUG = os.environ.get("EAGER_MODE_DEBUG", 'False').lower() in ('true', '1', 't')
|
|
@ -22,11 +22,15 @@ to convert the TorchScript function into MLIR using the `torch` dialect.
|
|||
"""
|
||||
|
||||
import abc
|
||||
from typing import Any, Optional, Iterable
|
||||
import re
|
||||
from typing import Any, Optional, Iterable, Dict
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.jit import ScriptFunction
|
||||
import torch._C
|
||||
import torch.jit
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects.func import FuncOp
|
||||
|
@ -202,26 +206,148 @@ def get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
|||
return None
|
||||
|
||||
|
||||
def build_module(jit_function: ScriptFunction, annotations) -> ir.Module:
|
||||
def is_tensor_type(typ: torch._C.Type):
|
||||
return typ.isSubtypeOf(torch.TensorType.get()) or (
|
||||
isinstance(typ, torch.OptionalType)
|
||||
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
|
||||
)
|
||||
|
||||
|
||||
def is_list_of_tensors_type(typ: torch._C.Type):
|
||||
return isinstance(typ, torch.ListType) and is_tensor_type(typ.getElementType())
|
||||
|
||||
|
||||
name_mangle_regex = re.compile("[^a-zA-Z0-9]")
|
||||
|
||||
|
||||
def build_ts_script_function(
|
||||
schema: torch._C.FunctionSchema, kwargs: Dict[str, Any]
|
||||
) -> torch.jit.ScriptFunction:
|
||||
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
|
||||
|
||||
Constants are inlined for the purposes of invalidating the compile cache when they change.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema: torch._C.FunctionSchema
|
||||
PyTorch's representation for ops, contains type information needed for inlining constants into the TS graph.
|
||||
kwargs: Dict
|
||||
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float/bool params).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.jit.ScriptFunction
|
||||
Fully specialized (all constants) TS graph whose only arguments are tensors.
|
||||
"""
|
||||
|
||||
# Creates empty TS graph.
|
||||
graph = torch._C.Graph()
|
||||
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
|
||||
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
|
||||
# Associate graph inputs/outputs with node inputs/outputs.
|
||||
graph_inputs = []
|
||||
for arg in schema.arguments:
|
||||
arg_name = arg.name if arg.name != "self" else "input"
|
||||
|
||||
# If arg is a flattened list of tensors, such as in the case of torch.cat
|
||||
# then add each element of the list to the graph corresponding to arg
|
||||
# and insert a ListConstruct to function as input to the op.
|
||||
if is_list_of_tensors_type(arg.type):
|
||||
inps = []
|
||||
for kwarg in [
|
||||
kwarg for kwarg in kwargs if f"{arg_name}_flattened" in kwarg
|
||||
]:
|
||||
inp = graph.addInput()
|
||||
el_typ = arg.type.getElementType()
|
||||
if isinstance(el_typ, torch.OptionalType):
|
||||
el_typ = el_typ.getElementType()
|
||||
inp.setType(el_typ)
|
||||
inp.setDebugName(kwarg)
|
||||
inps.append(inp)
|
||||
graph_inputs.append(kwarg)
|
||||
list_cons = graph.insertNode(graph.create("prim::ListConstruct", inps))
|
||||
list_cons.moveBefore(node)
|
||||
inp = list_cons.output()
|
||||
inp.setType(torch.ListType.ofTensors())
|
||||
# If arg is a tensor, then add input to the graph corresponding to arg.
|
||||
elif is_tensor_type(arg.type) and kwargs[arg_name] is not None:
|
||||
inp = graph.addInput()
|
||||
if isinstance(arg.type, torch.OptionalType):
|
||||
el_typ = arg.type.getElementType()
|
||||
else:
|
||||
el_typ = arg.type
|
||||
inp.setType(el_typ)
|
||||
inp.setDebugName(arg_name)
|
||||
graph_inputs.append(arg_name)
|
||||
# If arg is a constant, inline (at the top of the graph).
|
||||
else:
|
||||
val = kwargs[arg_name]
|
||||
if val == []:
|
||||
# Some ops have empty list default values for args
|
||||
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
|
||||
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
|
||||
# This might be an upstream bug but there doesn't seem to be a way to
|
||||
# build a prim::ListConstruct list that's empty.
|
||||
val = None
|
||||
inp = graph.insertConstant(val)
|
||||
inp.node().moveBefore(node)
|
||||
|
||||
node.addInput(inp)
|
||||
|
||||
# Reorder graph inputs to match kwargs.
|
||||
permutes = [
|
||||
{inp: i for i, inp in enumerate(graph_inputs)}[kwarg]
|
||||
for kwarg in [kwarg for kwarg in kwargs if kwarg in graph_inputs]
|
||||
]
|
||||
graph.permuteInputs(permutes)
|
||||
|
||||
if node.hasMultipleOutputs():
|
||||
for outp in node.outputs():
|
||||
graph.registerOutput(outp)
|
||||
else:
|
||||
graph.registerOutput(node.output())
|
||||
|
||||
fn = torch._C._create_function_from_graph(
|
||||
f"{name_mangle_regex.sub('', str(graph))}", graph
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def build_mlir_module(op: OpOverload, kwargs: Dict[str, Any]) -> ir.Module:
|
||||
"""Translate input function into an MLIR module in the `torch` dialect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
jit_function: ScriptFunction
|
||||
Function in TorchScript IR to turn into MLIR.
|
||||
annotation: Annotation
|
||||
Annotation object representing the types of
|
||||
the operands of `jit_function`.
|
||||
op: OpOverload
|
||||
Callable from the torch.ops.aten module/namespace that has a _schema field.
|
||||
kwargs: Dict
|
||||
A dictionary with all arguments passed in through __torch_dispatch__ (including int/float,bool params).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ir.Module
|
||||
Translation of the input module into an MLIR module
|
||||
Translation of the input module into an MLIR module.
|
||||
"""
|
||||
mb = ModuleBuilder()
|
||||
mb.import_function(jit_function)
|
||||
|
||||
func_op = get_func_op_with_name(mb.module, jit_function.name)
|
||||
# The assert here is to catch tensor shapes that have size 0 dimensions, such as those produced in
|
||||
# the course of evaluating SliceEndSleStartModule_basic and SliceOutOfLowerBoundEndIndexModule_basic.
|
||||
# Such 0 size dimensions fail the assert at mlir/lib/IR/BuiltinTypes.cpp, line 887
|
||||
annotations = []
|
||||
for arg_name, arg in kwargs.items():
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert np.prod(arg.shape) != 0, f"{arg_name} has invalid shape {arg.shape}"
|
||||
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg.dtype))
|
||||
annotations = tuple(annotations)
|
||||
|
||||
script_fun = build_ts_script_function(op._schema, kwargs)
|
||||
assert len(annotations) == len(
|
||||
list(script_fun.graph.inputs())
|
||||
), "Number of annotations and number of graph inputs differs."
|
||||
|
||||
mb = ModuleBuilder()
|
||||
mb.import_function(script_fun)
|
||||
|
||||
func_op = get_func_op_with_name(mb.module, script_fun.name)
|
||||
assert (
|
||||
func_op is not None
|
||||
), "Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder"
|
||||
|
|
|
@ -2,21 +2,20 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Tuple, Union
|
||||
from typing import List, Dict
|
||||
from typing import Any, Callable, Tuple
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.fx.node import map_aggregate
|
||||
from torch.fx.operator_schemas import normalize_function, create_type_hint
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch_mlir._mlir_libs._mlir.passmanager import PassManager
|
||||
from torch.fx import immutable_collections
|
||||
from torch.fx.operator_schemas import (
|
||||
_torchscript_schema_to_signature,
|
||||
_args_kwargs_to_normalized_args_kwargs,
|
||||
)
|
||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops
|
||||
|
||||
from torch_mlir.dialects import torch as torch_dialect
|
||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
|
||||
from torch_mlir.eager_mode.ir_building import build_module, TorchTensorType
|
||||
|
||||
OP_REGISTRY = {op["name"]: op for op in get_registered_ops()}
|
||||
SUPPORTED_OPS = frozenset(
|
||||
|
@ -37,153 +36,44 @@ class UnsupportedByTorchMlirEagerMode(Exception):
|
|||
return self.value
|
||||
|
||||
|
||||
def check_supported_op(schema: torch._C.FunctionSchema) -> bool:
|
||||
return (
|
||||
"torch."
|
||||
+ schema.name.replace("::", ".")
|
||||
+ ("." + schema.overload_name if schema.overload_name else "")
|
||||
) in SUPPORTED_OPS
|
||||
|
||||
|
||||
def is_tensor_type(typ: torch._C.Type):
|
||||
return typ.isSubtypeOf(torch.TensorType.get()) or (
|
||||
isinstance(typ, torch.OptionalType)
|
||||
and typ.getElementType().isSubtypeOf(torch._C.TensorType.get())
|
||||
)
|
||||
|
||||
|
||||
def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, Any]):
|
||||
"""Fill in default values for optional args, which are dependent on the schema."""
|
||||
|
||||
arg_types = map_aggregate(args, type)
|
||||
assert isinstance(arg_types, tuple)
|
||||
arg_types = map_aggregate(map_aggregate(args, type), create_type_hint)
|
||||
kwarg_types = {
|
||||
k: create_type_hint(map_aggregate(v, type)) for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
new_args_and_kwargs = normalize_function(
|
||||
target.op, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False
|
||||
sig = _torchscript_schema_to_signature(target._schema)
|
||||
_, new_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
||||
sig, args, kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
assert new_args_and_kwargs, "Couldn't normalize args and kwargs"
|
||||
new_args, new_kwargs = new_args_and_kwargs
|
||||
return new_args, new_kwargs
|
||||
if "self" in new_kwargs:
|
||||
new_kwargs["input"] = new_kwargs.pop("self")
|
||||
|
||||
# Flatten lists of args for ops that takes lists, such as torch.cat.
|
||||
to_remove = set()
|
||||
to_add = {}
|
||||
for k, v in new_kwargs.items():
|
||||
if isinstance(v, (tuple, list)) and len(v) and isinstance(v[0], torch.Tensor):
|
||||
to_remove.add(k)
|
||||
for i, vv in enumerate(v):
|
||||
to_add[f"{k}_flattened_{i}"] = vv
|
||||
|
||||
for rem in to_remove:
|
||||
del new_kwargs[rem]
|
||||
new_kwargs.update(**to_add)
|
||||
|
||||
# Sort here in order to have consistency across TS graph and
|
||||
# MLIR module.
|
||||
sorted_kwargs = dict(sorted(new_kwargs.items()))
|
||||
return immutable_collections.immutable_dict(sorted_kwargs)
|
||||
|
||||
|
||||
def build_script_function(
|
||||
schema: torch._C.FunctionSchema,
|
||||
args: List[torch._C.Argument],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> torch.jit.ScriptFunction:
|
||||
"""Build a torch.jit.ScriptFunction that corresponds to the schema.
|
||||
|
||||
Constants are inlined for the purposes of invalidating the compile cache when they change.
|
||||
"""
|
||||
|
||||
# Creates empty TS graph.
|
||||
graph = torch._C.Graph()
|
||||
# Creates and inserts node with identifier `schema.name`; NB node has no inputs or outputs at this point.
|
||||
node = graph.insertNode(graph.create(schema.name, len(schema.returns)))
|
||||
# Associate graph inputs/outputs with node inputs/outputs.
|
||||
for i, arg in enumerate(schema.arguments):
|
||||
# Find value corresponding to schema arg, either in positional or kw args.
|
||||
kwarg = False
|
||||
if arg.name in kwargs:
|
||||
val = kwargs[arg.name]
|
||||
kwarg = True
|
||||
else:
|
||||
val = args[i]
|
||||
|
||||
# If arg is a tensor, then add input to the graph corresponding to arg.
|
||||
if is_tensor_type(arg.type) and val is not None:
|
||||
inp = graph.addInput()
|
||||
if isinstance(arg.type, torch.OptionalType):
|
||||
inp.setType(arg.type.getElementType())
|
||||
else:
|
||||
inp.setType(arg.type)
|
||||
|
||||
if kwarg:
|
||||
# Rename for debugging aid.
|
||||
inp.setDebugName(arg.name)
|
||||
# If arg is a constant, inline (at the top of the graph).
|
||||
else:
|
||||
if val == []:
|
||||
# Some ops have empty list default values for args
|
||||
# (such as aten::max_pool2d_with_indices with int[2] stride=[]
|
||||
# but graph.insertConstant doesnt' recognize [] as an empty list IValue.
|
||||
# This might be an upstream bug but there doesn't seem to be a way to
|
||||
# build a prim::ListConstruct list that's empty.
|
||||
val = None
|
||||
inp = graph.insertConstant(val)
|
||||
inp.node().moveBefore(node)
|
||||
|
||||
node.addInput(inp)
|
||||
|
||||
if node.hasMultipleOutputs():
|
||||
for outp in node.outputs():
|
||||
graph.registerOutput(outp)
|
||||
else:
|
||||
graph.registerOutput(node.output())
|
||||
|
||||
fn = torch._C._create_function_from_graph("f", graph)
|
||||
return fn
|
||||
def get_registered_op(op):
|
||||
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
|
||||
return registered_op
|
||||
|
||||
|
||||
def annotate_args_kwargs(
|
||||
script_fun: torch._C.ScriptFunction,
|
||||
normalized_args: List[Any],
|
||||
normalized_kwargs: Dict[str, Any],
|
||||
):
|
||||
unwrapped_normalized_args = tree_map(
|
||||
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
|
||||
normalized_args,
|
||||
)
|
||||
unwrapped_normalized_kwargs = tree_map(
|
||||
lambda x: x.detach().contiguous().numpy() if isinstance(x, torch.Tensor) else x,
|
||||
normalized_kwargs,
|
||||
)
|
||||
|
||||
annotations = []
|
||||
tensor_args = []
|
||||
for i, arg in enumerate(unwrapped_normalized_args):
|
||||
if isinstance(arg, np.ndarray):
|
||||
# TODO: Remove once size zero dimensions are handled by torch-mlir.
|
||||
shape = tuple(map(lambda x: x or -1, arg.shape))
|
||||
annotations.append(
|
||||
TorchTensorType(shape=shape, dtype=normalized_args[i].dtype)
|
||||
)
|
||||
tensor_args.append(arg)
|
||||
|
||||
# Pull out tensor kwargs and put them in positional order.
|
||||
tensor_kwargs_flat = []
|
||||
if unwrapped_normalized_kwargs:
|
||||
tensor_kwargs = {}
|
||||
arg_idxs = {
|
||||
arg_name: i
|
||||
for i, arg_name in enumerate(
|
||||
[arg.name for arg in script_fun.schema.arguments]
|
||||
)
|
||||
}
|
||||
for i, (kw, arg) in enumerate(unwrapped_normalized_kwargs.items()):
|
||||
if isinstance(arg, np.ndarray):
|
||||
tensor_kwargs[arg_idxs[kw]] = (arg, normalized_kwargs[kw].dtype)
|
||||
|
||||
for _i, (arg, arg_dtype) in sorted(tensor_kwargs.items()):
|
||||
annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg_dtype))
|
||||
tensor_kwargs_flat.append(arg)
|
||||
|
||||
return annotations, tensor_args, tensor_kwargs_flat
|
||||
|
||||
|
||||
def write_back_to_mutable(
|
||||
registered_op: Dict,
|
||||
out: Union[np.ndarray, List[np.ndarray]],
|
||||
all_tensor_args: List[np.ndarray],
|
||||
):
|
||||
def check_get_aliased_arg(func: Callable,):
|
||||
"""Write back to mutable args that aren't properly handled otherwise.
|
||||
|
||||
Because of how we pass values to the backend (by copying the tensor to a numpy array) we don't currently support
|
||||
ops that mutate operands. That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
|
||||
Because of how we pass values to the backend we don't currently support ops that mutate operands.
|
||||
That includes both inplace variants and outplace variants. Additionally, Torch-MLIR,
|
||||
as of right now, only handles arguments with value semantics, so we need to manually fake those semantics, which
|
||||
we can for these special cases. Hence, the solution is to manually write back to the same operand that the
|
||||
conventional pytorch op variant would write to.
|
||||
|
@ -191,82 +81,31 @@ def write_back_to_mutable(
|
|||
Note that there are ops where multiple operands are mutable (such as batchnorm outplace variants that
|
||||
mutate running_mean and running_var). We don't currently handle those.
|
||||
"""
|
||||
|
||||
registered_op = get_registered_op(func)
|
||||
if not registered_op["is_mutable"]:
|
||||
return None
|
||||
|
||||
if len(registered_op["returns"]) > 1:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
"TorchMLIR doesn't handle multiple aliased returns yet."
|
||||
)
|
||||
else:
|
||||
aliased_arg = next(
|
||||
arg
|
||||
for arg in registered_op["arguments"]
|
||||
if "alias_info" in arg and arg["alias_info"]["is_write"]
|
||||
)
|
||||
assert (
|
||||
"alias_info" in registered_op["returns"][0]
|
||||
and registered_op["returns"][0]["alias_info"]["is_write"]
|
||||
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
|
||||
and registered_op["returns"][0]["alias_info"]["after"][0]
|
||||
)
|
||||
assert (
|
||||
len(aliased_arg["alias_info"]["after"]) == 1
|
||||
and aliased_arg["alias_info"]["after"][0]
|
||||
== registered_op["returns"][0]["alias_info"]["after"][0]
|
||||
)
|
||||
np.copyto(all_tensor_args[0], out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def try_torch_mlir_eager(op, args, kwargs, backend):
|
||||
if hasattr(op, "op_name"):
|
||||
op_name = op.op_name
|
||||
elif hasattr(op, "__name__"):
|
||||
# Handle builtin_function_or_method.
|
||||
op_name = op.__name__
|
||||
else:
|
||||
raise RuntimeError(f"op {op} has no name")
|
||||
|
||||
if "detach" in op_name:
|
||||
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
|
||||
raise UnsupportedByTorchMlirEagerMode("detaching")
|
||||
|
||||
if not hasattr(op, "_schema"):
|
||||
raise RuntimeError(f"op {op} has no schema.")
|
||||
|
||||
new_args, new_kwargs = normalize_args_kwargs(op.overloadpacket, args, kwargs)
|
||||
|
||||
if "layout" in new_kwargs and new_kwargs["layout"] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{new_kwargs['layout']} layout not supported."
|
||||
)
|
||||
if "memory_format" in new_kwargs and new_kwargs["memory_format"] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{new_kwargs['memory_format']} memory format not supported."
|
||||
)
|
||||
|
||||
script_fun = build_script_function(op._schema, new_args, new_kwargs)
|
||||
annotations, np_tensor_args, np_tensor_kwargs_flat = annotate_args_kwargs(
|
||||
script_fun, new_args, new_kwargs
|
||||
aliased_arg = next(
|
||||
arg
|
||||
for arg in registered_op["arguments"]
|
||||
if "alias_info" in arg and arg["alias_info"]["is_write"]
|
||||
)
|
||||
assert (
|
||||
"alias_info" in registered_op["returns"][0]
|
||||
and registered_op["returns"][0]["alias_info"]["is_write"]
|
||||
and len(registered_op["returns"][0]["alias_info"]["after"]) == 1
|
||||
and registered_op["returns"][0]["alias_info"]["after"][0]
|
||||
)
|
||||
assert (
|
||||
len(aliased_arg["alias_info"]["after"]) == 1
|
||||
and aliased_arg["alias_info"]["after"][0]
|
||||
== registered_op["returns"][0]["alias_info"]["after"][0]
|
||||
)
|
||||
|
||||
eager_module = build_module(script_fun, annotations)
|
||||
with eager_module.context:
|
||||
pm = PassManager.parse(
|
||||
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
|
||||
)
|
||||
pm.run(eager_module)
|
||||
compiled_module = backend.compile(eager_module)
|
||||
loaded_module = backend.load(compiled_module)
|
||||
op_mlir_backend_callable = getattr(loaded_module, script_fun.name)
|
||||
assert (
|
||||
op_mlir_backend_callable is not None
|
||||
), f"Couldn't find function {script_fun.name} in module."
|
||||
|
||||
all_tensor_args = np_tensor_args + np_tensor_kwargs_flat
|
||||
out = op_mlir_backend_callable(*all_tensor_args)
|
||||
|
||||
registered_op = OP_REGISTRY[(op._schema.name, op._schema.overload_name)]
|
||||
if registered_op["is_mutable"]:
|
||||
out = write_back_to_mutable(registered_op, out, all_tensor_args)
|
||||
|
||||
return out
|
||||
return aliased_arg["name"] if aliased_arg["name"] != "self" else "input"
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Tuple, Callable, List, Dict, Any
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir._mlir_libs._mlir.ir import Module
|
||||
|
||||
# TODO: This might need to be an ABC too, such as
|
||||
# to support finding the backend that created the tensor.
|
||||
DeviceTensor = TypeVar("DeviceTensor")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorMetaData:
|
||||
"""A small container for metadata necessary for satisfying the pytorch dispatcher and other code (pytorch or
|
||||
otherwise) that branches on these attributes.
|
||||
|
||||
There is a lot of code in the PyTorch codebase that branches based on these attributes; the obvious ones here
|
||||
are dtype, device, and requires_grad (necessary for autograd itself). There is ample warning from PyTorch that,
|
||||
in principle, these should be as close as possible to true; see
|
||||
https://github.com/albanD/subclass_zoo/blob/1566e038f03cd89ab3cc37e670a44e3c2bbc1897/trivial_tensors.py#L90-L92
|
||||
|
||||
The defaults (properties) simplify the api and seem to work after some testing but
|
||||
might malfunction in unexpected ways.
|
||||
# TODO: revisit these assumptions
|
||||
"""
|
||||
|
||||
size: Tuple[int]
|
||||
dtype: torch.dtype
|
||||
requires_grad: bool
|
||||
|
||||
strides: Tuple[int]
|
||||
storage_offset: int = 0
|
||||
layout: torch.layout = torch.strided
|
||||
device: torch.device = torch.device("cpu")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
dtype,
|
||||
requires_grad,
|
||||
strides=None,
|
||||
storage_offset=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "size", size)
|
||||
object.__setattr__(self, "dtype", dtype)
|
||||
object.__setattr__(self, "requires_grad", requires_grad)
|
||||
|
||||
object.__setattr__(
|
||||
self, "strides", strides if strides is not None else len(size) * [0]
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "storage_offset", storage_offset if storage_offset is not None else 0
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "layout", layout if layout is not None else torch.strided
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "device", device if device is not None else torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
class TorchMLIREagerBackend(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def compile(
|
||||
self, module: Module
|
||||
) -> Callable[[List[DeviceTensor]], List[DeviceTensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> DeviceTensor:
|
||||
"""Unwrap the backend representation in order to build a torch.Tensor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_torch_metadata(
|
||||
self, tensor: DeviceTensor, kwargs: Dict[str, Any]
|
||||
) -> TensorMetaData:
|
||||
"""Parse relevant tensor metadata from backend device array (e.g., shape, stride, layout) in order to build
|
||||
wrapper tensor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer_from_device_to_torch(self, tensor: DeviceTensor) -> torch.Tensor:
|
||||
"""If compilation fails for some reason then device specific representations need to be munged into a
|
||||
torch.Tensor representation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def copy_into(self, dst: DeviceTensor, src: DeviceTensor):
|
||||
"""This method is needed for things like handling aliased arguments."""
|
||||
raise NotImplementedError
|
|
@ -2,26 +2,65 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
import contextlib
|
||||
import re
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from torch_mlir.eager_mode.ir_building import build_mlir_module
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import (
|
||||
try_torch_mlir_eager,
|
||||
UnsupportedByTorchMlirEagerMode,
|
||||
normalize_args_kwargs,
|
||||
check_get_aliased_arg,
|
||||
)
|
||||
from torch_mlir.eager_mode import EAGER_MODE_DEBUG
|
||||
from torch_mlir_e2e_test.eager_backends.refbackend import EagerModeRefBackend
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_dispatch():
|
||||
"""Prevent infinite recursion in case accidentally calling a tensor method on a TorchMLIRTensor within
|
||||
__torch_dispatch__."""
|
||||
|
||||
guard = torch._C._DisableTorchDispatch()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
backend = EagerModeRefBackend()
|
||||
|
||||
UNSUPPORTED_OPS = re.compile(
|
||||
"|".join([
|
||||
# We don't handle detach as it only pertains to autograd graph construction, which is handled by pytorch.
|
||||
"detach",
|
||||
# We don't handle _local_scalar_dense because it's just a way to unwrap a tensor that wraps a number.
|
||||
"_local_scalar_dense",
|
||||
# https://github.com/llvm/torch-mlir/issues/878
|
||||
"_unsafe_view",
|
||||
"view",
|
||||
])
|
||||
)
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
class TorchMLIRTensor(torch.Tensor):
|
||||
"""Wrap torch.Tensor in order to dispatch through torch-mlir instead of aten.
|
||||
"""This class serves the role abstract class with common functionality for dispatching through Torch-MLIR instead of aten.
|
||||
|
||||
This class uses the _make_wrapper_subclass pattern to override __torch_dispatch__
|
||||
in order to dispatch through torch-mlir instead of aten. Here we basically only unwrap and wrap
|
||||
torch.Tensors. Most of the heavy lifting is done in the adjacent torch_mlir_dispatch module.
|
||||
It defers device specific behavior to device specific implementations. The deriving classes use the
|
||||
make_bare_wrapper_subclass convenience method, adjacent here, and override __torch_dispatch__ in order to dispatch
|
||||
through Torch-MLIR instead of aten. Backends are free to choose whatever representation of the buffers (i.e., `elem`)
|
||||
and are expected to provide conversion mechanisms between their representation and torch.Tensor.
|
||||
|
||||
More documentation on how this pattern works can be found in this forum post
|
||||
Here we only verify that inputs abide by current supported features of Torch-MLIR (contiguous memory and
|
||||
strided tensor layout) and build the mlir module. Importantly, we also recover from any malfunctions in the
|
||||
deriving classes and dispatch back to conventional PyTorch.
|
||||
|
||||
More documentation on how the __torch_dispatch__ pattern works can be found in this forum post
|
||||
https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557
|
||||
and this RFC
|
||||
https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call
|
||||
|
@ -29,83 +68,188 @@ class TorchMLIRTensor(torch.Tensor):
|
|||
https://github.com/albanD/subclass_zoo
|
||||
"""
|
||||
|
||||
elem: torch.Tensor
|
||||
elem: Any
|
||||
|
||||
__slots__ = ["elem"]
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
# Only float tensors can have gradients.
|
||||
requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64}
|
||||
and (kwargs.get("requires_grad", False) or elem.requires_grad),
|
||||
)
|
||||
r.elem = elem.detach() if r.requires_grad else elem
|
||||
def __new__(cls, elem, **kwargs):
|
||||
"""Wrap elem (which could be a torch.Tensor or otherwise) in a torch.Tensor subclass.
|
||||
|
||||
Critically, this method needs to parse relevant metadata from the device representation
|
||||
(such as shape, striding, dtype, etc.) and translate it into torch conventions.
|
||||
|
||||
Deriving classes must provide a way to construct themselves from either their device specific representation
|
||||
or torch.Tensor; the latter is to handle the case that dispatch to PyTorch to recover from an error.
|
||||
"""
|
||||
if kwargs.get("constructing_from_device_tensor", False):
|
||||
tensor_meta_data = backend.get_torch_metadata(elem, kwargs)
|
||||
r = make_bare_wrapper_subclass(
|
||||
cls=cls,
|
||||
size=tensor_meta_data.size,
|
||||
strides=tensor_meta_data.strides,
|
||||
storage_offset=tensor_meta_data.storage_offset,
|
||||
dtype=tensor_meta_data.dtype,
|
||||
layout=tensor_meta_data.layout,
|
||||
device=tensor_meta_data.device,
|
||||
requires_grad=tensor_meta_data.requires_grad,
|
||||
)
|
||||
r.elem = elem
|
||||
elif isinstance(elem, torch.nn.Parameter):
|
||||
r = make_wrapper_subclass_from_torch_tensor(cls, elem.data, **kwargs)
|
||||
r.elem = backend.transfer_from_torch_to_device(elem.detach().data)
|
||||
elif isinstance(elem, torch.Tensor):
|
||||
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
|
||||
r.elem = backend.transfer_from_torch_to_device(elem)
|
||||
# This branch handles the case when a python scalar is passed to some op
|
||||
# or is returned from some aten op, such as _local_scalar_dense.
|
||||
elif isinstance(elem, (int, float, bool)):
|
||||
return elem
|
||||
else:
|
||||
raise ValueError(f"Unknown element type: {type(elem)}")
|
||||
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"TorchMLIRTensor({self.elem}, grad_fn={self.grad_fn})"
|
||||
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})"
|
||||
else:
|
||||
return f"TorchMLIRTensor({self.elem})"
|
||||
return f"TorchMLIRTensor({self.elem}, backend={backend.__class__.__name__})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
|
||||
requires_grad = False
|
||||
|
||||
def check_grad(e):
|
||||
nonlocal requires_grad
|
||||
if isinstance(e, TorchMLIRTensor):
|
||||
requires_grad |= e.requires_grad
|
||||
|
||||
tree_map(check_grad, args)
|
||||
tree_map(check_grad, kwargs)
|
||||
|
||||
def unwrap(e):
|
||||
if isinstance(e, TorchMLIRTensor):
|
||||
return e.elem
|
||||
if isinstance(e, torch.nn.Parameter):
|
||||
return e.detach()
|
||||
return e
|
||||
|
||||
def wrap(e):
|
||||
nonlocal requires_grad
|
||||
return (
|
||||
TorchMLIRTensor(e, requires_grad=requires_grad)
|
||||
if isinstance(e, torch.Tensor)
|
||||
else e
|
||||
)
|
||||
|
||||
unwrapped_args = tree_map(unwrap, args)
|
||||
unwrapped_kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
requires_grad = check_requires_grad(*args, **kwargs)
|
||||
try:
|
||||
out = try_torch_mlir_eager(
|
||||
func,
|
||||
unwrapped_args,
|
||||
unwrapped_kwargs,
|
||||
backend=refbackend.RefBackendLinalgOnTensorsBackend(),
|
||||
with no_dispatch():
|
||||
if hasattr(func, "op_name"):
|
||||
op_name = func.op_name
|
||||
elif hasattr(func, "__name__"):
|
||||
# Handle builtin_function_or_method.
|
||||
op_name = func.__name__
|
||||
else:
|
||||
raise RuntimeError(f"op {func} has no name")
|
||||
|
||||
if UNSUPPORTED_OPS.match(op_name):
|
||||
raise UnsupportedByTorchMlirEagerMode(op_name)
|
||||
|
||||
if not hasattr(func, "_schema"):
|
||||
raise RuntimeError(f"op {func} has no schema.")
|
||||
|
||||
normalized_kwargs = normalize_args_kwargs(func, args, kwargs)
|
||||
|
||||
if "layout" in normalized_kwargs and normalized_kwargs[
|
||||
"layout"
|
||||
] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{normalized_kwargs['layout']} layout not supported."
|
||||
)
|
||||
if "memory_format" in normalized_kwargs and normalized_kwargs[
|
||||
"memory_format"
|
||||
] not in {0, None}:
|
||||
raise UnsupportedByTorchMlirEagerMode(
|
||||
f"{normalized_kwargs['memory_format']} memory format not supported."
|
||||
)
|
||||
eager_module = build_mlir_module(func, normalized_kwargs)
|
||||
device_tensor_args = [
|
||||
kwarg.elem
|
||||
for _, kwarg in normalized_kwargs.items()
|
||||
if isinstance(kwarg, cls)
|
||||
]
|
||||
assert len(eager_module.body.operations[0].arguments) == len(
|
||||
device_tensor_args
|
||||
), "Number of parameters and number of arguments differs."
|
||||
op_mlir_backend_callable = backend.compile(eager_module)
|
||||
out = op_mlir_backend_callable(*device_tensor_args)
|
||||
out = tree_map(
|
||||
lambda x: cls(
|
||||
x, requires_grad=requires_grad, constructing_from_device_tensor=True
|
||||
),
|
||||
out,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
out = [torch.from_numpy(o) for o in out]
|
||||
else:
|
||||
out = torch.from_numpy(out)
|
||||
return tree_map(wrap, out)
|
||||
except Exception as e:
|
||||
if isinstance(e, UnsupportedByTorchMlirEagerMode):
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
|
||||
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
|
||||
)
|
||||
return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))
|
||||
if EAGER_MODE_DEBUG:
|
||||
warnings.warn(traceback.format_exc())
|
||||
if isinstance(e, UnsupportedByTorchMlirEagerMode):
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
|
||||
f"running through PyTorch eager. Please file an issue at https://github.com/llvm/torch-mlir/issues"
|
||||
)
|
||||
|
||||
with no_dispatch():
|
||||
unwrapped_args = tree_map(cls.unwrap, args)
|
||||
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
|
||||
out = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)
|
||||
|
||||
maybe_aliased_arg_name = check_get_aliased_arg(func)
|
||||
if maybe_aliased_arg_name is not None:
|
||||
backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem)
|
||||
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def unwrap(cls, e):
|
||||
"""Unwrap the TorchMLIRTensor representation in order to access the actual device specific representation."""
|
||||
if isinstance(e, cls):
|
||||
return backend.transfer_from_device_to_torch(e.elem)
|
||||
return e
|
||||
|
||||
|
||||
def check_requires_grad(*args, **kwargs):
|
||||
requires_grad = False
|
||||
|
||||
def check_grad(e):
|
||||
nonlocal requires_grad
|
||||
if isinstance(e, TorchMLIRTensor):
|
||||
requires_grad |= e.requires_grad
|
||||
|
||||
tree_map(check_grad, args)
|
||||
tree_map(check_grad, kwargs)
|
||||
|
||||
return requires_grad
|
||||
|
||||
|
||||
def make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs):
|
||||
"""Convenience method that parse out relevant metadata from a torch.Tensor, in order to produce
|
||||
a wrapper subclass.
|
||||
|
||||
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
|
||||
of the device specific implementation.
|
||||
"""
|
||||
r = make_bare_wrapper_subclass(
|
||||
cls=cls,
|
||||
size=elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
# Only float tensors can have gradients.
|
||||
requires_grad=elem.dtype in {torch.float, torch.float32, torch.float64}
|
||||
and (kwargs.get("requires_grad", False) or elem.requires_grad),
|
||||
)
|
||||
return r
|
||||
|
||||
|
||||
def make_bare_wrapper_subclass(
|
||||
*, cls, size, strides, storage_offset, dtype, layout, device, requires_grad
|
||||
):
|
||||
"""Convenience method that builds a wrapper subclass.
|
||||
|
||||
NB: this convenience method does not set that `elem` attribute of the subclass, as that is the responsibility
|
||||
of the device specific implementation.
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
size,
|
||||
strides=strides,
|
||||
storage_offset=storage_offset,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch_mlir.compiler_utils import (
|
||||
get_module_name_for_debug_dump,
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
|
||||
TorchMLIREagerBackend,
|
||||
TensorMetaData,
|
||||
)
|
||||
from torch_mlir.ir import Module
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
|
||||
RefBackendLinalgOnTensorsBackend,
|
||||
)
|
||||
|
||||
NUMPY_TO_TORCH_DTYPE_DICT = {
|
||||
np.bool: torch.bool,
|
||||
np.bool_: torch.bool,
|
||||
np.uint8: torch.uint8,
|
||||
np.int8: torch.int8,
|
||||
np.int16: torch.int16,
|
||||
np.int32: torch.int32,
|
||||
np.int64: torch.int64,
|
||||
np.float16: torch.float16,
|
||||
np.float32: torch.float32,
|
||||
np.float64: torch.float64,
|
||||
np.complex64: torch.complex64,
|
||||
np.complex128: torch.complex128,
|
||||
}
|
||||
|
||||
_ref_backend = RefBackendLinalgOnTensorsBackend()
|
||||
|
||||
|
||||
class EagerModeRefBackend(TorchMLIREagerBackend):
|
||||
"""Main entry-point for the reference backend for eager mode.
|
||||
|
||||
RefBackend uses numpy.ndarray representations of tensors and thus all of the wrapping and unwrapping
|
||||
and munging here is done to between torch.Tensor and numpy.ndarray.
|
||||
"""
|
||||
|
||||
module_to_refbackend_invoker = {}
|
||||
|
||||
def get_torch_metadata(
|
||||
self, tensor: np.ndarray, kwargs: Dict[str, Any]
|
||||
) -> TensorMetaData:
|
||||
return TensorMetaData(
|
||||
size=tensor.shape,
|
||||
dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type],
|
||||
requires_grad=tensor.dtype in {np.float, np.float32, np.float64}
|
||||
and kwargs.get("requires_grad", False),
|
||||
)
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Lower the imported TS module to linalg and then further compile for the reference backend and then call."""
|
||||
fn_name = get_module_name_for_debug_dump(imported_module)
|
||||
module_hash = str(imported_module)
|
||||
if module_hash not in self.module_to_refbackend_invoker:
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"EagerMode",
|
||||
)
|
||||
self.module_to_refbackend_invoker[module_hash] = _ref_backend.load(
|
||||
_ref_backend.compile(imported_module)
|
||||
)
|
||||
|
||||
ref_backend_invoker = self.module_to_refbackend_invoker[module_hash]
|
||||
op_mlir_backend_callable = getattr(ref_backend_invoker, fn_name)
|
||||
assert (
|
||||
op_mlir_backend_callable is not None
|
||||
), f"Couldn't find function in module."
|
||||
return op_mlir_backend_callable
|
||||
|
||||
def copy_into(self, dst: np.ndarray, src: np.ndarray):
|
||||
np.copyto(dst, src)
|
||||
|
||||
def transfer_from_device_to_torch(self, e: np.ndarray):
|
||||
return torch.from_numpy(e).clone()
|
||||
|
||||
def transfer_from_torch_to_device(self, tensor: torch.Tensor) -> np.ndarray:
|
||||
return tensor.numpy()
|
|
@ -15,7 +15,27 @@ def wrap(e):
|
|||
|
||||
|
||||
def unwrap(e):
|
||||
return e.elem.clone() if isinstance(e, TorchMLIRTensor) else e
|
||||
return TorchMLIRTensor.unwrap(e) if isinstance(e, TorchMLIRTensor) else e
|
||||
|
||||
|
||||
def to_tmt(m: torch.nn.Module):
|
||||
for buf_name, buf in m.named_buffers(recurse=True):
|
||||
if isinstance(buf, TorchMLIRTensor):
|
||||
continue
|
||||
m.register_buffer(buf_name, TorchMLIRTensor(buf))
|
||||
for param_name, param in m.named_parameters(recurse=True):
|
||||
if isinstance(param, TorchMLIRTensor):
|
||||
continue
|
||||
m.register_parameter(
|
||||
param_name,
|
||||
torch.nn.Parameter(
|
||||
TorchMLIRTensor(param), requires_grad=param.requires_grad
|
||||
),
|
||||
)
|
||||
for attr in dir(m):
|
||||
field = getattr(m, attr)
|
||||
if isinstance(field, torch.Tensor) and not isinstance(field, TorchMLIRTensor):
|
||||
setattr(m, attr, TorchMLIRTensor(field))
|
||||
|
||||
|
||||
class EagerModeTestConfig(TestConfig):
|
||||
|
@ -25,13 +45,14 @@ class EagerModeTestConfig(TestConfig):
|
|||
super().__init__()
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||
program.apply(to_tmt)
|
||||
return program
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
attr = artifact
|
||||
for part in item.symbol.split('.'):
|
||||
for part in item.symbol.split("."):
|
||||
attr = getattr(attr, part)
|
||||
|
||||
inps = tree_map(wrap, item.inputs)
|
||||
|
@ -39,7 +60,6 @@ class EagerModeTestConfig(TestConfig):
|
|||
output = tree_map(unwrap, outps)
|
||||
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=output))
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=output)
|
||||
)
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue