Commit Graph

15 Commits

Author SHA1 Message Date
Driss Guessous
23e8fa5a26
Add the option for the macro and note (#893) 2024-03-27 19:12:11 -07:00
Tri Dao
b32efb1a4d Don't need to reduce row_sum during online softmax 2024-02-20 13:33:38 -08:00
Tri Dao
ed4959b2eb Change inline to __forceinline__, use __grid_constant__ param 2024-01-20 17:38:47 -08:00
Tri Dao
6f706eff96 Make Softmax an object 2024-01-19 16:09:31 -08:00
Tri Dao
df1418f9db Move softmax_rescale_o to softmax.h 2024-01-14 15:06:06 -08:00
Tri Dao
6777336a1c Move masking to a separate file (mask.h) 2024-01-14 12:43:47 -08:00
Tri Dao
1274ec3e7e Move dropout to a separate file (dropout.h) 2024-01-14 12:19:17 -08:00
Tri Dao
10dad61277 apply_dropout now takes tensor of rowcol layout 2024-01-14 01:03:23 -08:00
Tri Dao
5ab9b3667b Clean up alibi, implement non-causal alibi 2023-12-21 22:27:40 -08:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
bb9beb3645 Remove some unused headers 2023-09-12 12:37:10 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
BoxiangW
e07aa036db
Support flash attention 2 with causal masking when KV's seq length is longer than Q's seq length. (#436) 2023-08-24 16:42:34 -07:00
Tri Dao
a4f148b6ab Fix masking of bwd when seqlen is not divisible by 128 2023-07-31 17:46:34 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00