torch-mlir/e2e_testing/torchscript/constant_alloc.py

1009 lines
26 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ZerosModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4)
@register_test_case(module_factory=lambda: ZerosModuleDefaultDtype())
def ZerosModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.int64)
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
def ZerosModuleInt2D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, 5, dtype=torch.int64)
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
def ZerosModuleInt3D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, 5, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class OnesModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4)
@register_test_case(module_factory=lambda: OnesModuleDefaultDtype())
def OnesModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward()
class OnesModuleInt(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.int64)
@register_test_case(module_factory=lambda: OnesModuleInt())
def OnesModuleInt_basic(module, tu: TestUtils):
module.forward()
class OnesModuleFloat(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.float32)
@register_test_case(module_factory=lambda: OnesModuleFloat())
def OnesModuleFloat_basic(module, tu: TestUtils):
module.forward()
class OnesModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: OnesModuleFalsePinMemory())
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyContiguousModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.empty((3, 4),
memory_format=torch.contiguous_format).fill_(0)
@register_test_case(module_factory=lambda: EmptyContiguousModule())
def EmptyModule_contiguous(module, tu: TestUtils):
module.forward()
class EmptyDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.empty((3, 4)).fill_(0)
@register_test_case(module_factory=lambda: EmptyDefaultDtypeModule())
def EmptyModule_defaultDtype(module, tu: TestUtils):
module.forward()
class EmptyIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.empty((3, 4), dtype=torch.int64).fill_(0)
@register_test_case(module_factory=lambda: EmptyIntModule())
def EmptyModule_int(module, tu: TestUtils):
module.forward()
class EmptyFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.empty((3, 4), dtype=torch.float32).fill_(0)
@register_test_case(module_factory=lambda: EmptyFloatModule())
def EmptyModule_float(module, tu: TestUtils):
module.forward()
class EmptyFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.empty((3, 4), dtype=torch.float32,
pin_memory=False).fill_(0)
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
def EmptyModule_falsePinMemory(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyLikeDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.empty_like(a).fill_(0)
@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeModule())
def EmptyLikeModule_defaultDtype(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class EmptyLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.empty_like(a, dtype=torch.int32).fill_(0)
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
def EmptyLikeModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)))
class EmptyLikeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.empty_like(a, dtype=torch.float32).fill_(0)
@register_test_case(module_factory=lambda: EmptyLikeFloatModule())
def EmptyLikeModule_float(module, tu: TestUtils):
module.forward(tu.rand(4, 5))
class EmptyLikeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.empty_like(a, dtype=torch.float64,
pin_memory=False).fill_(0)
@register_test_case(module_factory=lambda: EmptyLikeFalsePinMemoryModule())
def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class ZerosLikeDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.zeros_like(a)
@register_test_case(module_factory=lambda: ZerosLikeDefaultDtypeModule())
def ZerosLikeModule_defaultDtype(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ZerosLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.zeros_like(a, dtype=torch.int32)
@register_test_case(module_factory=lambda: ZerosLikeIntModule())
def ZerosLikeModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)))
class ZerosLikeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.zeros_like(a, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosLikeFloatModule())
def ZerosLikeModule_float(module, tu: TestUtils):
module.forward(tu.rand(4, 5))
class ZerosLikeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.zeros_like(a, dtype=torch.float64, pin_memory=False)
@register_test_case(module_factory=lambda: ZerosLikeFalsePinMemoryModule())
def ZerosLikeModule_falsePinMemory(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class OnesLikeDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ones_like(a)
@register_test_case(module_factory=lambda: OnesLikeDefaultDtypeModule())
def OnesLikeModule_defaultDtype(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class OnesLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ones_like(a, dtype=torch.int32)
@register_test_case(module_factory=lambda: OnesLikeIntModule())
def OnesLikeModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)))
class OnesLikeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ones_like(a, dtype=torch.float32)
@register_test_case(module_factory=lambda: OnesLikeFloatModule())
def OnesLikeModule_float(module, tu: TestUtils):
module.forward(tu.rand(4, 5))
class OnesLikeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ones_like(a, dtype=torch.float64, pin_memory=False)
@register_test_case(module_factory=lambda: OnesLikeFalsePinMemoryModule())
def OnesLikeModule_falsePinMemory(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class Fill_TensorFloat64WithFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3.0)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3.0)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).to(torch.float64))
class Fill_TensorFloat64WithInt64(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).to(torch.float64))
# ==============================================================================
class NewZerosModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_zeros(a, [3, 4])
@register_test_case(module_factory=lambda: NewZerosModuleDefaultDtype())
def NewZerosModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))
class NewZerosModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.int64)
@register_test_case(module_factory=lambda: NewZerosModuleInt2D())
def NewZerosModuleInt2D_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
class NewZerosModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.int64)
@register_test_case(module_factory=lambda: NewZerosModuleInt3D())
def NewZerosModuleInt3D_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))
class NewZerosModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32)
@register_test_case(module_factory=lambda: NewZerosModuleFloat2D())
def NewZerosModuleFloat2D_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3, 4)))
class NewZerosModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.float32)
@register_test_case(module_factory=lambda: NewZerosModuleFloat3D())
def NewZerosModuleFloat3D_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)))
class NewZerosModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: NewZerosModuleFalsePinMemory())
def NewZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)))
# ==============================================================================
class NewOnesModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_ones(a, [3, 4])
@register_test_case(module_factory=lambda: NewOnesModuleDefaultDtype())
def NewOnesModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))
class NewOnesModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.int64)
@register_test_case(module_factory=lambda: NewOnesModuleInt2D())
def NewOnesModuleInt2D_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
class NewOnesModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.int64)
@register_test_case(module_factory=lambda: NewOnesModuleInt3D())
def NewOnesModuleInt3D_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))
class NewOnesModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32)
@register_test_case(module_factory=lambda: NewOnesModuleFloat2D())
def NewOnesModuleFloat2D_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3, 4)))
class NewOnesModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.float32)
@register_test_case(module_factory=lambda: NewOnesModuleFloat3D())
def NewOnesModuleFloat3D_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)))
class NewOnesModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: NewOnesModuleFalsePinMemory())
def NewOnesModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)))
# ==============================================================================
class FullModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.full([2, 3], 5.0)
@register_test_case(module_factory=lambda: FullModuleDefaultDtype())
def FullModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward()
class FullModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.full([10, 5], 10.5, dtype=torch.int64)
@register_test_case(module_factory=lambda: FullModuleInt2D())
def FullModuleInt2D_basic(module, tu: TestUtils):
module.forward()
class FullModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.full([2, 3, 4], 5)
@register_test_case(module_factory=lambda: FullModuleInt3D())
def FullModuleInt3D_basic(module, tu: TestUtils):
module.forward()
class FullModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.full([10, 5], 10, dtype=torch.float32)
@register_test_case(module_factory=lambda: FullModuleFloat2D())
def FullModuleFloat2D_basic(module, tu: TestUtils):
module.forward()
class FullModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.full([2, 3, 4], 5.0)
@register_test_case(module_factory=lambda: FullModuleFloat3D())
def FullModuleFloat3D_basic(module, tu: TestUtils):
module.forward()
class FullModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.full([2, 3], 5.0, dtype=torch.int64, pin_memory=False)
@register_test_case(module_factory=lambda: FullModuleFalsePinMemory())
def FullModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class FullLikeModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 5)
@register_test_case(module_factory=lambda: FullLikeModuleDefaultDtype())
def FullLikeModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))
class FullLikeModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 10.5)
@register_test_case(module_factory=lambda: FullLikeModuleInt2D())
def FullLikeModuleInt2D_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4, 5)))
class FullLikeModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 5.0, dtype=torch.int64)
@register_test_case(module_factory=lambda: FullLikeModuleInt3D())
def FullLikeModuleInt3D_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10, 4, 5)).to(torch.int32))
class FullLikeModuleInt2DStatic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 10)
@register_test_case(module_factory=lambda: FullLikeModuleInt2DStatic())
def FullLikeModuleInt2DStatic_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4, 5)))
class FullLikeModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 10)
@register_test_case(module_factory=lambda: FullLikeModuleFloat2D())
def FullLikeModuleFloat2D_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class FullLikeModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 15, dtype=torch.float32)
@register_test_case(module_factory=lambda: FullLikeModuleFloat3D())
def FullLikeModuleFloat3D_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
class FullLikeModuleFloat3DStatic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 15.3, dtype=torch.float32)
@register_test_case(module_factory=lambda: FullLikeModuleFloat3DStatic())
def FullLikeModuleFloat3DStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
class FullLikeModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 5, dtype=torch.int64, pin_memory=False)
@register_test_case(module_factory=lambda: FullLikeModuleFalsePinMemory())
def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10, 4)))
# ==============================================================================
class ZeroFloat32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten.zero_(tensor)
@register_test_case(module_factory=lambda: ZeroFloat32Module())
def ZeroFloat32Module_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 2))
class ZeroInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, tensor):
return torch.ops.aten.zero_(tensor)
@register_test_case(module_factory=lambda: ZeroInt32Module())
def ZeroInt32Module_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10, 4), dtype=torch.int32))
class ZeroInt64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, tensor):
return torch.ops.aten.zero_(tensor)
@register_test_case(module_factory=lambda: ZeroInt64Module())
def ZeroInt64Module_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10, 4)))