From 4c557847bdd44bdfff90fa6d56089529ef065843 Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Tue, 30 Jan 2024 09:45:51 -0800 Subject: [PATCH] Don't fold `aten.detach` if result isn't same type as input. (#2824) We were seeing some assertion failures after some checks around folders were tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 . This PR essentially moves the logic that used to be applied at the LLVM level into the folder, which seems to be the suggested fix. I'm not sure if the IR that caused issues for us _should_ be valid? ``` %1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor ``` A better fix might be to create a verifier ensuring the result of `aten.detach` has the same type as its operand. --------- Co-authored-by: aaron-stgeorge --- lib/Dialect/Torch/IR/TorchOps.cpp | 6 +++++- test/Dialect/Torch/canonicalize.mlir | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 4af9bcfc1..4aacd8d76 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1465,7 +1465,11 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, // AtenDetachOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); } +OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { + if (getSelf().getType() != getResult().getType()) + return {}; + return getSelf(); +} //===----------------------------------------------------------------------===// // AtenNeIntOp diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 9172d4642..3cf82d9ed 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2146,3 +2146,10 @@ func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,? %1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32> return %1 : !torch.vtensor<[?,?],f32> } + +// CHECK-LABEL: func.func @torch.aten.detach$canonicalize +// CHECK-NEXT: torch.aten.detach +func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !torch.tensor { + %1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor + return %1 : !torch.tensor +}