mirror of https://github.com/llvm/torch-mlir
[torch] Fix unsqueezed output shape in canonicalization of AtenUnflattenIntOp (#3730)
Fixes https://github.com/iree-org/iree/issues/18562. During canonicalization pass on `AtenUnflattenIntOp`, if the second dim was statically equal to one, we would create an `AtenAddIntOp` to add one to the dimension obtained from `op.getDim()`. This, when passed into `Torch::unsqueezeTensor()`, would make it get interpreted as non-constant, which would lead to MLIR failing an assertion when `UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the output of the `UnsqueezeOp` would consist of only dynamic dims. This patch fixes this behavior, by extracting the integer value from the dim if it was constant, and then emitting a `ConstantIntOp` from (dim+1). This creates an output with static shape.pull/3761/head
parent
e4f2bdf0db
commit
67732883fa
|
@ -2189,6 +2189,9 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
|
||||||
if (dim0 != 1 && dim1 != 1)
|
if (dim0 != 1 && dim1 != 1)
|
||||||
return failure();
|
return failure();
|
||||||
Value unflattenDim = op.getDim();
|
Value unflattenDim = op.getDim();
|
||||||
|
int64_t dimAsInt;
|
||||||
|
bool dimWasConstant =
|
||||||
|
matchPattern(unflattenDim, m_TorchConstantInt(&dimAsInt));
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
|
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
|
||||||
// the runtime asserts below are introduced to catch malformed unflatten ops
|
// the runtime asserts below are introduced to catch malformed unflatten ops
|
||||||
|
@ -2217,9 +2220,22 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
|
||||||
}
|
}
|
||||||
if (dim1 == 1) {
|
if (dim1 == 1) {
|
||||||
// unsqueeze at dim + 1
|
// unsqueeze at dim + 1
|
||||||
|
Value dimPlusOne;
|
||||||
|
if (!dimWasConstant) {
|
||||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
|
||||||
Value dimPlusOne =
|
dimPlusOne =
|
||||||
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
|
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
|
||||||
|
} else {
|
||||||
|
// If dim was constant, creating an AtenAddIntOp will make
|
||||||
|
// Torch::unsqueezeTensor() interpret it as still not being a constant,
|
||||||
|
// and the resultant shape would consist of only dynamic dims. To fix
|
||||||
|
// this, emit a ConstantIntOp for (dim + 1) to avoid an assertion
|
||||||
|
// failure, when AtenUnsqueezeOp is in a later pass converted to
|
||||||
|
// ExpandShapeOp, which is bound to fail shape inference in MLIR if
|
||||||
|
// output dims are dynamic.
|
||||||
|
dimPlusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1));
|
||||||
|
}
|
||||||
FailureOr<Value> maybeUnsqueeze =
|
FailureOr<Value> maybeUnsqueeze =
|
||||||
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
|
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
|
||||||
if (failed(maybeUnsqueeze))
|
if (failed(maybeUnsqueeze))
|
||||||
|
|
Loading…
Reference in New Issue