[torch] Fix unsqueezed output shape in canonicalization of AtenUnflattenIntOp (#3730)

Fixes https://github.com/iree-org/iree/issues/18562.

During canonicalization pass on `AtenUnflattenIntOp`, if the second dim
was statically equal to one, we would create an `AtenAddIntOp` to add
one to the dimension obtained from `op.getDim()`. This, when passed into
`Torch::unsqueezeTensor()`, would make it get interpreted as
non-constant, which would lead to MLIR failing an assertion when
`UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the
output of the `UnsqueezeOp` would consist of only dynamic dims.

This patch fixes this behavior, by extracting the integer value from the
dim if it was constant, and then emitting a `ConstantIntOp` from
(dim+1). This creates an output with static shape.
pull/3085/merge
Vinayak Dev 2024-09-24 22:15:18 +05:30 committed by GitHub
parent e4f2bdf0db
commit 67732883fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 3 deletions

View File

@ -2189,6 +2189,9 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
if (dim0 != 1 && dim1 != 1)
return failure();
Value unflattenDim = op.getDim();
int64_t dimAsInt;
bool dimWasConstant =
matchPattern(unflattenDim, m_TorchConstantInt(&dimAsInt));
Value self = op.getSelf();
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
// the runtime asserts below are introduced to catch malformed unflatten ops
@ -2217,9 +2220,22 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
}
if (dim1 == 1) {
// unsqueeze at dim + 1
Value dimPlusOne;
if (!dimWasConstant) {
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
Value dimPlusOne =
dimPlusOne =
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
} else {
// If dim was constant, creating an AtenAddIntOp will make
// Torch::unsqueezeTensor() interpret it as still not being a constant,
// and the resultant shape would consist of only dynamic dims. To fix
// this, emit a ConstantIntOp for (dim + 1) to avoid an assertion
// failure, when AtenUnsqueezeOp is in a later pass converted to
// ExpandShapeOp, which is bound to fail shape inference in MLIR if
// output dims are dynamic.
dimPlusOne = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1));
}
FailureOr<Value> maybeUnsqueeze =
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
if (failed(maybeUnsqueeze))