mirror of https://github.com/llvm/torch-mlir
e86f56bc76
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product attention + relevant comments + tests. The tests added highlight the new capabilities introduced in this PR, including: Attention with F16 mask Attention with Boolean mask Causal attention with same Q K V shapes Causal attention without Q K V shapes Made sure that one cannot input both mask and is_causal. |
||
---|---|---|
.. | ||
TorchConversionToMLProgram | ||
TorchOnnxToTorch | ||
TorchToArith | ||
TorchToLinalg | ||
TorchToSCF | ||
TorchToStablehlo | ||
TorchToTMTensor | ||
TorchToTensor | ||
TorchToTosa | ||
Utils | ||
CMakeLists.txt | ||
PassDetail.h | ||
Passes.cpp |