flash-attention/csrc/layer_norm
2024-03-13 20:46:57 -07:00
..
ln_api.cpp Fix random state for dropout_layer_norm (#315) 2023-07-23 15:05:13 -07:00
ln_bwd_256.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_512.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_768.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_1024.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_1280.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_1536.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_2048.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_2560.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_3072.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_4096.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_5120.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_6144.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_bwd_7168.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_bwd_8192.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_bwd_kernels.cuh [LayerNorm] Rename x1 -> residual 2023-01-19 13:07:27 -08:00
ln_fwd_256.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_512.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_768.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_1024.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_1280.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_1536.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_2048.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_2560.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_3072.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_4096.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_5120.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_6144.cu [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
ln_fwd_7168.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_fwd_8192.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_fwd_kernels.cuh [LayerNorm] Rename x1 -> residual 2023-01-19 13:07:27 -08:00
ln_kernel_traits.h [LayerNorm] Fuse LayerScale 2022-12-10 23:28:23 -08:00
ln_parallel_bwd_256.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_512.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_768.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_1024.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_1280.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_1536.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_2048.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_2560.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_3072.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_4096.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_5120.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_6144.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_7168.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_bwd_8192.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_256.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_512.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_768.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_1024.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_1280.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_1536.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_2048.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_2560.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_3072.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_4096.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_5120.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_6144.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_7168.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_fwd_8192.cu [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_residual_bwd_kernels.cuh [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_parallel_residual_fwd_kernels.cuh [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln_utils.cuh [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
ln.h [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
README.md [LayerNorm] Switch from CUDA to Triton implementation 2024-01-05 00:31:17 -08:00
setup.py Make nvcc threads configurable via environment variable (#885) 2024-03-13 20:46:57 -07:00
static_switch.h Add fused_dense and dropout_add_layernorm CUDA extensions 2022-11-13 21:59:20 -08:00

This CUDA extension implements fused dropout + residual + LayerNorm, building on Apex's FastLayerNorm. Major changes:

  • Add dropout and residual.
  • Make it work for both pre-norm and post-norm architecture.
  • Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
  • Implement RMSNorm as an option.
  • Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).

If you want to use it for dimensions larger than 8k, please file an issue.

This extension has only been tested on A100s.

cd csrc/layer_norm && pip install .

As of 2024-01-05, this extension is no longer used in the FlashAttention repo. We've instead switched to a Triton-based implementation.