[MLIR][TORCH] Add support for clone op with channels last memory format

Fixes https://github.com/llvm/torch-mlir/issues/1829

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1851/head snapshot-20230202.737
Vivek Khandelwal 2023-02-01 16:43:59 +05:30
parent 95361747c2
commit ed9d8d1fb7
4 changed files with 34 additions and 4 deletions

View File

@ -270,6 +270,7 @@ MHLO_PASS_SET = {
"Convolution2DStaticModule_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ReturnThreeTensorFloat32_basic",
@ -425,6 +426,7 @@ MHLO_PASS_SET = {
# and very few tests work yet.
TOSA_PASS_SET = {
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseBinaryModule_basic",

View File

@ -177,8 +177,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
(!matchPattern(clone.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
clone.emitError("unimplemented: only default memory format is supported");
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
clone.emitError("unimplemented: only contiguous and channels last memory "
"format is supported");
return nullptr;
}
return payloadArgs[0];

View File

@ -4077,9 +4077,11 @@ public:
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
return op.emitError(
"unimplemented: only default memory format is supported");
"unimplemented: only contiguous and channels last memory "
"format is supported");
}
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())

View File

@ -1963,6 +1963,30 @@ def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseCloneChannelsLastMemoryFormatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.clone(x, memory_format=torch.channels_last)
@register_test_case(
module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule())
def ElementwiseCloneChannelsLastMemoryFormatModule_basic(
module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4, 5))
# ==============================================================================
class LiftFreshCopyModule(torch.nn.Module):
def __init__(self):