flash-attention/csrc/layer_norm
2023-01-06 17:34:22 -08:00
..
ln_api.cpp [LayerNorm] Implement RMS Norm 2023-01-06 17:34:22 -08:00
ln_bwd_256.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_512.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_768.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_1024.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_1280.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_1536.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_2048.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_2560.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_3072.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_4096.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_5120.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_6144.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_kernels.cuh [LayerNorm] Implement RMS Norm 2023-01-06 17:34:22 -08:00
ln_fwd_256.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_512.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_768.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_1024.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_1280.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_1536.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_2048.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_2560.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_3072.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_4096.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_5120.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_6144.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_kernels.cuh [LayerNorm] Implement RMS Norm 2023-01-06 17:34:22 -08:00
ln_kernel_traits.h [LayerNorm] Fuse LayerScale 2022-12-10 23:28:23 -08:00
ln_utils.cuh [LayerNorm] Implement RMS Norm 2023-01-06 17:34:22 -08:00
ln.h [LayerNorm] Implement RMS Norm 2023-01-06 17:34:22 -08:00
README.md [LayerNorm] Implement RMS Norm 2023-01-06 17:34:22 -08:00
setup.py [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
static_switch.h Add fused_dense and dropout_add_layernorm CUDA extensions 2022-11-13 21:59:20 -08:00

This CUDA extension implements fused dropout + residual + LayerNorm, building on Apex's FastLayerNorm. We add dropout and residual, and make it work for both pre-norm and post-norm architecture. We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144). We also implement RMSNorm as an option.

If you want to use it for dimensions larger than 6k, please file an issue.

This extension has only been tested on A100s.

cd csrc/layer_norm && pip install .