torch-mlir/lib/Conversion
rohan-tan-bhowmik e86f56bc76
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
2024-09-09 15:51:41 -07:00
..
TorchConversionToMLProgram [NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00
TorchOnnxToTorch Fix onnx.Gather lowering with dynamic shapes (#3675) 2024-08-29 17:02:16 -07:00
TorchToArith [TorchToArith] Add a lowering for `torch.add.float_int` (#3594) 2024-08-07 11:55:27 -05:00
TorchToLinalg [TorchToLinalg] Use `linalg.transpose` instead of `generic` when lowering `aten.T` (#3660) 2024-09-07 08:09:10 +02:00
TorchToSCF [NFC] Change to *cast instead of .*cast variants (#3405) 2024-05-30 23:45:13 -07:00
TorchToStablehlo [Stablehlo] use stablehlo specs lowering AtenSliceScatterOp (#3592) 2024-08-15 20:06:29 +08:00
TorchToTMTensor [Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690) 2024-09-09 15:51:41 -07:00
TorchToTensor [NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00
TorchToTosa [TOSA] Add Torch to Tosa Legalization for torch.tril (#3678) 2024-09-05 11:27:29 -07:00
Utils Added support for integer to complex conversion (#3604) 2024-08-14 18:13:00 +05:30
CMakeLists.txt [torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055) 2024-03-26 12:41:40 -07:00
PassDetail.h Minor fixes for `ConvertTorchConversionToMLProgram`. (#1991) 2023-04-04 09:09:58 -07:00
Passes.cpp Clang format refresh (#2812) 2024-01-29 12:59:33 -05:00