[TORCH][MLIR] Add E2E support for `aten.[ones_like|zeros_like]`

- This commit adds E2E support for `aten.ones_like` and
  `aten.zeros_like` ops.
- Adds support for non-None `dtype` argument of `aten.empty_like` op.
- All the unit test cases related to constant tensor allocation like ops
  are moved to a different file named `constant_alloc.py`.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/503/head
Gaurav Shukla 2021-12-21 17:21:19 +05:30
parent 9afaacedbd
commit 3c40539b34
7 changed files with 501 additions and 284 deletions

View File

@ -586,144 +586,6 @@ def ExpandModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class OnesModuleInt(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.int64)
@register_test_case(module_factory=lambda: OnesModuleInt())
def OnesModuleInt_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class OnesModuleFloat(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.float32)
@register_test_case(module_factory=lambda: OnesModuleFloat())
def OnesModuleFloat_basic(module, tu: TestUtils):
module.forward()
class OnesModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: OnesModuleFalsePinMemory())
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return 0 * torch.empty((3, 4), dtype=torch.int64)
@register_test_case(module_factory=lambda: EmptyIntModule())
def EmptyModule_int(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.pow(torch.empty((3, 4), dtype=torch.float32), 0)
@register_test_case(module_factory=lambda: EmptyFloatModule())
def EmptyModule_float(module, tu: TestUtils):
module.forward()
class EmptyFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.pow(torch.empty((3, 4), dtype=torch.float32,
pin_memory=False), 0)
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
def EmptyModule_falsePinMemory(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return 0 * torch.empty_like(a, dtype=torch.int64)
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
def EmptyLikeModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)))
# ==============================================================================
class EmptyLikeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.pow(torch.empty_like(a, dtype=torch.float32), 0)
@register_test_case(module_factory=lambda: EmptyLikeFloatModule())
def EmptyLikeModule_float(module, tu: TestUtils):
module.forward(tu.rand(4, 5))
# ==============================================================================
class ContiguousModule(torch.nn.Module): class ContiguousModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -926,57 +788,6 @@ def DropoutModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
class Fill_TensorFloat64WithFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3.0)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3.0)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).to(torch.float64))
class Fill_TensorFloat64WithInt64(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).to(torch.float64))
class MeanModule(torch.nn.Module): class MeanModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1047,86 +858,6 @@ def NumelZeroRankModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10,[])) module.forward(torch.randint(10,[]))
class ZerosModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.int64)
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
def ZerosModuleInt2D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, 5, dtype=torch.int64)
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
def ZerosModuleInt3D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, 5, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
class BoolTensorReturnFalseModule(torch.nn.Module): class BoolTensorReturnFalseModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1181,6 +912,7 @@ def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool)) module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
# ============================================================================== # ==============================================================================
class TModuleRank2(torch.nn.Module): class TModuleRank2(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1193,11 +925,11 @@ class TModuleRank2(torch.nn.Module):
def forward(self, lhs): def forward(self, lhs):
return torch.t(lhs) return torch.t(lhs)
@register_test_case(module_factory=lambda: TModuleRank2()) @register_test_case(module_factory=lambda: TModuleRank2())
def TModuleRank2_basic(module, tu: TestUtils): def TModuleRank2_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
class TModuleRank1(torch.nn.Module): class TModuleRank1(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1210,11 +942,11 @@ class TModuleRank1(torch.nn.Module):
def forward(self, lhs): def forward(self, lhs):
return torch.t(lhs) return torch.t(lhs)
@register_test_case(module_factory=lambda: TModuleRank1()) @register_test_case(module_factory=lambda: TModuleRank1())
def TModuleRank1_basic(module, tu: TestUtils): def TModuleRank1_basic(module, tu: TestUtils):
module.forward(tu.rand(3)) module.forward(tu.rand(3))
class TModuleRank0(torch.nn.Module): class TModuleRank0(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1227,7 +959,6 @@ class TModuleRank0(torch.nn.Module):
def forward(self, lhs): def forward(self, lhs):
return torch.t(lhs) return torch.t(lhs)
@register_test_case(module_factory=lambda: TModuleRank0()) @register_test_case(module_factory=lambda: TModuleRank0())
def TModuleRank0_basic(module, tu: TestUtils): def TModuleRank0_basic(module, tu: TestUtils):
module.forward(torch.tensor(7, dtype=torch.float32)) module.forward(torch.tensor(7, dtype=torch.float32))

View File

@ -0,0 +1,400 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ZerosModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.int64)
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
def ZerosModuleInt2D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleInt3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, 5, dtype=torch.int64)
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
def ZerosModuleInt3D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFloat3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, 5, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class OnesModuleInt(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.int64)
@register_test_case(module_factory=lambda: OnesModuleInt())
def OnesModuleInt_basic(module, tu: TestUtils):
module.forward()
class OnesModuleFloat(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.float32)
@register_test_case(module_factory=lambda: OnesModuleFloat())
def OnesModuleFloat_basic(module, tu: TestUtils):
module.forward()
class OnesModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: OnesModuleFalsePinMemory())
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return 0 * torch.empty((3, 4), dtype=torch.int64)
@register_test_case(module_factory=lambda: EmptyIntModule())
def EmptyModule_int(module, tu: TestUtils):
module.forward()
class EmptyFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.pow(torch.empty((3, 4), dtype=torch.float32), 0)
@register_test_case(module_factory=lambda: EmptyFloatModule())
def EmptyModule_float(module, tu: TestUtils):
module.forward()
class EmptyFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.pow(torch.empty((3, 4), dtype=torch.float32,
pin_memory=False), 0)
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
def EmptyModule_falsePinMemory(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return 0 * torch.empty_like(a, dtype=torch.int32)
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
def EmptyLikeModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)))
class EmptyLikeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.pow(torch.empty_like(a, dtype=torch.float32), 0)
@register_test_case(module_factory=lambda: EmptyLikeFloatModule())
def EmptyLikeModule_float(module, tu: TestUtils):
module.forward(tu.rand(4, 5))
class EmptyLikeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.pow(torch.empty_like(a, dtype=torch.float64,
pin_memory=False), 0)
@register_test_case(module_factory=lambda: EmptyLikeFalsePinMemoryModule())
def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class ZerosLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.zeros_like(a, dtype=torch.int32)
@register_test_case(module_factory=lambda: ZerosLikeIntModule())
def ZerosLikeModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)))
class ZerosLikeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.zeros_like(a, dtype=torch.float32)
@register_test_case(module_factory=lambda: ZerosLikeFloatModule())
def ZerosLikeModule_float(module, tu: TestUtils):
module.forward(tu.rand(4, 5))
class ZerosLikeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.zeros_like(a, dtype=torch.float64, pin_memory=False)
@register_test_case(module_factory=lambda: ZerosLikeFalsePinMemoryModule())
def ZerosLikeModule_falsePinMemory(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class OnesLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, a):
return torch.ones_like(a, dtype=torch.int32)
@register_test_case(module_factory=lambda: OnesLikeIntModule())
def OnesLikeModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)))
class OnesLikeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ones_like(a, dtype=torch.float32)
@register_test_case(module_factory=lambda: OnesLikeFloatModule())
def OnesLikeModule_float(module, tu: TestUtils):
module.forward(tu.rand(4, 5))
class OnesLikeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ones_like(a, dtype=torch.float64, pin_memory=False)
@register_test_case(module_factory=lambda: OnesLikeFalsePinMemoryModule())
def OnesLikeModule_falsePinMemory(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class Fill_TensorFloat64WithFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3.0)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3.0)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).to(torch.float64))
class Fill_TensorFloat64WithInt64(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3)
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).to(torch.float64))

