mirror of https://github.com/llvm/torch-mlir
Don't fold `aten.detach` if result isn't same type as input. (#2824)
We were seeing some assertion failures after some checks around folders were tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 . This PR essentially moves the logic that used to be applied at the LLVM level into the folder, which seems to be the suggested fix. I'm not sure if the IR that caused issues for us _should_ be valid? ``` %1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor ``` A better fix might be to create a verifier ensuring the result of `aten.detach` has the same type as its operand. --------- Co-authored-by: aaron-stgeorge <aaron.stgeorge@getcruise.com>pull/2831/head
parent
db67bc555a
commit
4c557847bd
|
@ -1465,7 +1465,11 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
|
|||
// AtenDetachOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); }
|
||||
OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) {
|
||||
if (getSelf().getType() != getResult().getType())
|
||||
return {};
|
||||
return getSelf();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenNeIntOp
|
||||
|
|
|
@ -2146,3 +2146,10 @@ func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,?
|
|||
%1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %1 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.detach$canonicalize
|
||||
// CHECK-NEXT: torch.aten.detach
|
||||
func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !torch.tensor {
|
||||
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue