[LINALG] Add torch.layout information

torch.layout information has been added.
pull/693/head
Prashant Kumar 2022-04-06 13:28:21 +00:00
parent eaf34fa02b
commit 1d5b5a89e8
3 changed files with 32 additions and 13 deletions

View File

@ -97,6 +97,13 @@ enum MemoryFormat {
ChannelsLast3d
};
//===----------------------------------------------------------------------===//
// Possible values for `layout` argument in PyTorch ops that support it.
// Source:
// https://github.com/pytorch/pytorch/blob/master/c10/core/Layout.h
//===----------------------------------------------------------------------===//
enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
} // namespace torch_upstream
} // namespace torch
} // namespace mlir

View File

@ -121,9 +121,8 @@ public:
// TODO: Add support for layout, pin_memory features.
// Only `none` layout is supported.
if (!op.layout().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// At this point all tensors should have value semantics, and hence the
// `layout` check can be ignored.
// The pin_memory should be either `False` or `none`.
bool pinMemory;
@ -181,14 +180,13 @@ public:
LogicalResult
matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Add support for layout, pin_memory and memory_format features.
// Only `none` layout is supported.
if (!op.layout().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// TODO: Add support pin_memory and memory_format features.
// At this point all tensors should have value semantics, and hence the
// `layout` check can be ignored.
// The pin_memory should be either `False` or `none`.
bool pinMemory;
@ -258,11 +256,9 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Add support for layout, pin_memory features.
// Only `none` layout is supported.
if (!op.layout().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// TODO: Add support for pin_memory features.
// At this point all tensors should have value semantics, and hence the
// `layout` check can be ignored.
// The pin_memory should be either `False` or `none`.
bool pinMemory;

View File

@ -1143,3 +1143,19 @@ class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module):
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype())
def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)).to(torch.int32))
class NewEmptyModuleLayoutIntDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4], layout = 0).fill_(0)
@register_test_case(module_factory=lambda: NewEmptyModuleLayoutIntDtype())
def NewEmptyModuleLayoutIntDtype_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3)).to(torch.int32))