From ed9d8d1fb7eafb14c2d1704bd5f73599017b2128 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 1 Feb 2023 16:43:59 +0530 Subject: [PATCH] [MLIR][TORCH] Add support for clone op with channels last memory format Fixes https://github.com/llvm/torch-mlir/issues/1829 Signed-Off By: Vivek Khandelwal --- e2e_testing/xfail_sets.py | 2 ++ .../TorchToLinalg/Uncategorized.cpp | 6 +++-- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 6 +++-- .../test_suite/elementwise.py | 24 +++++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 52fe3467b..3de357f98 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -270,6 +270,7 @@ MHLO_PASS_SET = { "Convolution2DStaticModule_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneModule_basic", "ElementwiseBinaryStaticShapeModule_basic", "ReturnThreeTensorFloat32_basic", @@ -425,6 +426,7 @@ MHLO_PASS_SET = { # and very few tests work yet. TOSA_PASS_SET = { "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneModule_basic", "ElementwiseUnaryModule_basic", "ElementwiseBinaryModule_basic", diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d5ad4f50b..0cbb81d36 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -177,8 +177,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (!clone.getMemoryFormat().getType().isa() && (!matchPattern(clone.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || - memoryFormat != torch_upstream::MemoryFormat::Contiguous)) { - clone.emitError("unimplemented: only default memory format is supported"); + (memoryFormat != torch_upstream::MemoryFormat::Contiguous && + memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) { + clone.emitError("unimplemented: only contiguous and channels last memory " + "format is supported"); return nullptr; } return payloadArgs[0]; diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b285a8a0d..cbd26b17d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4077,9 +4077,11 @@ public: if (!op.getMemoryFormat().getType().template isa() && (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || - memoryFormat != torch_upstream::MemoryFormat::Contiguous)) { + (memoryFormat != torch_upstream::MemoryFormat::Contiguous && + memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) { return op.emitError( - "unimplemented: only default memory format is supported"); + "unimplemented: only contiguous and channels last memory " + "format is supported"); } auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 46411df66..47fecf0ec 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1963,6 +1963,30 @@ def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseCloneChannelsLastMemoryFormatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.clone(x, memory_format=torch.channels_last) + + +@register_test_case( + module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule()) +def ElementwiseCloneChannelsLastMemoryFormatModule_basic( + module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 5)) + + +# ============================================================================== + + class LiftFreshCopyModule(torch.nn.Module): def __init__(self):