torch-mlir/include/torch-mlir-dialects/Dialect/TMTensor/IR
rohan-tan-bhowmik e86f56bc76
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
2024-09-09 15:51:41 -07:00
..
CMakeLists.txt Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
ScalarLoopOpInterface.h Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
ScalarLoopOpInterface.td Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
TMTensorBase.td Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
TMTensorDialect.h Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
TMTensorInterfaces.h Integrate LLVM at llvm/llvm-project@593f6fdcb4 (#3260) 2024-04-29 12:01:40 -07:00
TMTensorInterfaces.td [NFC] Change to *cast instead of .*cast variants (#3405) 2024-05-30 23:45:13 -07:00
TMTensorOpInterface.h Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
TMTensorOps.h Re-organize project structure to separate PyTorch dependencies from core project. (#2542) 2023-11-02 19:45:55 -07:00
TMTensorOps.td [Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690) 2024-09-09 15:51:41 -07:00