mirror of https://github.com/llvm/torch-mlir
[LINALG] Add support for preserve memory format in aten_empty_like op.
The preserve memory specifies that `If any of the input tensors is in channels_last format, operator output should be in channels_last format` and hence can be added as is in aten_empty_like op.pull/859/head snapshot-20220510.442
parent
5a6210b35b
commit
2b1b0f6e19
|
@ -196,13 +196,19 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: pin_memory must be either None or false");
|
||||
|
||||
// Only `none` and `contiguous` memory_format is supported.
|
||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||
if (!op.memory_format().getType().isa<Torch::NoneType>()) {
|
||||
int64_t memoryFormat;
|
||||
if (!op.memory_format().getType().isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)) ||
|
||||
memoryFormat != torch_upstream::MemoryFormat::Contiguous))
|
||||
if (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default memory format is supported");
|
||||
op, "unimplemented: the memory format should be specified in "
|
||||
"an integer constant");
|
||||
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
memoryFormat != torch_upstream::MemoryFormat::Preserve)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only none, contiguous and preserve "
|
||||
"memory_format is supported");
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = this->getTypeConverter();
|
||||
|
|
|
@ -331,6 +331,25 @@ def EmptyLikeModule_int(module, tu: TestUtils):
|
|||
module.forward(torch.randint(10, (3, 5)))
|
||||
|
||||
|
||||
class EmptyLikeMemoryFormatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.empty_like(a, memory_format=torch.preserve_format).fill_(0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: EmptyLikeMemoryFormatModule())
|
||||
def EmptyLikeMemoryFormatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5, 2, 1))
|
||||
|
||||
|
||||
class EmptyLikeFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue