From 88c216da13dbce32e9d122b5c962a584925c1400 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 24 Mar 2022 17:33:12 +0530 Subject: [PATCH] [MLIR][TORCH] Add support for same input and output shapes for view op This commit adds support for the cases of view op where the rank and the shapes of the input and result are equal. Signed-Off By: Vivek Khandelwal --- e2e_testing/torchscript/reshape_like.py | 73 +++++++++++++++++++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 47 ++++++------ 2 files changed, 99 insertions(+), 21 deletions(-) diff --git a/e2e_testing/torchscript/reshape_like.py b/e2e_testing/torchscript/reshape_like.py index e998b806e..108f40433 100644 --- a/e2e_testing/torchscript/reshape_like.py +++ b/e2e_testing/torchscript/reshape_like.py @@ -315,3 +315,76 @@ class ReshapeCollapseModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReshapeCollapseModule()) def ReshapeCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) + +# ============================================================================== + +class ViewNoChange1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(6) + +@register_test_case(module_factory=lambda: ViewNoChange1dModule()) +def ViewNoChange1dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6)) + + +class ViewNoChange2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(5, 6) + +@register_test_case(module_factory=lambda: ViewNoChange2dModule()) +def ViewNoChange2dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 6)) + + +class ViewNoChange3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(4, 5, 6) + +@register_test_case(module_factory=lambda: ViewNoChange3dModule()) +def ViewNoChange3dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6)) + + +class ViewNoChangeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ]) + + def forward(self, a): + return a.view(4, 5, 6) + +@register_test_case(module_factory=lambda: ViewNoChangeStaticModule()) +def ViewNoChangeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6)) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index f1f1d7329..34f43f2c1 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -107,19 +107,6 @@ public: auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); - // Currently, we only handle the expanding OR collapsing cases, we do not - // handle expanding And collapsing happening at the same time or cases where - // it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor. - // TODO: For the expanding And collapsing case, we will need to identify - // which dimensions are collapsing and which are expanding and do it in two - // steps. - // TODO: For neither collapsing nor expanding, we could find a intermediate - // shape to collapse and then expanded to the target shape. Like [2,3] => - // [6] => [3, 2]. - if (inputRank == resultRank) - return rewriter.notifyMatchFailure( - op, "unimplemented: the view op is neither expanding nor collapsing"); - if (resultRank == 0) return rewriter.notifyMatchFailure(op, "result shape of rank 0 is invalid"); @@ -147,11 +134,30 @@ public: return rewriter.notifyMatchFailure( op, "desired size list length mismatches with the result type rank"); } - SmallVector inputSizeTorchInt = getTensorSizes(rewriter, loc, input); - ArrayRef expandedShapeTorchInt = - llvm::makeArrayRef(isCollapse ? inputSizeTorchInt : outputSizeInt); - ArrayRef collapsedShapeTorchInt = - llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSizeTorchInt); + + SmallVector inputSize = getTensorSizes(rewriter, loc, input); + ArrayRef expandedShapeInt = + llvm::makeArrayRef(isCollapse ? inputSize : outputSizeInt); + ArrayRef collapsedShapeInt = + llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSize); + + // Currently, we only handle the expanding or collapsing cases or the + // identity cases where the rank and shape of the input and result are + // equal, and the input itself is the result. We do not handle expanding And + // collapsing happening at the same time or cases where it's neither + // collapsing nor expanding like view of [2,3] for 3x2 tensor. + // TODO: For the expanding And collapsing case, we will need to identify + // which dimensions are collapsing and which are expanding and do it in two + // steps. + // TODO: For neither collapsing nor expanding, we could find a intermediate + // shape to collapse and then expanded to the target shape. Like [2,3] => + // [6] => [3, 2]. + if (inputRank == resultRank) { + for (unsigned i = 0; i < inputRank; i++) + checkDimEqualHelper(rewriter, loc, inputSize[i], outputSizeInt[i]); + rewriter.replaceOpWithNewOp(op, resultType, input); + return success(); + } // Iterate through the view op size list to do the following: // @@ -307,9 +313,8 @@ public: op, "desired size is not compatible with the input tensor size"); } - checkDimEqualHelper(rewriter, loc, - collapsedShapeTorchInt[collapsedDim], - expandedShapeTorchInt[expandedDim]); + checkDimEqualHelper(rewriter, loc, collapsedShapeInt[collapsedDim], + expandedShapeInt[expandedDim]); // To meet the second requirement from tensor.expand_shape // verification code. expandedShape[expandedDim] = kUnknownSize;