[LINALG] Add support for preserve memory format in aten_empty_like op.

The preserve memory specifies that `If any of the input tensors is in channels_last format,
operator output should be in channels_last format` and hence can be
added as is in aten_empty_like op.
pull/859/head snapshot-20220510.442
Prashant Kumar 2022-05-06 04:53:41 +00:00
parent 5a6210b35b
commit 2b1b0f6e19
2 changed files with 32 additions and 7 deletions

View File

@ -196,13 +196,19 @@ public:
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be either None or false");
// Only `none` and `contiguous` memory_format is supported.
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.memory_format().getType().isa<Torch::NoneType>()) {
int64_t memoryFormat;
if (!op.memory_format().getType().isa<Torch::NoneType>() &&
(!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous))
if (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
op, "unimplemented: only default memory format is supported");
op, "unimplemented: the memory format should be specified in "
"an integer constant");
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::Preserve)
return rewriter.notifyMatchFailure(
op, "unimplemented: only none, contiguous and preserve "
"memory_format is supported");
}
Location loc = op.getLoc();
TypeConverter *typeConverter = this->getTypeConverter();

View File

@ -331,6 +331,25 @@ def EmptyLikeModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)))
class EmptyLikeMemoryFormatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.empty_like(a, memory_format=torch.preserve_format).fill_(0)
@register_test_case(module_factory=lambda: EmptyLikeMemoryFormatModule())
def EmptyLikeMemoryFormatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, 2, 1))
class EmptyLikeFloatModule(torch.nn.Module):
def __init__(self):