Tri Dao
e45a46a5b7
[Rotary] Implement GPT-J style (interleaved) rotary
2023-03-14 14:35:53 -07:00
Tri Dao
6b4a48218e
[FA] Remove unused variable rng_engine_inputs
2023-01-25 15:32:40 -08:00
Tri Dao
eb33e587e9
[LayerNorm] Rename x1 -> residual
2023-01-19 13:07:27 -08:00
Tri Dao
88173a1aaf
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
2023-01-17 18:12:27 -08:00
Tri Dao
f1e01c27ba
[Gen] Pass qkv_stride to ft_attention kernel for batched generation
2023-01-15 15:20:01 -08:00
Tri Dao
7c2191542a
[Gen] Make generation work with Tensor Parallel
2023-01-15 11:34:27 -08:00
Tri Dao
6738d9477d
[LayerNorm] Implement RMS Norm
2023-01-06 17:34:22 -08:00
Tri Dao
a1f49a2b92
[Compilation] Change BOOL_SWITCH to fix Windows compilation
...
Follow xFormers's DISTPATCH_BOOL. Haven't tested it on Windows.
2023-01-06 14:40:58 -08:00
Tri Dao
be1afaa276
[Gen, FT] Use fp32 accum for FMA
2023-01-03 22:09:22 -08:00
Tri Dao
f266fc7262
[Gen, FT] Use tlength instead of params.timestep for rotary
2023-01-03 17:46:55 -08:00
Tri Dao
a01d1213d7
[Gen] Add kernel from FasterTransformer for benchmarking
2023-01-03 17:37:43 -08:00
Tri Dao
a8cfe51551
Implement Tensor Parallel for transformer Block
2022-12-25 14:08:21 -08:00
Tri Dao
1e712ea8b0
Implement TensorParallel for MHA
2022-12-25 11:39:55 -08:00
Tri Dao
226a1b721d
Implement TensorParallel for FusedDense and FusedDenseGeluDense
2022-12-24 11:48:56 -08:00
Tri Dao
dff68c2b22
Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss
2022-12-23 14:51:08 -08:00
Tri Dao
e68ebbe89a
Simplify FusedDense
2022-12-22 21:25:31 -08:00
Tri Dao
5db330519a
[LayerNorm] Support taking subset of input or subset of output
2022-12-12 22:16:14 -08:00
Tri Dao
ae137ed17a
[LayerNorm] Fuse LayerScale
2022-12-10 23:28:23 -08:00
Tri Dao
8c6609ae1a
[LayerNorm] Support all dimensions up to 6k (if divisible by 8)
2022-12-09 02:06:22 -08:00
Tri Dao
8a2ece89f7
Simplify BOOL_SWITCH macro to fix compiling error on gcc 7
2022-12-06 14:38:32 -08:00
Tri Dao
0bf5e50038
Release training code
2022-11-28 17:34:40 -08:00
Tri Dao
9bc63d1e2d
Fix typo in comments
2022-11-25 16:35:08 -08:00
Tri Dao
d95ee1a95d
Speed up compilation by splitting into separate .cu files
2022-11-25 16:30:18 -08:00
Tri Dao
39ed597b28
[LayerNorm] Compile for both sm70 and sm80
2022-11-17 11:45:11 -08:00
Tri Dao
43ab0b5205
Mention that some CUDA extensions have only been tested on A100s
2022-11-15 07:10:25 -08:00
Tri Dao
e4d3013e15
[LayerNorm] Check cuda error after querying ctas_per_sm
2022-11-15 07:05:13 -08:00
Tri Dao
2e33fc8e36
Add GPT and ViT models
2022-11-13 22:30:23 -08:00
Tri Dao
fa6d1ce44f
Add fused_dense and dropout_add_layernorm CUDA extensions
2022-11-13 21:59:20 -08:00
Tri Dao
7c9953815a
Add fused cross entropy loss
2022-11-12 21:58:41 -08:00
Tri Dao
6998e0ecdb
Fix out-of-bound memory read
2022-11-09 09:34:14 -08:00
Tri Dao
557781933d
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
...
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
ca81f32e04
Implement rotary embedding in CUDA
2022-11-04 22:42:01 -07:00
Tri Dao
c422fee377
Get rid of o_rows_are_valid since we don't have headdim=16 anymore
2022-10-24 17:29:36 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
97e13de2b4
Cast q.get_device() to char to avoid compiler warning (narrowing)
2022-10-24 15:59:49 -07:00
Tri Dao
ed553e9238
Add Megatron attention implementation for benchmarking
2022-10-23 23:04:16 -07:00
Tri Dao
9e92a1f2d2
Attempt to use atomicCAS to replace atomicAdd(bfloat16)
2022-10-23 16:22:43 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
871db47941
Don't need to run configure for the forward pass
2022-10-21 18:22:27 -07:00
Tri Dao
7fc39832e2
Use block_size=128 for headdim=128 on SM80
...
Previously we were using block_size=256.
2022-10-21 13:19:54 -07:00
Tri Dao
a44f48df5a
Split fwd on the seqlen_q dimension
2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
YangShu
ff07250e8f
fix typo in function mha_fwd
...
as title.
2022-10-17 16:13:47 +08:00
Tri Dao
52fb4b729b
Fix #54 : set device for multi-GPU case
2022-10-16 12:51:26 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Eric Engelhart
2211db5fab
Fixed switch statement, thanks @yocabon
2022-10-04 21:31:39 -04:00
Eric Engelhart
9d7fd5b6e7
Replace BOOL_SWITCH with FP16_SWITCH to work around MSVC bug with constexpr variables and templates
2022-10-04 21:31:39 -04:00
Tri Dao
8166063a55
Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit
2022-09-12 14:21:29 -07:00
Tri Dao
bc2c210254
Don't nest BOOL_SWITCH to work around gcc 7 bug
2022-07-11 10:28:46 -07:00
Tri Dao
de19de7ab1
Implement for bf16
2022-07-09 23:31:56 -07:00
Tri Dao
6a77a6da10
Refactor gemm_cl to template on either __half or __nv_bfloat16
2022-07-09 23:18:26 -07:00
Tri Dao
e518a4b327
Refactor to template on __half, implement bf16 util functions
2022-07-09 23:18:26 -07:00
Tri Dao
2dc1b205f6
Fix Illegal Memory Access bug in fwd when d=16
2022-07-09 23:17:14 -07:00
Tri Dao
5b838a8bef
Apply dropout scaling to dQ and dK instead of to V (in bwd)
...
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
a5559a0e75
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af
Implement cross attention
2022-07-03 17:48:12 -07:00
Tri Dao
f66603cb6f
Support batch size > 64K by swapping grid.x and grid.y
2022-06-29 23:16:24 -07:00
Tri Dao
c0daa62eaa
Add type check (fp16) in the forward pass
2022-06-26 11:41:30 -07:00
Tri Dao
ea38d3d261
Fix race condition in backward pass (smem_dq)
2022-06-25 18:02:30 -07:00
Tri Dao
eeca63a72a
Bug fix: wrong smem_o write pointer for d=16
2022-06-25 15:18:33 -07:00
Tri Dao
5d07483bbc
Refactor Gmem code to store q, k, v pointers separately
2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958
Implement bwd for head dim 128
2022-06-11 17:52:36 -07:00
Tri Dao
0d854692c6
Implement fwd for head dim 128
2022-06-11 17:52:36 -07:00
Tri Dao
321c57d07d
Set block size of SM75 fwd to 256 if there's no dropout
...
This speeds up the fwd by 1.5x.
2022-06-04 16:51:28 -07:00
Tri Dao
d380e87fb6
Don't use Smem_dp_sum in backward pass
...
To reduce smem usage for SM75
2022-06-04 16:01:36 -07:00
Tri Dao
b17c6fe235
Reduce smem usage for Q and dO in the backward pass
...
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO
have 2 buffers)
2022-06-03 16:59:11 -07:00
Tri Dao
2712aa4c8d
Support Turing mma instructions
2022-06-03 16:58:44 -07:00
Tri Dao
050873327e
Remove softmax fp16 max
2022-06-02 14:09:46 -07:00
Tri Dao
14dc326e59
Use Cutlass gemm as WarpMma
2022-06-02 10:33:32 -07:00
Tri Dao
e78e7c9553
Remove old backward
2022-06-02 10:13:44 -07:00
Tri Dao
512c98ee05
Add Cutlass as submodule
2022-06-02 09:54:16 -07:00
Tri Dao
5a61cb7729
Rename src -> flash_attn
2022-06-01 18:50:26 -07:00
Tri Dao
c41479d66d
Support SM86 GPUs
2022-06-01 18:49:47 -07:00
Tri Dao
9dbc491aa5
Rename, add benchmarking script
2022-05-26 13:57:38 -07:00
Tri Dao
1fcbe6f0d0
First release
2022-05-20 14:21:58 -07:00