Add an info cast to `prims.squeeze` decomposition (#3844)

The onnx ingest sometimes has poorly propagated shape information. E.g.:

```mlir
...
    %9020 = torch.prims.squeeze %9010#1, %9019 : !torch.vtensor<[?,384,1],f32>, !torch.list<int> -> !torch.vtensor<[1,384],f32>
    return %9015, %9020 : !torch.vtensor<[1,384],f32>, !torch.vtensor<[1,384],f32>
  }
}
```

This occurred at the boundary of the onnx model
`migraphx_bert__bert-large-uncased`. Evidently, the output value tensor
info had more information than could be propagated forward. The
`PrimsSqueeze` lowering was returning a `!torch.vtensor<[?,384],f32>`
which was causing a type mismatch with the `func.return`.
pull/3848/head
zjgarvey 2024-11-01 12:10:47 -05:00 committed by GitHub
parent a82ba1c422
commit 3cfb7c8df6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -8958,7 +8958,8 @@ public:
} }
result = *squeezeTensorInfo; result = *squeezeTensorInfo;
} }
rewriter.replaceOp(op, result); rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(op, op.getType(),
result);
return success(); return success();
} }
}; };