Tri Dao
6998e0ecdb
Fix out-of-bound memory read
2022-11-09 09:34:14 -08:00
Tri Dao
557781933d
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
...
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
de19de7ab1
Implement for bf16
2022-07-09 23:31:56 -07:00
Tri Dao
6a77a6da10
Refactor gemm_cl to template on either __half or __nv_bfloat16
2022-07-09 23:18:26 -07:00
Tri Dao
e518a4b327
Refactor to template on __half, implement bf16 util functions
2022-07-09 23:18:26 -07:00
Tri Dao
5b838a8bef
Apply dropout scaling to dQ and dK instead of to V (in bwd)
...
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
a5559a0e75
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af
Implement cross attention
2022-07-03 17:48:12 -07:00
Tri Dao
f66603cb6f
Support batch size > 64K by swapping grid.x and grid.y
2022-06-29 23:16:24 -07:00
Tri Dao
ea38d3d261
Fix race condition in backward pass (smem_dq)
2022-06-25 18:02:30 -07:00
Tri Dao
5d07483bbc
Refactor Gmem code to store q, k, v pointers separately
2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958
Implement bwd for head dim 128
2022-06-11 17:52:36 -07:00
Tri Dao
d380e87fb6
Don't use Smem_dp_sum in backward pass
...
To reduce smem usage for SM75
2022-06-04 16:01:36 -07:00
Tri Dao
14dc326e59
Use Cutlass gemm as WarpMma
2022-06-02 10:33:32 -07:00
Tri Dao
9dbc491aa5
Rename, add benchmarking script
2022-05-26 13:57:38 -07:00