[Torch] Fold no-op reshape (#3769)

This was preventing dynamic dims in an ONNX model from being reified (causing the generation of `tensor.cast`s and preventing fusion in iree):

```mlir
%2 = torch.vtensor.literal(dense<[4, 256]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>]
%7 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%8 = torch.aten.reshape %2, %7 : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64>
//... chain of foldable ops linking %2 to the `shape` operand of a `torch.aten.broadcast_to ... -> !torch.vtensor<[?,?],si64>`
```
pull/3786/head
Ian Wood 2024-10-10 18:54:27 -07:00 committed by GitHub
parent 2665ed343b
commit 8787970afe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 1 deletions

View File

@ -11455,6 +11455,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenReshapeAsOp : Torch_Op<"aten.reshape_as", [

View File

@ -2261,6 +2261,19 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// AtenReshapeOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) {
auto selfTy = dyn_cast<ValueTensorType>(getSelf().getType());
auto opTy = dyn_cast<ValueTensorType>(getType());
if (selfTy && selfTy == opTy && selfTy.hasSizes() &&
selfTy.toBuiltinTensor().hasStaticShape())
return getSelf();
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenSelectIntOp
//===----------------------------------------------------------------------===//

View File

@ -856,7 +856,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)")
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")