torch-mlir/projects/pt1
rohan-tan-bhowmik e86f56bc76
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
2024-09-09 15:51:41 -07:00
..
e2e_testing [Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690) 2024-09-09 15:51:41 -07:00
examples [FxImporter] Add an e2e test example for FxImporter (#3331) 2024-05-14 00:45:19 +08:00
python [Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690) 2024-09-09 15:51:41 -07:00
test [NFC] Update black version (#3256) 2024-04-29 11:06:01 +08:00
tools Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
CMakeLists.txt Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00