mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for clone op with channels last memory format
Fixes https://github.com/llvm/torch-mlir/issues/1829 Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1851/head snapshot-20230202.737
parent
95361747c2
commit
ed9d8d1fb7
|
@ -270,6 +270,7 @@ MHLO_PASS_SET = {
|
|||
"Convolution2DStaticModule_basic",
|
||||
"ConvolutionModule2DTransposeStridedStatic_basic",
|
||||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
"ElementwiseBinaryStaticShapeModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
|
@ -425,6 +426,7 @@ MHLO_PASS_SET = {
|
|||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseBinaryModule_basic",
|
||||
|
|
|
@ -177,8 +177,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
|
||||
(!matchPattern(clone.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)) ||
|
||||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
|
||||
clone.emitError("unimplemented: only default memory format is supported");
|
||||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
|
||||
clone.emitError("unimplemented: only contiguous and channels last memory "
|
||||
"format is supported");
|
||||
return nullptr;
|
||||
}
|
||||
return payloadArgs[0];
|
||||
|
|
|
@ -4077,9 +4077,11 @@ public:
|
|||
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)) ||
|
||||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
|
||||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
|
||||
return op.emitError(
|
||||
"unimplemented: only default memory format is supported");
|
||||
"unimplemented: only contiguous and channels last memory "
|
||||
"format is supported");
|
||||
}
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
|
|
@ -1963,6 +1963,30 @@ def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCloneChannelsLastMemoryFormatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.clone(x, memory_format=torch.channels_last)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule())
|
||||
def ElementwiseCloneChannelsLastMemoryFormatModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LiftFreshCopyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue