Tri Dao
|
ea38d3d261
|
Fix race condition in backward pass (smem_dq)
|
2022-06-25 18:02:30 -07:00 |
|
Tri Dao
|
eeca63a72a
|
Bug fix: wrong smem_o write pointer for d=16
|
2022-06-25 15:18:33 -07:00 |
|
Tri Dao
|
5d07483bbc
|
Refactor Gmem code to store q, k, v pointers separately
|
2022-06-12 16:37:32 -07:00 |
|
Tri Dao
|
d3e6440958
|
Implement bwd for head dim 128
|
2022-06-11 17:52:36 -07:00 |
|
Tri Dao
|
0d854692c6
|
Implement fwd for head dim 128
|
2022-06-11 17:52:36 -07:00 |
|
Tri Dao
|
321c57d07d
|
Set block size of SM75 fwd to 256 if there's no dropout
This speeds up the fwd by 1.5x.
|
2022-06-04 16:51:28 -07:00 |
|
Tri Dao
|
d380e87fb6
|
Don't use Smem_dp_sum in backward pass
To reduce smem usage for SM75
|
2022-06-04 16:01:36 -07:00 |
|
Tri Dao
|
b17c6fe235
|
Reduce smem usage for Q and dO in the backward pass
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO
have 2 buffers)
|
2022-06-03 16:59:11 -07:00 |
|
Tri Dao
|
2712aa4c8d
|
Support Turing mma instructions
|
2022-06-03 16:58:44 -07:00 |
|
Tri Dao
|
050873327e
|
Remove softmax fp16 max
|
2022-06-02 14:09:46 -07:00 |
|
Tri Dao
|
14dc326e59
|
Use Cutlass gemm as WarpMma
|
2022-06-02 10:33:32 -07:00 |
|
Tri Dao
|
e78e7c9553
|
Remove old backward
|
2022-06-02 10:13:44 -07:00 |
|
Tri Dao
|
c41479d66d
|
Support SM86 GPUs
|
2022-06-01 18:49:47 -07:00 |
|
Tri Dao
|
9dbc491aa5
|
Rename, add benchmarking script
|
2022-05-26 13:57:38 -07:00 |
|