From 1d5b5a89e8e9ac75e4837b158fc4230eb7388b28 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Wed, 6 Apr 2022 13:28:21 +0000 Subject: [PATCH] [LINALG] Add torch.layout information torch.layout information has been added. --- .../Dialect/Torch/Utils/TorchUpstream.h | 7 ++++++ .../TorchToLinalg/TensorConstructors.cpp | 22 ++++++++----------- .../test_suite/constant_alloc.py | 16 ++++++++++++++ 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 91fff0078..124ebdf92 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -97,6 +97,13 @@ enum MemoryFormat { ChannelsLast3d }; +//===----------------------------------------------------------------------===// +// Possible values for `layout` argument in PyTorch ops that support it. +// Source: +// https://github.com/pytorch/pytorch/blob/master/c10/core/Layout.h +//===----------------------------------------------------------------------===// +enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; + } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index a9e22dbbe..2075732a8 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -121,9 +121,8 @@ public: // TODO: Add support for layout, pin_memory features. // Only `none` layout is supported. - if (!op.layout().getType().template isa()) - return rewriter.notifyMatchFailure( - op, "unimplemented: only default layout is supported"); + // At this point all tensors should have value semantics, and hence the + // `layout` check can be ignored. // The pin_memory should be either `False` or `none`. bool pinMemory; @@ -181,14 +180,13 @@ public: LogicalResult matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - // TODO: Add support for layout, pin_memory and memory_format features. - // Only `none` layout is supported. - if (!op.layout().getType().template isa()) - return rewriter.notifyMatchFailure( - op, "unimplemented: only default layout is supported"); + // TODO: Add support pin_memory and memory_format features. + // At this point all tensors should have value semantics, and hence the + // `layout` check can be ignored. // The pin_memory should be either `False` or `none`. bool pinMemory; @@ -258,11 +256,9 @@ public: if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - // TODO: Add support for layout, pin_memory features. - // Only `none` layout is supported. - if (!op.layout().getType().isa()) - return rewriter.notifyMatchFailure( - op, "unimplemented: only default layout is supported"); + // TODO: Add support for pin_memory features. + // At this point all tensors should have value semantics, and hence the + // `layout` check can be ignored. // The pin_memory should be either `False` or `none`. bool pinMemory; diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 5e14d51c7..61444167a 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1143,3 +1143,19 @@ class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module): @register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype()) def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3)).to(torch.int32)) + +class NewEmptyModuleLayoutIntDtype(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.new_empty(a, [3, 4], layout = 0).fill_(0) + +@register_test_case(module_factory=lambda: NewEmptyModuleLayoutIntDtype()) +def NewEmptyModuleLayoutIntDtype_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (2, 3)).to(torch.int32))