View File

@ -46,6 +46,7 @@ from . import slice_like
from . import nll_loss from . import nll_loss
from . import index_select from . import index_select
from . import arange from . import arange
from . import constant_alloc
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -2160,6 +2160,44 @@ def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)"; let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
} }
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dtype,
TorchOptionalIntType:$layout,
TorchOptionalDeviceType:$device,
TorchOptionalBoolType:$pin_memory,
TorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
}
def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$dtype,
TorchOptionalIntType:$layout,
TorchOptionalDeviceType:$device,
TorchOptionalBoolType:$pin_memory,
TorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dtype `,` $layout `,` $device `,` $pin_memory `,` $memory_format attr-dict `:` type($self) `,` type($dtype) `,` type($layout) `,` type($device) `,` type($pin_memory) `,` type($memory_format) `->` type($result)";
}
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics

View File

@ -536,8 +536,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
}; };
} // namespace } // namespace
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
namespace { namespace {
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> { class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@ -547,14 +547,6 @@ public:
Torch::ListType::get(Torch::IntType::get(op.getContext())); Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList = Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.self()); rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.self());
// TODO: Handle the case when `dtype` is NoneType.
Type dtype = op.dtype().getType();
if (dtype.isa<OptionalType>() || dtype.isa<Torch::NoneType>() ||
dtype.isa<mlir::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: None dtype is not supported");
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>( rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(), op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(),
op.pin_memory(), op.memory_format()); op.pin_memory(), op.memory_format());
@ -605,6 +597,28 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
}; };
} // namespace } // namespace
namespace {
// Decompose constant tensor allocation like ops.
template <typename OpTy, int fillVal>
class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Allocate a memory block.
Value initTensor = rewriter.create<AtenEmptyLikeOp>(
loc, op.getType(), op.self(), op.dtype(), op.layout(), op.device(),
op.pin_memory(), op.memory_format());
Value constVal = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(fillVal));
// Initialize the allocated memory block with `fillVal`.
rewriter.replaceOpWithNewOp<AtenFill_ScalarOp>(op, initTensor.getType(),
initTensor, constVal);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeComplexOpsPass class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> { : public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -622,6 +636,12 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenLogSoftmaxIntOp>(); target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenEmptyLikeOp>(context); patterns.add<DecomposeAtenEmptyLikeOp>(context);
target.addIllegalOp<AtenEmptyLikeOp>(); target.addIllegalOp<AtenEmptyLikeOp>();
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(
context);
target.addIllegalOp<AtenOnesLikeOp>();
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
context);
target.addIllegalOp<AtenZerosLikeOp>();
patterns.add<DecomposeAtenExpandOp>(context); patterns.add<DecomposeAtenExpandOp>(context);
target.addIllegalOp<AtenExpandOp>(); target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenSizeOp>(context); patterns.add<DecomposeAtenSizeOp>(context);

View File

@ -234,15 +234,15 @@ public:
visitOperation(Operation *op, visitOperation(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final { ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp, if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenCeilOp,
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp,
AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenEmptyLikeOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenReciprocalOp,
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp,
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) { AtenAbsOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }
@ -395,6 +395,14 @@ public:
} else if (auto emptyMemoryFormat = dyn_cast<AtenEmptyMemoryFormatOp>(op)) { } else if (auto emptyMemoryFormat = dyn_cast<AtenEmptyMemoryFormatOp>(op)) {
return visitConstantTensorAllocOp<AtenEmptyMemoryFormatOp>( return visitConstantTensorAllocOp<AtenEmptyMemoryFormatOp>(
emptyMemoryFormat); emptyMemoryFormat);
} else if (auto zerosLike = dyn_cast<AtenZerosLikeOp>(op)) {
return visitConstantTensorAllocLikeOp<AtenZerosLikeOp>(zerosLike,
operands);
} else if (auto onesLike = dyn_cast<AtenOnesLikeOp>(op)) {
return visitConstantTensorAllocLikeOp<AtenOnesLikeOp>(onesLike, operands);
} else if (auto emptyLike = dyn_cast<AtenEmptyLikeOp>(op)) {
return visitConstantTensorAllocLikeOp<AtenEmptyLikeOp>(emptyLike,
operands);
} else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) { } else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
return visitAtenToDtypeOp(toDtype, operands); return visitAtenToDtypeOp(toDtype, operands);
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) { } else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
@ -566,6 +574,9 @@ private:
ChangeResult visitScalarToTensorConversionOp(OpTy op); ChangeResult visitScalarToTensorConversionOp(OpTy op);
ChangeResult visitAtenTensorOp(AtenTensorOp op); ChangeResult visitAtenTensorOp(AtenTensorOp op);
template <typename OpTy> ChangeResult visitConstantTensorAllocOp(OpTy op); template <typename OpTy> ChangeResult visitConstantTensorAllocOp(OpTy op);
template <typename OpTy>
ChangeResult visitConstantTensorAllocLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult ChangeResult
visitAtenToDtypeOp(AtenToDtypeOp op, visitAtenToDtypeOp(AtenToDtypeOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@ -1407,6 +1418,20 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocOp(OpTy op) {
return getLatticeElement(op.getResult()).join(knowledge); return getLatticeElement(op.getResult()).join(knowledge);
} }
template <typename OpTy>
ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
if (input.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes = input.sizes;
}
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), input.dtype);
return getLatticeElement(op.getResult()).join(knowledge);
}
// Convert input tensor type to the given `dtype`. // Convert input tensor type to the given `dtype`.
ChangeResult TypeAnalyzer::visitAtenToDtypeOp( ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
AtenToDtypeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { AtenToDtypeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {

View File

@ -565,6 +565,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::detach : (Tensor) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")