mirror of https://github.com/llvm/torch-mlir
parent
18ef40acaf
commit
66de821eaf
|
@ -0,0 +1,309 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from framework import run_test
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import build_script_function
|
||||
|
||||
|
||||
# CHECK: graph(%[[A1:.*]] : Tensor,
|
||||
# CHECK: %[[A2:.*]] : Tensor,
|
||||
# CHECK: %[[A3:.*]] : Tensor):
|
||||
# CHECK: %[[A4:.*]] : int = prim::Constant[value=1]()
|
||||
# CHECK: %[[A5:.*]] : int = prim::Constant[value=1]()
|
||||
# CHECK: %[[A0:.*]] : Tensor = aten::addmm(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]])
|
||||
# CHECK: return (%[[A0]])
|
||||
# -----
|
||||
# CHECK: PASS - simple
|
||||
@run_test
|
||||
def simple():
|
||||
target = torch.ops.aten.addmm.default
|
||||
A = torch.randn(1, 3, 32, 32)
|
||||
B = torch.randn(1, 3, 32, 32)
|
||||
C = torch.randn(1, 3, 32, 32)
|
||||
args = (A, B, C)
|
||||
kwargs = dict(beta=1, alpha=1)
|
||||
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[B1:.*]] : Tensor,
|
||||
# CHECK: %[[B2:.*]] : Tensor,
|
||||
# CHECK: %[[B3:.*]] : Tensor):
|
||||
# CHECK: %[[B4:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[B5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[B6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[B7:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[B8:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[B9:.*]] : int = prim::Constant[value=1]()
|
||||
# CHECK: %[[B0:.*]] : Tensor = aten::convolution(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[B9]])
|
||||
# CHECK: return (%[[B0]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_optional_tensor_input
|
||||
@run_test
|
||||
def handle_optional_tensor_input():
|
||||
target = torch.ops.aten.convolution.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
weight = torch.randn(3, 3, 3, 3)
|
||||
bias = torch.randn(3)
|
||||
args = (input, weight, bias)
|
||||
kwargs = dict(
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
)
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: FAIL - fail_not_enough_args
|
||||
# CHECK: Errors: tuple index out of range
|
||||
@run_test
|
||||
def fail_not_enough_args():
|
||||
target = torch.ops.aten.convolution.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
weight = torch.randn(3, 3, 3, 3)
|
||||
bias = torch.randn(3)
|
||||
args = (input, weight, bias)
|
||||
kwargs = dict(
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
# Missing groups=1,
|
||||
)
|
||||
build_script_function(target._schema, args, kwargs)
|
||||
|
||||
|
||||
# CHECK: PASS - simple_args_or_kwargs
|
||||
@run_test
|
||||
def simple_args_or_kwargs():
|
||||
target = torch.ops.aten.convolution.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
weight = torch.randn(3, 3, 3, 3)
|
||||
bias = torch.randn(3)
|
||||
stride = [1, 1]
|
||||
padding = [0, 0]
|
||||
dilation = [1, 1]
|
||||
transposed = False
|
||||
output_padding = [0, 0]
|
||||
groups = 1
|
||||
script_fun1 = build_script_function(
|
||||
target._schema,
|
||||
(input, weight, bias),
|
||||
dict(
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
transposed=transposed,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
),
|
||||
)
|
||||
|
||||
script_fun2 = build_script_function(
|
||||
target._schema,
|
||||
(input, weight, bias, stride, padding, dilation),
|
||||
dict(transposed=transposed, output_padding=output_padding, groups=groups),
|
||||
)
|
||||
assert str(script_fun1.graph) == str(script_fun2.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[C2:.*]] : Tensor):
|
||||
# CHECK: %[[C3:.*]] : int[] = prim::Constant[value=[3, 3]]()
|
||||
# CHECK: %[[C4:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[C5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[C6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[C7:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[C0:.*]] : Tensor, %[[C1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]])
|
||||
# CHECK: return (%[[C0]], %[[C1]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_empty_lists
|
||||
@run_test
|
||||
def handle_empty_lists():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
# print(target._schema)
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {
|
||||
"kernel_size": [3, 3],
|
||||
"stride": [],
|
||||
"padding": [0, 0],
|
||||
"dilation": [1, 1],
|
||||
"ceil_mode": False,
|
||||
}
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[D2:.*]] : Tensor):
|
||||
# CHECK: %[[D3:.*]] : int[] = prim::Constant[value=[3, 3]]()
|
||||
# CHECK: %[[D4:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[D5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[D6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[D7:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[D0:.*]] : Tensor, %[[D1:.*]] : Tensor = aten::max_pool2d_with_indices(%[[D2]], %[[D3]], %[[D4]], %[[D5]], %[[D6]], %[[D7]])
|
||||
# CHECK: return (%[[D0]], %[[D1]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_nones
|
||||
@run_test
|
||||
def handle_nones():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default
|
||||
# print(target._schema)
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {
|
||||
"kernel_size": [3, 3],
|
||||
"stride": None,
|
||||
"padding": [0, 0],
|
||||
"dilation": [1, 1],
|
||||
"ceil_mode": False,
|
||||
}
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[E1:.*]] : Tensor,
|
||||
# CHECK: %[[E2:.*]] : Tensor,
|
||||
# CHECK: %[[E3:.*]] : Tensor):
|
||||
# CHECK: %[[E4:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[E5:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[E6:.*]] : int[] = prim::Constant[value=[1, 1]]()
|
||||
# CHECK: %[[E7:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[E8:.*]] : int[] = prim::Constant[value=[0, 0]]()
|
||||
# CHECK: %[[E9:.*]] : int = prim::Constant[value=1]()
|
||||
# CHECK: %[[E0:.*]] : Tensor = aten::convolution(%[[E1]], %[[E2]], %[[E3]], %[[E4]], %[[E5]], %[[E6]], %[[E7]], %[[E8]], %[[E9]])
|
||||
# CHECK: return (%[[E0]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_optional_tensors
|
||||
@run_test
|
||||
def handle_optional_tensors():
|
||||
target = torch.ops.aten.convolution.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
weight = torch.randn(3, 3, 3, 3)
|
||||
bias = torch.randn(3)
|
||||
args = (input, weight, bias)
|
||||
kwargs = dict(
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
)
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[F1:.*]] : Tensor):
|
||||
# CHECK: %[[F2:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F3:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F4:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F5:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F6:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[F0:.*]] : Tensor = aten::ones_like(%[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], %[[F6]])
|
||||
# CHECK: return (%[[F0]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_ones_like
|
||||
@run_test
|
||||
def handle_ones_like():
|
||||
target = torch.ops.aten.ones_like.default
|
||||
input = torch.randn(1, 3, 32, 32)
|
||||
args = (input,)
|
||||
kwargs = dict(
|
||||
dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
|
||||
)
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: graph(%[[G3:.*]] : Tensor,
|
||||
# CHECK: %[[G4:.*]] : Tensor,
|
||||
# CHECK: %[[G5:.*]] : Tensor):
|
||||
# CHECK: %[[G6:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[G7:.*]] : NoneType = prim::Constant()
|
||||
# CHECK: %[[G8:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[G9:.*]] : float = prim::Constant[value=1.]()
|
||||
# CHECK: %[[G10:.*]] : float = prim::Constant[value=1.]()
|
||||
# CHECK: %[[G0:.*]] : Tensor, %[[G1:.*]] : Tensor, %[[G2:.*]] : Tensor = aten::native_batch_norm(%[[G3]], %[[G4]], %[[G5]], %[[G6]], %[[G7]], %[[G8]], %[[G9]], %[[G10]])
|
||||
# CHECK: return (%[[G0]], %[[G1]], %[[G2]])
|
||||
# -----
|
||||
# CHECK: PASS - handle_multiple_outputs
|
||||
@run_test
|
||||
def handle_multiple_outputs():
|
||||
target = torch.ops.aten.native_batch_norm.default
|
||||
A = torch.randn(1, 3, 32, 32)
|
||||
B = torch.randn(1, 3, 32, 32)
|
||||
C = torch.randn(1, 3, 32, 32)
|
||||
args = (A, B, C, None, None, False, 1.0, 1.0)
|
||||
kwargs = dict()
|
||||
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.graph)
|
||||
|
||||
|
||||
# CHECK: f
|
||||
# CHECK: PASS - check_legal_name
|
||||
@run_test
|
||||
def check_legal_name():
|
||||
target = torch.ops.aten.native_batch_norm.default
|
||||
A = torch.randn(1, 3, 32, 32)
|
||||
B = torch.randn(1, 3, 32, 32)
|
||||
C = torch.randn(1, 3, 32, 32)
|
||||
args = (A, B, C, None, None, False, 1.0, 1.0)
|
||||
kwargs = dict()
|
||||
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.name)
|
||||
|
||||
|
||||
# CHECK: graph(%[[H3:.*]] : Tensor,
|
||||
# CHECK: %[[H4:.*]] : Tensor,
|
||||
# CHECK: %[[H5:.*]] : Tensor,
|
||||
# CHECK: %[[H6:.*]] : Tensor,
|
||||
# CHECK: %[[H7:.*]] : Tensor,
|
||||
# CHECK: %out : Tensor,
|
||||
# CHECK: %save_mean : Tensor,
|
||||
# CHECK: %save_invstd : Tensor):
|
||||
# CHECK: %[[H8:.*]] : bool = prim::Constant[value=0]()
|
||||
# CHECK: %[[H9:.*]] : float = prim::Constant[value=0.10000000000000001]()
|
||||
# CHECK: %[[H10:.*]] : float = prim::Constant[value=0.0001]()
|
||||
# CHECK: %[[H0:.*]] : Tensor, %[[H1:.*]] : Tensor, %[[H2:.*]] : Tensor = aten::native_batch_norm(%[[H3]], %[[H4]], %[[H5]], %[[H6]], %[[H7]], %[[H8]], %[[H9]], %[[H10]], %out, %save_mean, %save_invstd)
|
||||
# CHECK: return (%[[H0]], %[[H1]], %[[H2]])
|
||||
# -----
|
||||
# CHECK: PASS - correctly_order_kwargs
|
||||
@run_test
|
||||
def correctly_order_kwargs():
|
||||
target = torch.ops.aten.native_batch_norm.out
|
||||
|
||||
input = torch.randn(2, 5, 2, 3)
|
||||
weight = torch.randn(5)
|
||||
bias = torch.randn(5)
|
||||
running_mean = torch.randn(5)
|
||||
running_var = torch.randn(5)
|
||||
args = (input, weight, bias, running_mean, running_var)
|
||||
|
||||
out = torch.empty_like(input)
|
||||
save_mean = torch.empty_like(running_mean)
|
||||
save_invstd = torch.empty_like(running_var)
|
||||
|
||||
kwargs = dict(
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.0001,
|
||||
out=out,
|
||||
save_mean=save_mean,
|
||||
save_invstd=save_invstd,
|
||||
)
|
||||
|
||||
script_fun = build_script_function(target._schema, args, kwargs)
|
||||
print(script_fun.graph)
|
|
@ -0,0 +1,23 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: true
|
||||
|
||||
|
||||
def run_test(*args, XPASS=False, XFAIL=False):
|
||||
def _run_test(test):
|
||||
test_name = test.__name__
|
||||
try:
|
||||
test()
|
||||
print(("X" if XPASS else "") + f"PASS - {test_name}")
|
||||
except Exception as e:
|
||||
print(("X" if XFAIL else "") + f"FAIL - {test_name}")
|
||||
print("Errors: ", e)
|
||||
print()
|
||||
|
||||
if len(args):
|
||||
_run_test(args[0])
|
||||
else:
|
||||
return _run_test
|
|
@ -8,19 +8,12 @@
|
|||
|
||||
import torch
|
||||
|
||||
from framework import run_test
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs
|
||||
|
||||
|
||||
def run_test(test, XFAIL=False, XPASS=False):
|
||||
try:
|
||||
test()
|
||||
print(("X" if XPASS else "") + f"PASS - {test.__name__}")
|
||||
except Exception as e:
|
||||
print(("X" if XFAIL else "") + f"FAIL - {test.__name__}")
|
||||
print(e)
|
||||
|
||||
|
||||
# CHECK: PASS - should_normalize
|
||||
@run_test
|
||||
def should_normalize():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
|
@ -44,6 +37,7 @@ def should_normalize():
|
|||
|
||||
# CHECK: FAIL - shouldnt_normalize1
|
||||
# CHECK: Couldn't normalize args and kwargs
|
||||
@run_test
|
||||
def shouldnt_normalize1():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
|
@ -58,6 +52,7 @@ def shouldnt_normalize1():
|
|||
# TODO(max): change these to FAIL when the upstream bug is fixed.
|
||||
|
||||
# CHECK: XPASS - shouldnt_normalize2
|
||||
@run_test(XPASS=True)
|
||||
def shouldnt_normalize2():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
|
@ -66,19 +61,9 @@ def shouldnt_normalize2():
|
|||
|
||||
|
||||
# CHECK: XPASS - shouldnt_normalize3
|
||||
@run_test(XPASS=True)
|
||||
def shouldnt_normalize3():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"kernel_size": [3, 3], "padding": None}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
run_test(should_normalize)
|
||||
run_test(shouldnt_normalize1)
|
||||
run_test(shouldnt_normalize2, XPASS=True)
|
||||
run_test(shouldnt_normalize3, XPASS=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -125,8 +125,7 @@ def build_script_function(
|
|||
else:
|
||||
graph.registerOutput(node.output())
|
||||
|
||||
fn_name = str(node).strip()
|
||||
fn = torch._C._create_function_from_graph(fn_name, graph)
|
||||
fn = torch._C._create_function_from_graph("f", graph)
|
||||
return fn
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue