YangShu
ff07250e8f
fix typo in function mha_fwd
...
as title.
2022-10-17 16:13:47 +08:00
Tri Dao
52fb4b729b
Fix #54 : set device for multi-GPU case
2022-10-16 12:51:26 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Tri Dao
de19de7ab1
Implement for bf16
2022-07-09 23:31:56 -07:00
Tri Dao
5b838a8bef
Apply dropout scaling to dQ and dK instead of to V (in bwd)
...
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
6c3a8c65af
Implement cross attention
2022-07-03 17:48:12 -07:00
Tri Dao
c0daa62eaa
Add type check (fp16) in the forward pass
2022-06-26 11:41:30 -07:00
Tri Dao
eeca63a72a
Bug fix: wrong smem_o write pointer for d=16
2022-06-25 15:18:33 -07:00
Tri Dao
5d07483bbc
Refactor Gmem code to store q, k, v pointers separately
2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958
Implement bwd for head dim 128
2022-06-11 17:52:36 -07:00
Tri Dao
0d854692c6
Implement fwd for head dim 128
2022-06-11 17:52:36 -07:00
Tri Dao
321c57d07d
Set block size of SM75 fwd to 256 if there's no dropout
...
This speeds up the fwd by 1.5x.
2022-06-04 16:51:28 -07:00
Tri Dao
2712aa4c8d
Support Turing mma instructions
2022-06-03 16:58:44 -07:00
Tri Dao
9dbc491aa5
Rename, add benchmarking script
2022-05-26 13:57:38 -07:00