mirror of https://github.com/llvm/torch-mlir
[Torch] Fold no-op reshape (#3769)
This was preventing dynamic dims in an ONNX model from being reified (causing the generation of `tensor.cast`s and preventing fusion in iree): ```mlir %2 = torch.vtensor.literal(dense<[4, 256]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>] %7 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int> %8 = torch.aten.reshape %2, %7 : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64> //... chain of foldable ops linking %2 to the `shape` operand of a `torch.aten.broadcast_to ... -> !torch.vtensor<[?,?],si64>` ```pull/3786/head
parent
2665ed343b
commit
8787970afe
|
@ -11455,6 +11455,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenReshapeAsOp : Torch_Op<"aten.reshape_as", [
|
||||
|
|
|
@ -2261,6 +2261,19 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) {
|
||||
auto selfTy = dyn_cast<ValueTensorType>(getSelf().getType());
|
||||
auto opTy = dyn_cast<ValueTensorType>(getType());
|
||||
if (selfTy && selfTy == opTy && selfTy.hasSizes() &&
|
||||
selfTy.toBuiltinTensor().hasStaticShape())
|
||||
return getSelf();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSelectIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -856,7 +856,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)")
|
||||
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::reshape : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
|
||||
emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue