Don't fold `aten.detach` if result isn't same type as input. (#2824)

We were seeing some assertion failures after some checks around folders
were tightened up in LLVM:
https://github.com/llvm/llvm-project/pull/75887 . This PR essentially
moves the logic that used to be applied at the LLVM level into the
folder, which seems to be the suggested fix.

I'm not sure if the IR that caused issues for us _should_ be valid?
```
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
```
A better fix might be to create a verifier ensuring the result of
`aten.detach` has the same type as its operand.

---------

Co-authored-by: aaron-stgeorge <aaron.stgeorge@getcruise.com>
pull/2831/head
Aaron St George 2024-01-30 09:45:51 -08:00 committed by GitHub
parent db67bc555a
commit 4c557847bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 1 deletions

View File

@ -1465,7 +1465,11 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
// AtenDetachOp // AtenDetachOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); } OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) {
if (getSelf().getType() != getResult().getType())
return {};
return getSelf();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenNeIntOp // AtenNeIntOp

View File

@ -2146,3 +2146,10 @@ func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,?
%1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32> %1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32> return %1 : !torch.vtensor<[?,?],f32>
} }
// CHECK-LABEL: func.func @torch.aten.detach$canonicalize
// CHECK-NEXT: torch.aten.detach
func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !torch.tensor {
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
return %1 : !torch.tensor
}