Commit Graph

41 Commits

Author SHA1 Message Date
Tri Dao
c422fee377 Get rid of o_rows_are_valid since we don't have headdim=16 anymore 2022-10-24 17:29:36 -07:00
Tri Dao
46fd2a20b2 Support all head dims that are multiples of 8, up to 128 2022-10-24 16:04:21 -07:00
Tri Dao
97e13de2b4 Cast q.get_device() to char to avoid compiler warning (narrowing) 2022-10-24 15:59:49 -07:00
Tri Dao
9e92a1f2d2 Attempt to use atomicCAS to replace atomicAdd(bfloat16) 2022-10-23 16:22:43 -07:00
Tri Dao
a5a8806d1a Split bwd on the seqlen_q dimension 2022-10-23 11:35:15 -07:00
Tri Dao
871db47941 Don't need to run configure for the forward pass 2022-10-21 18:22:27 -07:00
Tri Dao
7fc39832e2 Use block_size=128 for headdim=128 on SM80
Previously we were using block_size=256.
2022-10-21 13:19:54 -07:00
Tri Dao
a44f48df5a Split fwd on the seqlen_q dimension 2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6 Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
YangShu
ff07250e8f
fix typo in function mha_fwd
as title.
2022-10-17 16:13:47 +08:00
Tri Dao
52fb4b729b Fix #54: set device for multi-GPU case 2022-10-16 12:51:26 -07:00
Tri Dao
5badfb7848 Implement attention kernel that splits the batch into two 2022-10-13 20:49:02 -07:00
Eric Engelhart
2211db5fab Fixed switch statement, thanks @yocabon 2022-10-04 21:31:39 -04:00
Eric Engelhart
9d7fd5b6e7 Replace BOOL_SWITCH with FP16_SWITCH to work around MSVC bug with constexpr variables and templates 2022-10-04 21:31:39 -04:00
Tri Dao
8166063a55 Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit 2022-09-12 14:21:29 -07:00
Tri Dao
bc2c210254 Don't nest BOOL_SWITCH to work around gcc 7 bug 2022-07-11 10:28:46 -07:00
Tri Dao
de19de7ab1 Implement for bf16 2022-07-09 23:31:56 -07:00
Tri Dao
6a77a6da10 Refactor gemm_cl to template on either __half or __nv_bfloat16 2022-07-09 23:18:26 -07:00
Tri Dao
e518a4b327 Refactor to template on __half, implement bf16 util functions 2022-07-09 23:18:26 -07:00
Tri Dao
2dc1b205f6 Fix Illegal Memory Access bug in fwd when d=16 2022-07-09 23:17:14 -07:00
Tri Dao
5b838a8bef Apply dropout scaling to dQ and dK instead of to V (in bwd)
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
a5559a0e75 Do P * dP (pointwise) in the bwd in fp32 instead of fp16 2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Tri Dao
f66603cb6f Support batch size > 64K by swapping grid.x and grid.y 2022-06-29 23:16:24 -07:00
Tri Dao
c0daa62eaa Add type check (fp16) in the forward pass 2022-06-26 11:41:30 -07:00
Tri Dao
ea38d3d261 Fix race condition in backward pass (smem_dq) 2022-06-25 18:02:30 -07:00
Tri Dao
eeca63a72a Bug fix: wrong smem_o write pointer for d=16 2022-06-25 15:18:33 -07:00
Tri Dao
5d07483bbc Refactor Gmem code to store q, k, v pointers separately 2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958 Implement bwd for head dim 128 2022-06-11 17:52:36 -07:00
Tri Dao
0d854692c6 Implement fwd for head dim 128 2022-06-11 17:52:36 -07:00
Tri Dao
321c57d07d Set block size of SM75 fwd to 256 if there's no dropout
This speeds up the fwd by 1.5x.
2022-06-04 16:51:28 -07:00
Tri Dao
d380e87fb6 Don't use Smem_dp_sum in backward pass
To reduce smem usage for SM75
2022-06-04 16:01:36 -07:00
Tri Dao
b17c6fe235 Reduce smem usage for Q and dO in the backward pass
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO
have 2 buffers)
2022-06-03 16:59:11 -07:00
Tri Dao
2712aa4c8d Support Turing mma instructions 2022-06-03 16:58:44 -07:00
Tri Dao
050873327e Remove softmax fp16 max 2022-06-02 14:09:46 -07:00
Tri Dao
14dc326e59 Use Cutlass gemm as WarpMma 2022-06-02 10:33:32 -07:00
Tri Dao
e78e7c9553 Remove old backward 2022-06-02 10:13:44 -07:00
Tri Dao
512c98ee05 Add Cutlass as submodule 2022-06-02 09:54:16 -07:00
Tri Dao
5a61cb7729 Rename src -> flash_attn 2022-06-01 18:50:26 -07:00
Tri Dao
c41479d66d Support SM86 GPUs 2022-06-01 18:49:47 -07:00
Tri Dao
9dbc491aa5 Rename, add benchmarking script 2022-05-26 13:57:38 -07:00