Don't fold `aten.clone` if result isn't same type as input (#3347)

Similar to https://github.com/llvm/torch-mlir/pull/2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
pull/3349/head
Aaron St George 2024-05-15 09:07:45 -07:00 committed by GitHub
parent 5928f68e60
commit ba32b9cee7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 1 deletions

View File

@ -2581,7 +2581,8 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
// note: memory_format would be ignored
if (llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
if (getSelf().getType() == getResult().getType() &&
llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
// self should have value semantics
return getSelf();
}

View File

@ -3015,3 +3015,14 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
return %result0 : !torch.vtensor<[10,64,56,56],f32>
}
// -----
// CHECK-LABEL: @torch.aten.clone$no_fold(
func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) {
// CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
%none = torch.constant.none
%0 = torch.aten.clone %arg0, %none : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
return %1 : !torch.tensor
}