Commit Graph

55 Commits

Author SHA1 Message Date
Tri Dao
1aa6d7d9b6 Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
52fb4b729b Fix #54: set device for multi-GPU case 2022-10-16 12:51:26 -07:00
Tri Dao
5badfb7848 Implement attention kernel that splits the batch into two 2022-10-13 20:49:02 -07:00
Tri Dao
0c01568daf Only run backward test for d=128 on A100 2022-10-04 18:06:08 -07:00
Tri Dao
2ed471ecc4 Add tests for numerical error 2022-07-22 17:54:09 -04:00