mirror of https://github.com/llvm/torch-mlir
Don't fold `aten.clone` if result isn't same type as input (#3347)
Similar to https://github.com/llvm/torch-mlir/pull/2824, we were seeing some assertion failures after the addition checks around folders were tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 . This PR essentially moves the logic that used to be applied at the LLVM level into the folder, which seems to be the suggested fix.pull/3349/head
parent
5928f68e60
commit
ba32b9cee7
|
@ -2581,7 +2581,8 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
|
|||
|
||||
OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
|
||||
// note: memory_format would be ignored
|
||||
if (llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
|
||||
if (getSelf().getType() == getResult().getType() &&
|
||||
llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
|
||||
// self should have value semantics
|
||||
return getSelf();
|
||||
}
|
||||
|
|
|
@ -3015,3 +3015,14 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor
|
|||
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
|
||||
return %result0 : !torch.vtensor<[10,64,56,56],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @torch.aten.clone$no_fold(
|
||||
func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) {
|
||||
// CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.clone %arg0, %none : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
|
||||
%1 = torch.copy.to_tensor %0 : !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue