Kirthi Shankar Sivamani
|
45567a25a2
|
only 1 thread writes to global mem in fprop
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
2023-04-15 06:09:41 +00:00 |
|
Kirthi Shankar Sivamani
|
31018c5fa0
|
Support CUDA graph capture
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
2023-04-12 16:53:22 -07:00 |
|
Tri Dao
|
6998e0ecdb
|
Fix out-of-bound memory read
|
2022-11-09 09:34:14 -08:00 |
|
Tri Dao
|
c422fee377
|
Get rid of o_rows_are_valid since we don't have headdim=16 anymore
|
2022-10-24 17:29:36 -07:00 |
|
Tri Dao
|
46fd2a20b2
|
Support all head dims that are multiples of 8, up to 128
|
2022-10-24 16:04:21 -07:00 |
|
Tri Dao
|
871db47941
|
Don't need to run configure for the forward pass
|
2022-10-21 18:22:27 -07:00 |
|
Tri Dao
|
a44f48df5a
|
Split fwd on the seqlen_q dimension
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
1aa6d7d9b6
|
Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
5badfb7848
|
Implement attention kernel that splits the batch into two
|
2022-10-13 20:49:02 -07:00 |
|
Tri Dao
|
de19de7ab1
|
Implement for bf16
|
2022-07-09 23:31:56 -07:00 |
|
Tri Dao
|
6a77a6da10
|
Refactor gemm_cl to template on either __half or __nv_bfloat16
|
2022-07-09 23:18:26 -07:00 |
|
Tri Dao
|
e518a4b327
|
Refactor to template on __half, implement bf16 util functions
|
2022-07-09 23:18:26 -07:00 |
|
Tri Dao
|
2dc1b205f6
|
Fix Illegal Memory Access bug in fwd when d=16
|
2022-07-09 23:17:14 -07:00 |
|
Tri Dao
|
6c3a8c65af
|
Implement cross attention
|
2022-07-03 17:48:12 -07:00 |
|
Tri Dao
|
f66603cb6f
|
Support batch size > 64K by swapping grid.x and grid.y
|
2022-06-29 23:16:24 -07:00 |
|
Tri Dao
|
5d07483bbc
|
Refactor Gmem code to store q, k, v pointers separately
|
2022-06-12 16:37:32 -07:00 |
|
Tri Dao
|
d3e6440958
|
Implement bwd for head dim 128
|
2022-06-11 17:52:36 -07:00 |
|
Tri Dao
|
0d854692c6
|
Implement fwd for head dim 128
|
2022-06-11 17:52:36 -07:00 |
|
Tri Dao
|
14dc326e59
|
Use Cutlass gemm as WarpMma
|
2022-06-02 10:33:32 -07:00 |
|
Tri Dao
|
9dbc491aa5
|
Rename, add benchmarking script
|
2022-05-26 13:57:38 -07:00 |
|