mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for `aten.[ones_like|zeros_like]`
- This commit adds E2E support for `aten.ones_like` and `aten.zeros_like` ops. - Adds support for non-None `dtype` argument of `aten.empty_like` op. - All the unit test cases related to constant tensor allocation like ops are moved to a different file named `constant_alloc.py`. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/503/head
parent
9afaacedbd
commit
3c40539b34
|
@ -586,144 +586,6 @@ def ExpandModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class OnesModuleInt(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.ones(3, 4, dtype=torch.int64)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: OnesModuleInt())
|
|
||||||
def OnesModuleInt_basic(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class OnesModuleFloat(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.ones(3, 4, dtype=torch.float32)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: OnesModuleFloat())
|
|
||||||
def OnesModuleFloat_basic(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
|
|
||||||
class OnesModuleFalsePinMemory(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.ones(3, 4, dtype=torch.float32, pin_memory=False)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: OnesModuleFalsePinMemory())
|
|
||||||
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class EmptyIntModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return 0 * torch.empty((3, 4), dtype=torch.int64)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: EmptyIntModule())
|
|
||||||
def EmptyModule_int(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class EmptyFloatModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.pow(torch.empty((3, 4), dtype=torch.float32), 0)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: EmptyFloatModule())
|
|
||||||
def EmptyModule_float(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyFalsePinMemoryModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.pow(torch.empty((3, 4), dtype=torch.float32,
|
|
||||||
pin_memory=False), 0)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
|
|
||||||
def EmptyModule_falsePinMemory(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class EmptyLikeIntModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([-1, -1], torch.int64, True),
|
|
||||||
])
|
|
||||||
def forward(self, a):
|
|
||||||
return 0 * torch.empty_like(a, dtype=torch.int64)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
|
|
||||||
def EmptyLikeModule_int(module, tu: TestUtils):
|
|
||||||
module.forward(torch.randint(10, (3, 5)))
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class EmptyLikeFloatModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([-1, -1], torch.float32, True),
|
|
||||||
])
|
|
||||||
def forward(self, a):
|
|
||||||
return torch.pow(torch.empty_like(a, dtype=torch.float32), 0)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: EmptyLikeFloatModule())
|
|
||||||
def EmptyLikeModule_float(module, tu: TestUtils):
|
|
||||||
module.forward(tu.rand(4, 5))
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class ContiguousModule(torch.nn.Module):
|
class ContiguousModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -926,57 +788,6 @@ def DropoutModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
|
||||||
class Fill_TensorFloat64WithFloat32(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([-1, -1, -1], torch.float32, True),
|
|
||||||
])
|
|
||||||
def forward(self, tensor):
|
|
||||||
return torch.ops.aten.fill_(tensor, 3.0)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
|
|
||||||
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
|
|
||||||
module.forward(torch.randn(3, 2, 4))
|
|
||||||
|
|
||||||
|
|
||||||
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([-1, -1, -1], torch.float64, True),
|
|
||||||
])
|
|
||||||
def forward(self, tensor):
|
|
||||||
return torch.ops.aten.fill_(tensor, 3.0)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
|
|
||||||
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
|
|
||||||
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
|
||||||
|
|
||||||
|
|
||||||
class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
([-1, -1, -1], torch.float64, True),
|
|
||||||
])
|
|
||||||
def forward(self, tensor):
|
|
||||||
return torch.ops.aten.fill_(tensor, 3)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
|
|
||||||
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
|
|
||||||
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
|
||||||
|
|
||||||
|
|
||||||
class MeanModule(torch.nn.Module):
|
class MeanModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1047,86 +858,6 @@ def NumelZeroRankModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(10,[]))
|
module.forward(torch.randint(10,[]))
|
||||||
|
|
||||||
|
|
||||||
class ZerosModuleInt2D(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.zeros(3, 4, dtype=torch.int64)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
|
|
||||||
def ZerosModuleInt2D_basic(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
|
|
||||||
class ZerosModuleInt3D(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.zeros(3, 4, 5, dtype=torch.int64)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
|
|
||||||
def ZerosModuleInt3D_basic(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
|
|
||||||
class ZerosModuleFloat2D(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.zeros(3, 4, dtype=torch.float32)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
|
|
||||||
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
|
|
||||||
class ZerosModuleFloat3D(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.zeros(3, 4, 5, dtype=torch.float32)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
|
|
||||||
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
|
|
||||||
class ZerosModuleFalsePinMemory(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
@export
|
|
||||||
@annotate_args([
|
|
||||||
None,
|
|
||||||
])
|
|
||||||
def forward(self):
|
|
||||||
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
|
|
||||||
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
|
|
||||||
module.forward()
|
|
||||||
|
|
||||||
|
|
||||||
class BoolTensorReturnFalseModule(torch.nn.Module):
|
class BoolTensorReturnFalseModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1181,6 +912,7 @@ def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
|
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class TModuleRank2(torch.nn.Module):
|
class TModuleRank2(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1193,11 +925,11 @@ class TModuleRank2(torch.nn.Module):
|
||||||
def forward(self, lhs):
|
def forward(self, lhs):
|
||||||
return torch.t(lhs)
|
return torch.t(lhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TModuleRank2())
|
@register_test_case(module_factory=lambda: TModuleRank2())
|
||||||
def TModuleRank2_basic(module, tu: TestUtils):
|
def TModuleRank2_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
|
||||||
class TModuleRank1(torch.nn.Module):
|
class TModuleRank1(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1210,11 +942,11 @@ class TModuleRank1(torch.nn.Module):
|
||||||
def forward(self, lhs):
|
def forward(self, lhs):
|
||||||
return torch.t(lhs)
|
return torch.t(lhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TModuleRank1())
|
@register_test_case(module_factory=lambda: TModuleRank1())
|
||||||
def TModuleRank1_basic(module, tu: TestUtils):
|
def TModuleRank1_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3))
|
module.forward(tu.rand(3))
|
||||||
|
|
||||||
|
|
||||||
class TModuleRank0(torch.nn.Module):
|
class TModuleRank0(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1227,7 +959,6 @@ class TModuleRank0(torch.nn.Module):
|
||||||
def forward(self, lhs):
|
def forward(self, lhs):
|
||||||
return torch.t(lhs)
|
return torch.t(lhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TModuleRank0())
|
@register_test_case(module_factory=lambda: TModuleRank0())
|
||||||
def TModuleRank0_basic(module, tu: TestUtils):
|
def TModuleRank0_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.tensor(7, dtype=torch.float32))
|
module.forward(torch.tensor(7, dtype=torch.float32))
|
||||||
|
|
|
@ -0,0 +1,400 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ZerosModuleInt2D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, dtype=torch.int64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
|
||||||
|
def ZerosModuleInt2D_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleInt3D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, 5, dtype=torch.int64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
|
||||||
|
def ZerosModuleInt3D_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleFloat2D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, dtype=torch.float32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
|
||||||
|
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleFloat3D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, 5, dtype=torch.float32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
|
||||||
|
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosModuleFalsePinMemory(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
|
||||||
|
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class OnesModuleInt(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.ones(3, 4, dtype=torch.int64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesModuleInt())
|
||||||
|
def OnesModuleInt_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class OnesModuleFloat(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.ones(3, 4, dtype=torch.float32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesModuleFloat())
|
||||||
|
def OnesModuleFloat_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class OnesModuleFalsePinMemory(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.ones(3, 4, dtype=torch.float32, pin_memory=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesModuleFalsePinMemory())
|
||||||
|
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class EmptyIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return 0 * torch.empty((3, 4), dtype=torch.int64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyIntModule())
|
||||||
|
def EmptyModule_int(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.pow(torch.empty((3, 4), dtype=torch.float32), 0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyFloatModule())
|
||||||
|
def EmptyModule_float(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyFalsePinMemoryModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.pow(torch.empty((3, 4), dtype=torch.float32,
|
||||||
|
pin_memory=False), 0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
|
||||||
|
def EmptyModule_falsePinMemory(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class EmptyLikeIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return 0 * torch.empty_like(a, dtype=torch.int32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
|
||||||
|
def EmptyLikeModule_int(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (3, 5)))
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyLikeFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.pow(torch.empty_like(a, dtype=torch.float32), 0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyLikeFloatModule())
|
||||||
|
def EmptyLikeModule_float(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyLikeFalsePinMemoryModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.pow(torch.empty_like(a, dtype=torch.float64,
|
||||||
|
pin_memory=False), 0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyLikeFalsePinMemoryModule())
|
||||||
|
def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ZerosLikeIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.zeros_like(a, dtype=torch.int32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosLikeIntModule())
|
||||||
|
def ZerosLikeModule_int(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (3, 5)))
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosLikeFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.zeros_like(a, dtype=torch.float32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosLikeFloatModule())
|
||||||
|
def ZerosLikeModule_float(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosLikeFalsePinMemoryModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.zeros_like(a, dtype=torch.float64, pin_memory=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosLikeFalsePinMemoryModule())
|
||||||
|
def ZerosLikeModule_falsePinMemory(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class OnesLikeIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ones_like(a, dtype=torch.int32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesLikeIntModule())
|
||||||
|
def OnesLikeModule_int(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (3, 5)))
|
||||||
|
|
||||||
|
|
||||||
|
class OnesLikeFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ones_like(a, dtype=torch.float32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesLikeFloatModule())
|
||||||
|
def OnesLikeModule_float(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class OnesLikeFalsePinMemoryModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ones_like(a, dtype=torch.float64, pin_memory=False)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesLikeFalsePinMemoryModule())
|
||||||
|
def OnesLikeModule_falsePinMemory(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class Fill_TensorFloat64WithFloat32(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor):
|
||||||
|
return torch.ops.aten.fill_(tensor, 3.0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
|
||||||
|
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor):
|
||||||
|
return torch.ops.aten.fill_(tensor, 3.0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
|
||||||
|
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
|
class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor):
|
||||||
|
return torch.ops.aten.fill_(tensor, 3)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
|
||||||
|
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
||||||
|
|
|
@ -46,6 +46,7 @@ from . import slice_like
|
||||||
from . import nll_loss
|
from . import nll_loss
|
||||||
from . import index_select
|
from . import index_select
|
||||||
from . import arange
|
from . import arange
|
||||||
|
from . import constant_alloc
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||||
|
|
|
@ -2160,6 +2160,44 @@ def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
|
||||||
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
|
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
TorchOptionalIntType:$dtype,
|
||||||
|
TorchOptionalIntType:$layout,
|
||||||
|
TorchOptionalDeviceType:$device,
|
||||||
|
TorchOptionalBoolType:$pin_memory,
|
||||||
|
TorchOptionalIntType:$memory_format
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
TorchOptionalIntType:$dtype,
|
||||||
|
TorchOptionalIntType:$layout,
|
||||||
|
TorchOptionalDeviceType:$device,
|
||||||
|
TorchOptionalBoolType:$pin_memory,
|
||||||
|
TorchOptionalIntType:$memory_format
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
|
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -536,8 +536,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
|
||||||
namespace {
|
namespace {
|
||||||
|
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
||||||
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
@ -547,14 +547,6 @@ public:
|
||||||
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
||||||
Value sizeList =
|
Value sizeList =
|
||||||
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.self());
|
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.self());
|
||||||
|
|
||||||
// TODO: Handle the case when `dtype` is NoneType.
|
|
||||||
Type dtype = op.dtype().getType();
|
|
||||||
if (dtype.isa<OptionalType>() || dtype.isa<Torch::NoneType>() ||
|
|
||||||
dtype.isa<mlir::NoneType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: None dtype is not supported");
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||||||
op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(),
|
op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(),
|
||||||
op.pin_memory(), op.memory_format());
|
op.pin_memory(), op.memory_format());
|
||||||
|
@ -605,6 +597,28 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose constant tensor allocation like ops.
|
||||||
|
template <typename OpTy, int fillVal>
|
||||||
|
class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
|
||||||
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
// Allocate a memory block.
|
||||||
|
Value initTensor = rewriter.create<AtenEmptyLikeOp>(
|
||||||
|
loc, op.getType(), op.self(), op.dtype(), op.layout(), op.device(),
|
||||||
|
op.pin_memory(), op.memory_format());
|
||||||
|
Value constVal = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(fillVal));
|
||||||
|
// Initialize the allocated memory block with `fillVal`.
|
||||||
|
rewriter.replaceOpWithNewOp<AtenFill_ScalarOp>(op, initTensor.getType(),
|
||||||
|
initTensor, constVal);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -622,6 +636,12 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||||
patterns.add<DecomposeAtenEmptyLikeOp>(context);
|
patterns.add<DecomposeAtenEmptyLikeOp>(context);
|
||||||
target.addIllegalOp<AtenEmptyLikeOp>();
|
target.addIllegalOp<AtenEmptyLikeOp>();
|
||||||
|
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(
|
||||||
|
context);
|
||||||
|
target.addIllegalOp<AtenOnesLikeOp>();
|
||||||
|
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
|
||||||
|
context);
|
||||||
|
target.addIllegalOp<AtenZerosLikeOp>();
|
||||||
patterns.add<DecomposeAtenExpandOp>(context);
|
patterns.add<DecomposeAtenExpandOp>(context);
|
||||||
target.addIllegalOp<AtenExpandOp>();
|
target.addIllegalOp<AtenExpandOp>();
|
||||||
patterns.add<DecomposeAtenSizeOp>(context);
|
patterns.add<DecomposeAtenSizeOp>(context);
|
||||||
|
|
|
@ -234,15 +234,15 @@ public:
|
||||||
visitOperation(Operation *op,
|
visitOperation(Operation *op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
||||||
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
|
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
|
||||||
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp,
|
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenCeilOp,
|
||||||
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp,
|
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp,
|
||||||
AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
|
AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
|
||||||
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenEmptyLikeOp,
|
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenReciprocalOp,
|
||||||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp,
|
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp,
|
||||||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
||||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||||
AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) {
|
AtenAbsOp>(op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -395,6 +395,14 @@ public:
|
||||||
} else if (auto emptyMemoryFormat = dyn_cast<AtenEmptyMemoryFormatOp>(op)) {
|
} else if (auto emptyMemoryFormat = dyn_cast<AtenEmptyMemoryFormatOp>(op)) {
|
||||||
return visitConstantTensorAllocOp<AtenEmptyMemoryFormatOp>(
|
return visitConstantTensorAllocOp<AtenEmptyMemoryFormatOp>(
|
||||||
emptyMemoryFormat);
|
emptyMemoryFormat);
|
||||||
|
} else if (auto zerosLike = dyn_cast<AtenZerosLikeOp>(op)) {
|
||||||
|
return visitConstantTensorAllocLikeOp<AtenZerosLikeOp>(zerosLike,
|
||||||
|
operands);
|
||||||
|
} else if (auto onesLike = dyn_cast<AtenOnesLikeOp>(op)) {
|
||||||
|
return visitConstantTensorAllocLikeOp<AtenOnesLikeOp>(onesLike, operands);
|
||||||
|
} else if (auto emptyLike = dyn_cast<AtenEmptyLikeOp>(op)) {
|
||||||
|
return visitConstantTensorAllocLikeOp<AtenEmptyLikeOp>(emptyLike,
|
||||||
|
operands);
|
||||||
} else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
} else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||||
return visitAtenToDtypeOp(toDtype, operands);
|
return visitAtenToDtypeOp(toDtype, operands);
|
||||||
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
|
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
|
||||||
|
@ -566,6 +574,9 @@ private:
|
||||||
ChangeResult visitScalarToTensorConversionOp(OpTy op);
|
ChangeResult visitScalarToTensorConversionOp(OpTy op);
|
||||||
ChangeResult visitAtenTensorOp(AtenTensorOp op);
|
ChangeResult visitAtenTensorOp(AtenTensorOp op);
|
||||||
template <typename OpTy> ChangeResult visitConstantTensorAllocOp(OpTy op);
|
template <typename OpTy> ChangeResult visitConstantTensorAllocOp(OpTy op);
|
||||||
|
template <typename OpTy>
|
||||||
|
ChangeResult visitConstantTensorAllocLikeOp(
|
||||||
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitAtenToDtypeOp(AtenToDtypeOp op,
|
visitAtenToDtypeOp(AtenToDtypeOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
@ -1407,6 +1418,20 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocOp(OpTy op) {
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
|
||||||
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto input = operands[0]->getValue();
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
|
if (input.hasSizes) {
|
||||||
|
knowledge.hasSizes = true;
|
||||||
|
knowledge.sizes = input.sizes;
|
||||||
|
}
|
||||||
|
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), input.dtype);
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
// Convert input tensor type to the given `dtype`.
|
// Convert input tensor type to the given `dtype`.
|
||||||
ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
|
ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
|
||||||
AtenToDtypeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
AtenToDtypeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
|
|
@ -565,6 +565,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::detach : (Tensor) -> (Tensor)")
|
emit("aten::detach : (Tensor) -> (Tensor)")
|
||||||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||||
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
|
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
|
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
||||||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
|
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue