Tri Dao
|
b4cc152e97
|
Make sure dout is contiguous
|
2023-07-17 21:54:44 -07:00 |
|
Tri Dao
|
4f285b3547
|
FlashAttention-2 release
|
2023-07-17 06:21:34 -07:00 |
|
Tri Dao
|
e8a0b4acdd
|
[Doc] Change total -> total_q
|
2023-07-02 17:23:52 -07:00 |
|
Kirthi Shankar Sivamani
|
7d25a4ec4f
|
Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
2023-04-13 06:25:52 +00:00 |
|
Kirthi Shankar Sivamani
|
31018c5fa0
|
Support CUDA graph capture
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
2023-04-12 16:53:22 -07:00 |
|
Kirthi Shankar Sivamani
|
b6aa059bbf
|
Add option for deterministic execution
|
2023-03-30 18:23:35 -07:00 |
|
Tri Dao
|
88c4e5dbf6
|
Fix the case when dout is not contiguous
|
2022-12-13 13:58:17 -08:00 |
|
Tri Dao
|
557781933d
|
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
|
2022-11-05 16:26:17 -07:00 |
|
Tri Dao
|
46fd2a20b2
|
Support all head dims that are multiples of 8, up to 128
|
2022-10-24 16:04:21 -07:00 |
|
Tri Dao
|
a5a8806d1a
|
Split bwd on the seqlen_q dimension
|
2022-10-23 11:35:15 -07:00 |
|
Tri Dao
|
a44f48df5a
|
Split fwd on the seqlen_q dimension
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
1aa6d7d9b6
|
Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
1b9facacc3
|
Fix QKV interface to allocate output in Python
|
2022-10-14 03:33:41 -07:00 |
|
Tri Dao
|
5badfb7848
|
Implement attention kernel that splits the batch into two
|
2022-10-13 20:49:02 -07:00 |
|
Tri Dao
|
a5559a0e75
|
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
|
2022-07-03 17:52:05 -07:00 |
|
Tri Dao
|
6c3a8c65af
|
Implement cross attention
|
2022-07-03 17:48:12 -07:00 |
|
Tri Dao
|
5a61cb7729
|
Rename src -> flash_attn
|
2022-06-01 18:50:26 -07:00 |
|