mirror of https://github.com/llvm/torch-mlir
[torch] Fix unsqueezed output shape in canonicalization of AtenUnflattenIntOp (#3730)
Fixes https://github.com/iree-org/iree/issues/18562. During canonicalization pass on `AtenUnflattenIntOp`, if the second dim was statically equal to one, we would create an `AtenAddIntOp` to add one to the dimension obtained from `op.getDim()`. This, when passed into `Torch::unsqueezeTensor()`, would make it get interpreted as non-constant, which would lead to MLIR failing an assertion when `UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the output of the `UnsqueezeOp` would consist of only dynamic dims. This patch fixes this behavior, by extracting the integer value from the dim if it was constant, and then emitting a `ConstantIntOp` from (dim+1). This creates an output with static shape.pull/3761/head
parent
e4f2bdf0db
commit
67732883fa
|
@ -2189,6 +2189,9 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
|
|||
if (dim0 != 1 && dim1 != 1)
|
||||
return failure();
|
||||
Value unflattenDim = op.getDim();
|
||||
int64_t dimAsInt;
|
||||
bool dimWasConstant =
|
||||
matchPattern(unflattenDim, m_TorchConstantInt(&dimAsInt));
|
||||
Value self = op.getSelf();
|
||||
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
|
||||
// the runtime asserts below are introduced to catch malformed unflatten ops
|
||||
|
@ -2217,9 +2220,22 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
|
|||
}
|
||||
if (dim1 == 1) {
|
||||
// unsqueeze at dim + 1
|
||||
Value dimPlusOne;
|
||||
if (!dimWasConstant) {
|
||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
|
||||
Value dimPlusOne =
|
||||
dimPlusOne =
|
||||
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
|
||||
} else {
|
||||
// If dim was constant, creating an AtenAddIntOp will make
|
||||
// Torch::unsqueezeTensor() interpret it as still not being a constant,
|
||||
// and the resultant shape would consist of only dynamic dims. To fix
|
||||
// this, emit a ConstantIntOp for (dim + 1) to avoid an assertion
|
||||
// failure, when AtenUnsqueezeOp is in a later pass converted to
|
||||
// ExpandShapeOp, which is bound to fail shape inference in MLIR if
|
||||
// output dims are dynamic.
|
||||
dimPlusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1));
|
||||
}
|
||||
FailureOr<Value> maybeUnsqueeze =
|
||||
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
|
||||
if (failed(maybeUnsqueeze))
|
||||
|
|
Loading…
Reference in New Issue