Set block size to 64 x 64 for kvcache to avoid nvcc segfaults

This commit is contained in:
Tri Dao 2023-09-17 16:14:58 -07:00
parent 8c8b4d36e1
commit c984208ddb
3 changed files with 8 additions and 14 deletions

View File

@ -115,18 +115,12 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
constexpr int kBlockM = 64; // Fixed for all head dimensions
if (!is_sm8x) { // A100, H100
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
} else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}
template<typename T>

View File

@ -1,4 +1,4 @@
__version__ = "2.2.3"
__version__ = "2.2.3.post1"
from flash_attn.flash_attn_interface import (
flash_attn_func,

View File

@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==2.2.3
RUN pip install flash-attn==2.2.3.post1
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v2.2.3 \
&& cd flash-attention && git checkout v2.2.3.post1 \
&& cd csrc/fused_softmax && pip install . && cd ../../ \
&& cd csrc/rotary && pip install . && cd ../../ \
&& cd csrc/layer_norm && pip install . && cd ../../ \