diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index c4223ae55..bed228671 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2189,6 +2189,9 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( if (dim0 != 1 && dim1 != 1) return failure(); Value unflattenDim = op.getDim(); + int64_t dimAsInt; + bool dimWasConstant = + matchPattern(unflattenDim, m_TorchConstantInt(&dimAsInt)); Value self = op.getSelf(); Value cstMOne = rewriter.create(op.getLoc(), -1); // the runtime asserts below are introduced to catch malformed unflatten ops @@ -2217,9 +2220,22 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( } if (dim1 == 1) { // unsqueeze at dim + 1 - Value cstOne = rewriter.create(op.getLoc(), 1); - Value dimPlusOne = - rewriter.create(op.getLoc(), unflattenDim, cstOne); + Value dimPlusOne; + if (!dimWasConstant) { + Value cstOne = rewriter.create(op.getLoc(), 1); + dimPlusOne = + rewriter.create(op.getLoc(), unflattenDim, cstOne); + } else { + // If dim was constant, creating an AtenAddIntOp will make + // Torch::unsqueezeTensor() interpret it as still not being a constant, + // and the resultant shape would consist of only dynamic dims. To fix + // this, emit a ConstantIntOp for (dim + 1) to avoid an assertion + // failure, when AtenUnsqueezeOp is in a later pass converted to + // ExpandShapeOp, which is bound to fail shape inference in MLIR if + // output dims are dynamic. + dimPlusOne = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1)); + } FailureOr maybeUnsqueeze = Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne); if (failed(maybeUnsqueeze))