Tri Dao
713ea302d7
Allow headdim 128 in FlashMHA interface
2022-08-05 09:47:22 -07:00
Tri Dao
2ed471ecc4
Add tests for numerical error
2022-07-22 17:54:09 -04:00
Tri Dao
42f54d8840
Edit mention of Triton implementation
...
Phil Tillet suggests calling it "experimental".
2022-07-11 17:02:29 -07:00
Tri Dao
4577151ff8
Link to Triton implementation
2022-07-11 16:01:43 -07:00
Tri Dao
bc2c210254
Don't nest BOOL_SWITCH to work around gcc 7 bug
2022-07-11 10:28:46 -07:00
Tri Dao
d1fc80a3bb
Link to IEEE Spectrum article on MLPerf
2022-07-10 12:11:46 -07:00
Tri Dao
1bbebccc0a
Edit README to mention bf16 support
2022-07-09 23:34:29 -07:00
Tri Dao
de19de7ab1
Implement for bf16
2022-07-09 23:31:56 -07:00
Tri Dao
6a77a6da10
Refactor gemm_cl to template on either __half or __nv_bfloat16
2022-07-09 23:18:26 -07:00
Tri Dao
e518a4b327
Refactor to template on __half, implement bf16 util functions
2022-07-09 23:18:26 -07:00
Tri Dao
2dc1b205f6
Fix Illegal Memory Access bug in fwd when d=16
2022-07-09 23:17:14 -07:00
Tri Dao
5b838a8bef
Apply dropout scaling to dQ and dK instead of to V (in bwd)
...
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
a5559a0e75
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af
Implement cross attention
2022-07-03 17:48:12 -07:00
Tri Dao
01947bc93b
Merge pull request #18 from gahdritz/main
...
Slightly improve installation process
2022-07-02 23:33:14 -07:00
Gustaf
af4a9ce024
Add missing __init__.py
2022-07-03 02:04:55 -04:00
Gustaf
440e9c49f2
Add einops installation to setup.py
2022-07-03 02:04:24 -04:00
Tri Dao
f66603cb6f
Support batch size > 64K by swapping grid.x and grid.y
2022-06-29 23:16:24 -07:00
Tri Dao
450b64fe44
Add README section on issues
2022-06-27 13:50:16 -07:00
Tri Dao
c0daa62eaa
Add type check (fp16) in the forward pass
2022-06-26 11:41:30 -07:00
Tri Dao
ea38d3d261
Fix race condition in backward pass (smem_dq)
2022-06-25 18:02:30 -07:00
Tri Dao
eeca63a72a
Bug fix: wrong smem_o write pointer for d=16
2022-06-25 15:18:33 -07:00
Dan Fu
765741c1ee
More explanation
2022-06-14 11:55:14 -07:00
Dan Fu
2d5b2483b8
Speedup graph for A100, d128
2022-06-14 11:54:16 -07:00
Tri Dao
5d07483bbc
Refactor Gmem code to store q, k, v pointers separately
2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958
Implement bwd for head dim 128
2022-06-11 17:52:36 -07:00
Tri Dao
0d854692c6
Implement fwd for head dim 128
2022-06-11 17:52:36 -07:00
Dan Fu
0a398dfc37
Broken link
2022-06-04 17:28:45 -07:00
Dan Fu
bd60750e0b
T4
2022-06-04 17:27:51 -07:00
Tri Dao
321c57d07d
Set block size of SM75 fwd to 256 if there's no dropout
...
This speeds up the fwd by 1.5x.
2022-06-04 16:51:28 -07:00
Tri Dao
f2d8d4104e
Edit README: support Turing (SM75)
2022-06-04 16:06:48 -07:00
Tri Dao
d380e87fb6
Don't use Smem_dp_sum in backward pass
...
To reduce smem usage for SM75
2022-06-04 16:01:36 -07:00
Tri Dao
b17c6fe235
Reduce smem usage for Q and dO in the backward pass
...
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO
have 2 buffers)
2022-06-03 16:59:11 -07:00
Tri Dao
2712aa4c8d
Support Turing mma instructions
2022-06-03 16:58:44 -07:00
Tri Dao
050873327e
Remove softmax fp16 max
2022-06-02 14:09:46 -07:00
Tri Dao
14dc326e59
Use Cutlass gemm as WarpMma
2022-06-02 10:33:32 -07:00
Tri Dao
e78e7c9553
Remove old backward
2022-06-02 10:13:44 -07:00
Tri Dao
512c98ee05
Add Cutlass as submodule
2022-06-02 09:54:16 -07:00
Dan Fu
ad6c694bb3
3090 speedup
2022-06-01 20:07:00 -07:00
Tri Dao
5a61cb7729
Rename src -> flash_attn
2022-06-01 18:50:26 -07:00
Tri Dao
c41479d66d
Support SM86 GPUs
2022-06-01 18:49:47 -07:00
Dan Fu
4b7cfb5f45
Citation
2022-05-30 13:29:04 -07:00
Dan Fu
963173fcb5
Jpg resolution
2022-05-30 11:47:42 -07:00
Dan Fu
cd04d29883
Fix jpg
2022-05-30 11:46:01 -07:00
Tri Dao
a78745189a
Add paper arXiv link
2022-05-29 18:15:43 -07:00
Tri Dao
d9fff84bd0
Edit roadmap
2022-05-29 15:44:18 -07:00
Tri Dao
e4ffe5d50e
Convert banner figure from pdf to jpg
2022-05-29 15:39:17 -07:00
Tri Dao
67c3779598
Reorganize directories, add banner figure
2022-05-29 15:34:22 -07:00
Dan Fu
7025a092d1
Make png images into jpg for dark mode
2022-05-28 22:46:49 +01:00
Dan Fu
4decc3c166
README typo
2022-05-27 22:38:20 +01:00