mirror of https://github.com/llvm/torch-mlir
Add a canonicalization pattern for `aten.unflatten.int` (#3656)
Addresses an issue in <https://github.com/llvm/torch-mlir/issues/3651> where some unflatten ops generated from onnx models weren't propagating static shape information. It may be necessary to add further optimizations for the more general case when some static information is present in the unflatten (or possibly reshape/view) op's `sizes` list, but not reflected in the output shape. These ops will only successfully infer shapes if the `sizes` list is gotten from a list of constant ints (with possibly one -1). A common example where this fails is when some of the `sizes` are determined from `aten.size.int` ops on dynamic tensors, and other `sizes` are known statically. This PR includes: - a canonicalizer for `aten.unflatten.int` which converts to `aten.unsqueeze` when it is expanding one dim to two, and one of the new dims is statically 1. - an improvement to the folder for `aten.__or__.bool` which does not rely on *both* operands being static.pull/3686/head
parent
2960538c6d
commit
295bf418a4
|
@ -9538,6 +9538,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDimOp : Torch_Op<"aten.dim", [
|
||||
|
|
|
@ -739,12 +739,16 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
|
|||
OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) {
|
||||
auto valueA = dyn_cast_or_null<IntegerAttr>(adaptor.getA());
|
||||
auto valueB = dyn_cast_or_null<IntegerAttr>(adaptor.getB());
|
||||
if (!valueA || !valueB) {
|
||||
if (!valueA && !valueB)
|
||||
return nullptr;
|
||||
if ((valueA && valueA.getValue() == 1) || (valueB && valueB.getValue() == 1))
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
|
||||
if (valueA && valueA.getValue() == 0)
|
||||
return getB();
|
||||
if (valueB && valueB.getValue() == 0)
|
||||
return getA();
|
||||
// unreachable
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1),
|
||||
valueA.getValue() | valueB.getValue());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2162,6 +2166,85 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenUnflattenIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtenUnflattenIntOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
// if there are only two sizes and one of them is statically 1, then convert
|
||||
// to an unqueeze.
|
||||
patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) {
|
||||
SmallVector<Value> sizeValues;
|
||||
if (!getListConstructElements(op.getSizes(), sizeValues))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"sizes must come from list construct");
|
||||
if (sizeValues.size() != 2)
|
||||
return failure();
|
||||
int64_t dim0, dim1;
|
||||
bool dim0Constant = matchPattern(sizeValues[0], m_TorchConstantInt(&dim0));
|
||||
bool dim1Constant = matchPattern(sizeValues[1], m_TorchConstantInt(&dim1));
|
||||
if (!dim0Constant && !dim1Constant)
|
||||
return failure();
|
||||
if (dim0 != 1 && dim1 != 1)
|
||||
return failure();
|
||||
Value unflattenDim = op.getDim();
|
||||
Value self = op.getSelf();
|
||||
Value cstMOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), -1);
|
||||
// the runtime asserts below are introduced to catch malformed unflatten ops
|
||||
// possibly generated from onnx IR.
|
||||
Value unsqueeze;
|
||||
if (dim0 == 1) {
|
||||
// unsqueeze at dim
|
||||
FailureOr<Value> maybeUnsqueeze =
|
||||
Torch::unsqueezeTensor(rewriter, op, self, unflattenDim);
|
||||
if (failed(maybeUnsqueeze))
|
||||
return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op");
|
||||
unsqueeze = maybeUnsqueeze.value();
|
||||
// check if the remaining size value is either -1 or equal to original
|
||||
// size at dim
|
||||
Value selfSizeAtDim =
|
||||
rewriter.create<AtenSizeIntOp>(op.getLoc(), self, unflattenDim);
|
||||
Value isSameSize = rewriter.create<AtenEqIntOp>(
|
||||
op.getLoc(), selfSizeAtDim, sizeValues[1]);
|
||||
Value isMinusOne =
|
||||
rewriter.create<AtenEqIntOp>(op.getLoc(), cstMOne, sizeValues[1]);
|
||||
Value isMOneOrSameSize = rewriter.create<Aten__Or__BoolOp>(
|
||||
op.getLoc(), isMinusOne, isSameSize);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
op.getLoc(), isMOneOrSameSize,
|
||||
rewriter.getStringAttr("unflatten sizes must be compatible"));
|
||||
}
|
||||
if (dim1 == 1) {
|
||||
// unsqueeze at dim + 1
|
||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), 1);
|
||||
Value dimPlusOne =
|
||||
rewriter.create<AtenAddIntOp>(op.getLoc(), unflattenDim, cstOne);
|
||||
FailureOr<Value> maybeUnsqueeze =
|
||||
Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne);
|
||||
if (failed(maybeUnsqueeze))
|
||||
return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op");
|
||||
unsqueeze = maybeUnsqueeze.value();
|
||||
// check if the remaining size value is either -1 or equal to original
|
||||
// size at dim
|
||||
Value selfSizeAtDim =
|
||||
rewriter.create<AtenSizeIntOp>(op.getLoc(), self, unflattenDim);
|
||||
Value isSameSize = rewriter.create<AtenEqIntOp>(
|
||||
op.getLoc(), selfSizeAtDim, sizeValues[0]);
|
||||
Value isMinusOne =
|
||||
rewriter.create<AtenEqIntOp>(op.getLoc(), cstMOne, sizeValues[0]);
|
||||
Value isMOneOrSameSize = rewriter.create<Aten__Or__BoolOp>(
|
||||
op.getLoc(), isMinusOne, isSameSize);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
op.getLoc(), isMOneOrSameSize,
|
||||
rewriter.getStringAttr("unflatten sizes must be compatible"));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(op, op.getType(),
|
||||
unsqueeze);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSelectIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -757,7 +757,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True
|
||||
)
|
||||
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
|
||||
|
|
Loading…
Reference in New Issue