2021-12-21 19:51:19 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
|
|
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
|
|
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
2022-02-11 10:02:18 +08:00
|
|
|
class ZerosModuleDefaultDtype(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.zeros(3, 4)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosModuleDefaultDtype())
|
|
|
|
def ZerosModuleDefaultDtype_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
class ZerosModuleInt2D(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.zeros(3, 4, dtype=torch.int64)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
|
|
|
|
def ZerosModuleInt2D_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
|
|
|
class ZerosModuleInt3D(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.zeros(3, 4, 5, dtype=torch.int64)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
|
|
|
|
def ZerosModuleInt3D_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
|
|
|
class ZerosModuleFloat2D(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.zeros(3, 4, dtype=torch.float32)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
|
|
|
|
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
|
|
|
class ZerosModuleFloat3D(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.zeros(3, 4, 5, dtype=torch.float32)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
|
|
|
|
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
|
|
|
class ZerosModuleFalsePinMemory(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
|
|
|
|
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
2022-02-11 10:02:18 +08:00
|
|
|
class OnesModuleDefaultDtype(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.ones(3, 4)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: OnesModuleDefaultDtype())
|
|
|
|
def OnesModuleDefaultDtype_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
class OnesModuleInt(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.ones(3, 4, dtype=torch.int64)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: OnesModuleInt())
|
|
|
|
def OnesModuleInt_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
|
|
|
class OnesModuleFloat(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.ones(3, 4, dtype=torch.float32)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: OnesModuleFloat())
|
|
|
|
def OnesModuleFloat_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
|
|
|
class OnesModuleFalsePinMemory(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
|
|
|
return torch.ones(3, 4, dtype=torch.float32, pin_memory=False)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: OnesModuleFalsePinMemory())
|
|
|
|
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
2022-02-11 10:02:18 +08:00
|
|
|
class EmptyDefaultDtypeModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
2022-02-16 03:58:03 +08:00
|
|
|
return torch.empty((3, 4)).fill_(0)
|
2022-02-11 10:02:18 +08:00
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: EmptyDefaultDtypeModule())
|
|
|
|
def EmptyModule_defaultDtype(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
class EmptyIntModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
2022-02-16 03:58:03 +08:00
|
|
|
return torch.empty((3, 4), dtype=torch.int64).fill_(0)
|
2021-12-21 19:51:19 +08:00
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: EmptyIntModule())
|
|
|
|
def EmptyModule_int(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyFloatModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
2022-02-16 03:58:03 +08:00
|
|
|
return torch.empty((3, 4), dtype=torch.float32).fill_(0)
|
2021-12-21 19:51:19 +08:00
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: EmptyFloatModule())
|
|
|
|
def EmptyModule_float(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyFalsePinMemoryModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
])
|
|
|
|
def forward(self):
|
2022-02-16 03:58:03 +08:00
|
|
|
return torch.empty((3, 4), dtype=torch.float32,
|
|
|
|
pin_memory=False).fill_(0)
|
2021-12-21 19:51:19 +08:00
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
|
|
|
|
def EmptyModule_falsePinMemory(module, tu: TestUtils):
|
|
|
|
module.forward()
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
2022-02-11 10:02:18 +08:00
|
|
|
class EmptyLikeDefaultDtypeModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
2022-02-16 03:58:03 +08:00
|
|
|
return torch.empty_like(a).fill_(0)
|
2022-02-11 10:02:18 +08:00
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeModule())
|
|
|
|
def EmptyLikeModule_defaultDtype(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(3, 5))
|
|
|
|
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
class EmptyLikeIntModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.int64, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
2022-02-16 03:58:03 +08:00
|
|
|
return torch.empty_like(a, dtype=torch.int32).fill_(0)
|
2021-12-21 19:51:19 +08:00
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
|
|
|
|
def EmptyLikeModule_int(module, tu: TestUtils):
|
|
|
|
module.forward(torch.randint(10, (3, 5)))
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyLikeFloatModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
2022-02-16 03:58:03 +08:00
|
|
|
return torch.empty_like(a, dtype=torch.float32).fill_(0)
|
2021-12-21 19:51:19 +08:00
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: EmptyLikeFloatModule())
|
|
|
|
def EmptyLikeModule_float(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(4, 5))
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyLikeFalsePinMemoryModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
2022-02-16 03:58:03 +08:00
|
|
|
return torch.empty_like(a, dtype=torch.float64,
|
|
|
|
pin_memory=False).fill_(0)
|
2021-12-21 19:51:19 +08:00
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: EmptyLikeFalsePinMemoryModule())
|
|
|
|
def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(2, 3, 4))
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
2022-02-11 10:02:18 +08:00
|
|
|
class ZerosLikeDefaultDtypeModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
|
|
|
return torch.zeros_like(a)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosLikeDefaultDtypeModule())
|
|
|
|
def ZerosLikeModule_defaultDtype(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(3, 5))
|
|
|
|
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
class ZerosLikeIntModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.int64, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
|
|
|
return torch.zeros_like(a, dtype=torch.int32)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosLikeIntModule())
|
|
|
|
def ZerosLikeModule_int(module, tu: TestUtils):
|
|
|
|
module.forward(torch.randint(10, (3, 5)))
|
|
|
|
|
|
|
|
|
|
|
|
class ZerosLikeFloatModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
|
|
|
return torch.zeros_like(a, dtype=torch.float32)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosLikeFloatModule())
|
|
|
|
def ZerosLikeModule_float(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(4, 5))
|
|
|
|
|
|
|
|
|
|
|
|
class ZerosLikeFalsePinMemoryModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
|
|
|
return torch.zeros_like(a, dtype=torch.float64, pin_memory=False)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: ZerosLikeFalsePinMemoryModule())
|
|
|
|
def ZerosLikeModule_falsePinMemory(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(2, 3, 4))
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
2022-02-11 10:02:18 +08:00
|
|
|
class OnesLikeDefaultDtypeModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
|
|
|
return torch.ones_like(a)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: OnesLikeDefaultDtypeModule())
|
|
|
|
def OnesLikeModule_defaultDtype(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(3, 5))
|
|
|
|
|
|
|
|
|
2021-12-21 19:51:19 +08:00
|
|
|
class OnesLikeIntModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.int64, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
|
|
|
return torch.ones_like(a, dtype=torch.int32)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: OnesLikeIntModule())
|
|
|
|
def OnesLikeModule_int(module, tu: TestUtils):
|
|
|
|
module.forward(torch.randint(10, (3, 5)))
|
|
|
|
|
|
|
|
|
|
|
|
class OnesLikeFloatModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
|
|
|
return torch.ones_like(a, dtype=torch.float32)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: OnesLikeFloatModule())
|
|
|
|
def OnesLikeModule_float(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(4, 5))
|
|
|
|
|
|
|
|
|
|
|
|
class OnesLikeFalsePinMemoryModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, a):
|
|
|
|
return torch.ones_like(a, dtype=torch.float64, pin_memory=False)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: OnesLikeFalsePinMemoryModule())
|
|
|
|
def OnesLikeModule_falsePinMemory(module, tu: TestUtils):
|
|
|
|
module.forward(tu.rand(2, 3, 4))
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
class Fill_TensorFloat64WithFloat32(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1, -1], torch.float32, True),
|
|
|
|
])
|
|
|
|
def forward(self, tensor):
|
|
|
|
return torch.ops.aten.fill_(tensor, 3.0)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
|
|
|
|
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
|
|
|
|
module.forward(torch.randn(3, 2, 4))
|
|
|
|
|
|
|
|
|
|
|
|
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
])
|
|
|
|
def forward(self, tensor):
|
|
|
|
return torch.ops.aten.fill_(tensor, 3.0)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
|
|
|
|
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
|
|
|
|
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
|
|
|
|
|
|
|
|
|
|
|
class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
])
|
|
|
|
def forward(self, tensor):
|
|
|
|
return torch.ops.aten.fill_(tensor, 3)
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
|
|
|
|
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
|
|
|
|
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
|
|
|
|