Tri Dao
|
557781933d
|
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
|
2022-11-05 16:26:17 -07:00 |
|
Tri Dao
|
46fd2a20b2
|
Support all head dims that are multiples of 8, up to 128
|
2022-10-24 16:04:21 -07:00 |
|
Tri Dao
|
9e92a1f2d2
|
Attempt to use atomicCAS to replace atomicAdd(bfloat16)
|
2022-10-23 16:22:43 -07:00 |
|
Tri Dao
|
a5a8806d1a
|
Split bwd on the seqlen_q dimension
|
2022-10-23 11:35:15 -07:00 |
|
Tri Dao
|
1aa6d7d9b6
|
Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
de19de7ab1
|
Implement for bf16
|
2022-07-09 23:31:56 -07:00 |
|
Tri Dao
|
6a77a6da10
|
Refactor gemm_cl to template on either __half or __nv_bfloat16
|
2022-07-09 23:18:26 -07:00 |
|
Tri Dao
|
e518a4b327
|
Refactor to template on __half, implement bf16 util functions
|
2022-07-09 23:18:26 -07:00 |
|
Tri Dao
|
6c3a8c65af
|
Implement cross attention
|
2022-07-03 17:48:12 -07:00 |
|
Tri Dao
|
f66603cb6f
|
Support batch size > 64K by swapping grid.x and grid.y
|
2022-06-29 23:16:24 -07:00 |
|
Tri Dao
|
eeca63a72a
|
Bug fix: wrong smem_o write pointer for d=16
|
2022-06-25 15:18:33 -07:00 |
|
Tri Dao
|
5d07483bbc
|
Refactor Gmem code to store q, k, v pointers separately
|
2022-06-12 16:37:32 -07:00 |
|
Tri Dao
|
d3e6440958
|
Implement bwd for head dim 128
|
2022-06-11 17:52:36 -07:00 |
|
Tri Dao
|
0d854692c6
|
Implement fwd for head dim 128
|
2022-06-11 17:52:36 -07:00 |
|
Tri Dao
|
b17c6fe235
|
Reduce smem usage for Q and dO in the backward pass
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO
have 2 buffers)
|
2022-06-03 16:59:11 -07:00 |
|
Tri Dao
|
2712aa4c8d
|
Support Turing mma instructions
|
2022-06-03 16:58:44 -07:00 |
|
Tri Dao
|
050873327e
|
Remove softmax fp16 max
|
2022-06-02 14:09:46 -07:00 |
|
Tri Dao
|
14dc326e59
|
Use Cutlass gemm as WarpMma
|
2022-06-02 10:33:32 -07:00 |
|
Tri Dao
|
9dbc491aa5
|
Rename, add benchmarking script
|
2022-05-26 13:57:38 -07:00 |
|