Add a canonicalization pattern for `aten.unflatten.int` (#3656)

Addresses an issue in <https://github.com/llvm/torch-mlir/issues/3651>
where some unflatten ops generated from onnx models weren't propagating
static shape information. It may be necessary to add further
optimizations for the more general case when some static information is
present in the unflatten (or possibly reshape/view) op's `sizes` list,
but not reflected in the output shape. These ops will only successfully
infer shapes if the `sizes` list is gotten from a list of constant ints
(with possibly one -1). A common example where this fails is when some
of the `sizes` are determined from `aten.size.int` ops on dynamic
tensors, and other `sizes` are known statically.

This PR includes:
- a canonicalizer for `aten.unflatten.int` which converts to
`aten.unsqueeze` when it is expanding one dim to two, and one of the new
dims is statically 1.
- an improvement to the folder for `aten.__or__.bool` which does not
rely on *both* operands being static.
pull/3686/head
zjgarvey 2024-09-03 16:38:20 -07:00 committed by GitHub
parent 2960538c6d
commit 295bf418a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 92 additions and 6 deletions

View File

@ -9538,6 +9538,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
printDefaultTorchOp(printer, *this, 3, 1); printDefaultTorchOp(printer, *this, 3, 1);
} }
}]; }];
let hasCanonicalizer = 1;
} }
def Torch_AtenDimOp : Torch_Op<"aten.dim", [ def Torch_AtenDimOp : Torch_Op<"aten.dim", [

View File

@ -739,12 +739,16 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) {
auto valueA = dyn_cast_or_null<IntegerAttr>(adaptor.getA()); auto valueA = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
auto valueB = dyn_cast_or_null<IntegerAttr>(adaptor.getB()); auto valueB = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
if (!valueA || !valueB) { if (!valueA && !valueB)
return nullptr; return nullptr;
} if ((valueA && valueA.getValue() == 1) || (valueB && valueB.getValue() == 1))
return IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
return IntegerAttr::get(IntegerType::get(getContext(), 1), if (valueA && valueA.getValue() == 0)
valueA.getValue() | valueB.getValue()); return getB();
if (valueB && valueB.getValue() == 0)
return getA();
// unreachable
return nullptr;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2162,6 +2166,85 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
//===----------------------------------------------------------------------===//
// AtenUnflattenIntOp
//===----------------------------------------------------------------------===//
void AtenUnflattenIntOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// if there are only two sizes and one of them is statically 1, then convert
// to an unqueeze.
patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) {
SmallVector<Value> sizeValues;
if (!getListConstructElements(op.getSizes(), sizeValues))
return rewriter.notifyMatchFailure(op,
"sizes must come from list construct");
if (sizeValues.size() != 2)
return failure();
int64_t dim0, dim1;
bool dim0Constant = matchPattern(sizeValues[0], m_TorchConstantInt(&dim0));
bool dim1Constant = matchPattern(sizeValues[1], m_TorchConstantInt(&dim1));
if (!dim0Constant && !dim1Constant)
return failure();
if (dim0 != 1 && dim1 != 1)
return failure();
Value unflattenDim = op.getDim();
Value self = op.getSelf();
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
// the runtime asserts below are introduced to catch malformed unflatten ops
// possibly generated from onnx IR.
Value unsqueeze;
if (dim0 == 1) {
// unsqueeze at dim
FailureOr<Value> maybeUnsqueeze =
Torch::unsqueezeTensor(rewriter, op, self, unflattenDim);
if (failed(maybeUnsqueeze))
return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op");
unsqueeze = maybeUnsqueeze.value();
// check if the remaining size value is either -1 or equal to original
// size at dim
Value selfSizeAtDim =
rewriter.create<AtenSizeIntOp>(op.getLoc(), self, unflattenDim);
Value isSameSize = rewriter.create<AtenEqIntOp>(
op.getLoc(), selfSizeAtDim, sizeValues[1]);
Value isMinusOne =
rewriter.create<AtenEqIntOp>(op.getLoc(), cstMOne, sizeValues[1]);
Value isMOneOrSameSize = rewriter.create<Aten__Or__BoolOp>(
op.getLoc(), isMinusOne, isSameSize);
rewriter.create<Torch::RuntimeAssertOp>(
op.getLoc(), isMOneOrSameSize,
rewriter.getStringAttr("unflatten sizes must be compatible"));
}
if (dim1 == 1) {
// unsqueeze at dim + 1
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
Value dimPlusOne =
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
FailureOr<Value> maybeUnsqueeze =
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
if (failed(maybeUnsqueeze))
return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op");
unsqueeze = maybeUnsqueeze.value();
// check if the remaining size value is either -1 or equal to original
// size at dim
Value selfSizeAtDim =
rewriter.create<AtenSizeIntOp>(op.getLoc(), self, unflattenDim);
Value isSameSize = rewriter.create<AtenEqIntOp>(
op.getLoc(), selfSizeAtDim, sizeValues[0]);
Value isMinusOne =
rewriter.create<AtenEqIntOp>(op.getLoc(), cstMOne, sizeValues[0]);
Value isMOneOrSameSize = rewriter.create<Aten__Or__BoolOp>(
op.getLoc(), isMinusOne, isSameSize);
rewriter.create<Torch::RuntimeAssertOp>(
op.getLoc(), isMOneOrSameSize,
rewriter.getStringAttr("unflatten sizes must be compatible"));
}
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(op, op.getType(),
unsqueeze);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSelectIntOp // AtenSelectIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -757,7 +757,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)") emit(
"aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True
)
emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::dim : (Tensor) -> (int)", has_folder=True)
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)") emit("aten::Bool.Tensor : (Tensor) -> (bool)")