From 67732883fa0d5e50cd449a3c5a6e80d60337d099 Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:15:18 +0530 Subject: [PATCH] [torch] Fix unsqueezed output shape in canonicalization of AtenUnflattenIntOp (#3730) Fixes https://github.com/iree-org/iree/issues/18562. During canonicalization pass on `AtenUnflattenIntOp`, if the second dim was statically equal to one, we would create an `AtenAddIntOp` to add one to the dimension obtained from `op.getDim()`. This, when passed into `Torch::unsqueezeTensor()`, would make it get interpreted as non-constant, which would lead to MLIR failing an assertion when `UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the output of the `UnsqueezeOp` would consist of only dynamic dims. This patch fixes this behavior, by extracting the integer value from the dim if it was constant, and then emitting a `ConstantIntOp` from (dim+1). This creates an output with static shape. --- lib/Dialect/Torch/IR/TorchOps.cpp | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index c4223ae55..bed228671 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2189,6 +2189,9 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( if (dim0 != 1 && dim1 != 1) return failure(); Value unflattenDim = op.getDim(); + int64_t dimAsInt; + bool dimWasConstant = + matchPattern(unflattenDim, m_TorchConstantInt(&dimAsInt)); Value self = op.getSelf(); Value cstMOne = rewriter.create(op.getLoc(), -1); // the runtime asserts below are introduced to catch malformed unflatten ops @@ -2217,9 +2220,22 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( } if (dim1 == 1) { // unsqueeze at dim + 1 - Value cstOne = rewriter.create(op.getLoc(), 1); - Value dimPlusOne = - rewriter.create(op.getLoc(), unflattenDim, cstOne); + Value dimPlusOne; + if (!dimWasConstant) { + Value cstOne = rewriter.create(op.getLoc(), 1); + dimPlusOne = + rewriter.create(op.getLoc(), unflattenDim, cstOne); + } else { + // If dim was constant, creating an AtenAddIntOp will make + // Torch::unsqueezeTensor() interpret it as still not being a constant, + // and the resultant shape would consist of only dynamic dims. To fix + // this, emit a ConstantIntOp for (dim + 1) to avoid an assertion + // failure, when AtenUnsqueezeOp is in a later pass converted to + // ExpandShapeOp, which is bound to fail shape inference in MLIR if + // output dims are dynamic. + dimPlusOne = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1)); + } FailureOr maybeUnsqueeze = Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne); if (failed(maybeUnsqueeze))