Tri Dao
dfe29f5e2b
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
2023-09-18 15:29:06 -07:00
Tri Dao
ccbb14f38e
Implement rotary embedding in flash_attn_with_kvcache
2023-09-16 01:20:16 -07:00
dan_the_3rd
c3f2a632aa
[ft_attention] Fix for seqlen=8136 ( #488 )
...
When seqlen=8136, `smem_sz = 48840`, and apparently starting the kernel returns an `invalid argument` CUDA error.
`48840 < 48 * 1024` but apparently it's still above the limit somehow..?
Tested on A100
2023-08-28 10:00:22 -07:00
Tri Dao
a157cc8c9b
[FT] Implement MQA/GQA
2023-07-22 23:47:01 -07:00
Tri Dao
2800efc71f
[FT] rotary_cos/sin should have batch_size dimension
2023-07-06 15:33:33 -07:00
Tri Dao
3a9bfd076f
[FT] rotary_cos/sin should have shape (dim) instead of (seqlen, dim)
2023-07-03 09:41:04 -07:00
Tri Dao
62e9814466
[Rotary] Make sure frequency calculation is in fp32
2023-07-02 16:39:39 -07:00
Tri Dao
48bc6eacd6
[Gen] Add rotary base as an argument to FT attention kernel
2023-05-30 13:38:34 -07:00
Tri Dao
311d6606bf
[Gen] Fix FT kernel smem size, CG when batch size changed
2023-04-20 17:03:13 -07:00
Tri Dao
f5d0fbd468
[FT] Fix FT's single query attention for bf16 hdim128 rotary
2023-03-28 21:27:00 -07:00
Tri Dao
dc08ea1c33
Support H100 for other CUDA extensions
2023-03-15 16:59:27 -07:00
Tri Dao
f1e01c27ba
[Gen] Pass qkv_stride to ft_attention kernel for batched generation
2023-01-15 15:20:01 -08:00
Tri Dao
7c2191542a
[Gen] Make generation work with Tensor Parallel
2023-01-15 11:34:27 -08:00
Tri Dao
be1afaa276
[Gen, FT] Use fp32 accum for FMA
2023-01-03 22:09:22 -08:00
Tri Dao
f266fc7262
[Gen, FT] Use tlength instead of params.timestep for rotary
2023-01-03 17:46:55 -08:00
Tri Dao
a01d1213d7
[Gen] Add kernel from FasterTransformer for benchmarking
2023-01-03 17:37:43 -08:00