mirror of https://github.com/llvm/torch-mlir
[LINALG] Add torch.layout information
torch.layout information has been added.pull/693/head
parent
eaf34fa02b
commit
1d5b5a89e8
|
@ -97,6 +97,13 @@ enum MemoryFormat {
|
|||
ChannelsLast3d
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Possible values for `layout` argument in PyTorch ops that support it.
|
||||
// Source:
|
||||
// https://github.com/pytorch/pytorch/blob/master/c10/core/Layout.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
|
||||
|
||||
} // namespace torch_upstream
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -121,9 +121,8 @@ public:
|
|||
|
||||
// TODO: Add support for layout, pin_memory features.
|
||||
// Only `none` layout is supported.
|
||||
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default layout is supported");
|
||||
// At this point all tensors should have value semantics, and hence the
|
||||
// `layout` check can be ignored.
|
||||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
|
@ -181,14 +180,13 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// TODO: Add support for layout, pin_memory and memory_format features.
|
||||
// Only `none` layout is supported.
|
||||
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default layout is supported");
|
||||
// TODO: Add support pin_memory and memory_format features.
|
||||
// At this point all tensors should have value semantics, and hence the
|
||||
// `layout` check can be ignored.
|
||||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
|
@ -258,11 +256,9 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// TODO: Add support for layout, pin_memory features.
|
||||
// Only `none` layout is supported.
|
||||
if (!op.layout().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only default layout is supported");
|
||||
// TODO: Add support for pin_memory features.
|
||||
// At this point all tensors should have value semantics, and hence the
|
||||
// `layout` check can be ignored.
|
||||
|
||||
// The pin_memory should be either `False` or `none`.
|
||||
bool pinMemory;
|
||||
|
|
|
@ -1143,3 +1143,19 @@ class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype())
|
||||
def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)).to(torch.int32))
|
||||
|
||||
class NewEmptyModuleLayoutIntDtype(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4], layout = 0).fill_(0)
|
||||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleLayoutIntDtype())
|
||||
def NewEmptyModuleLayoutIntDtype_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)).to(torch.int32))
|
||||
|
|
Loading…
Reference in New Issue