Commit Graph

11 Commits

Author SHA1 Message Date
Tri Dao
e518a4b327 Refactor to template on __half, implement bf16 util functions 2022-07-09 23:18:26 -07:00
Tri Dao
5b838a8bef Apply dropout scaling to dQ and dK instead of to V (in bwd)
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
a5559a0e75 Do P * dP (pointwise) in the bwd in fp32 instead of fp16 2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Tri Dao
f66603cb6f Support batch size > 64K by swapping grid.x and grid.y 2022-06-29 23:16:24 -07:00
Tri Dao
ea38d3d261 Fix race condition in backward pass (smem_dq) 2022-06-25 18:02:30 -07:00
Tri Dao
5d07483bbc Refactor Gmem code to store q, k, v pointers separately 2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958 Implement bwd for head dim 128 2022-06-11 17:52:36 -07:00
Tri Dao
d380e87fb6 Don't use Smem_dp_sum in backward pass
To reduce smem usage for SM75
2022-06-04 16:01:36 -07:00
Tri Dao
14dc326e59 Use Cutlass gemm as WarpMma 2022-06-02 10:33:32 -07:00
Tri Dao
9dbc491aa5 Rename, add benchmarking script 2022-05-26 13:57:38 -07:00