Commit Graph

111 Commits

Author SHA1 Message Date
Tri Dao
321c57d07d Set block size of SM75 fwd to 256 if there's no dropout
This speeds up the fwd by 1.5x.
2022-06-04 16:51:28 -07:00
Tri Dao
d380e87fb6 Don't use Smem_dp_sum in backward pass
To reduce smem usage for SM75
2022-06-04 16:01:36 -07:00
Tri Dao
b17c6fe235 Reduce smem usage for Q and dO in the backward pass
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO
have 2 buffers)
2022-06-03 16:59:11 -07:00
Tri Dao
2712aa4c8d Support Turing mma instructions 2022-06-03 16:58:44 -07:00
Tri Dao
050873327e Remove softmax fp16 max 2022-06-02 14:09:46 -07:00
Tri Dao
14dc326e59 Use Cutlass gemm as WarpMma 2022-06-02 10:33:32 -07:00
Tri Dao
e78e7c9553 Remove old backward 2022-06-02 10:13:44 -07:00
Tri Dao
512c98ee05 Add Cutlass as submodule 2022-06-02 09:54:16 -07:00
Tri Dao
5a61cb7729 Rename src -> flash_attn 2022-06-01 18:50:26 -07:00
Tri Dao
c41479d66d Support SM86 GPUs 2022-06-01 18:49:47 -07:00
Tri Dao
9dbc491aa5 Rename, add benchmarking script 2022-05-26 13:57:38 -07:00