From 2ea2ab518b9be2c0fb57238148d06c3afd40f510 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Fri, 29 Oct 2021 03:15:05 -0400 Subject: [PATCH] Add contiguous --- e2e_testing/torchscript/basic.py | 18 ++++++++++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 19 +++++++++++++++++++ .../Transforms/MaximizeValueSemantics.cpp | 8 ++++++-- 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index a78f3fe82..bec79eac1 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -446,6 +446,7 @@ class BroadcastToModule(torch.nn.Module): def BroadcastToModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 1)) + class OnesModuleInt(torch.nn.Module): def __init__(self): super().__init__() @@ -475,3 +476,20 @@ class OnesModuleFloat(torch.nn.Module): @register_test_case(module_factory=lambda: OnesModuleFloat()) def OnesModuleFloat_basic(module, tu: TestUtils): module.forward() + +class ContiguousModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return x.contiguous() + + +@register_test_case(module_factory=lambda: ContiguousModule()) +def ContiguousModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1)) diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index c1c6f8eb1..fa569565f 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -2555,6 +2555,23 @@ public: }; } // namespace +namespace { +class ConvertAtenContiguousOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenContiguousOp op, llvm::ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + AtenContiguousOp::Adaptor adaptor(operands); + rewriter.replaceOp(op, adaptor.self()); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenOnesOp : public OpConversionPattern { public: @@ -2685,6 +2702,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 993a3dbff..1afba8ad3 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -90,8 +90,12 @@ public: if (auto copyToValueTensor = dyn_cast(op)) { copyToValueTensorOps.push_back(copyToValueTensor); } else if (isa(op)) { + AtenTransposeIntOp, TensorStaticInfoCastOp, + AtenBroadcastToOp, AtenContiguousOp, AtenPermuteOp>(op)) { + // AtenContiguousOp might return a view, so this is conservatively + // correct. We could potentially be more precise and identify the cases + // that it does not return a view and treat those as having value + // semantics. viewLikeOps.push_back(op); llvm::append_range(workList, op->getResult(0).getUsers()); } else {