torch-mlir/lib/Dialect/TMTensor
rohan-tan-bhowmik e86f56bc76
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
2024-09-09 15:51:41 -07:00
..
IR [Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690) 2024-09-09 15:51:41 -07:00
Transforms [NFC reformat] Run pre-commit on all files and format misc. 2024-04-27 14:08:09 -07:00
CMakeLists.txt Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00