Set block size to 64 x 64 for kvcache to avoid nvcc segfaults
This commit is contained in:
parent
8c8b4d36e1
commit
c984208ddb
@ -115,18 +115,12 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
constexpr int kBlockM = 64; // Fixed for all head dimensions
|
||||
if (!is_sm8x) { // A100, H100
|
||||
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
||||
// and for headdim 192 with block size 64 x 128.
|
||||
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64);
|
||||
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
||||
} else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
|
||||
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
||||
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
||||
// and for headdim 192 with block size 64 x 128.
|
||||
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
|
||||
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
||||
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
__version__ = "2.2.3"
|
||||
__version__ = "2.2.3.post1"
|
||||
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_func,
|
||||
|
||||
@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
|
||||
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
|
||||
|
||||
# Install FlashAttention
|
||||
RUN pip install flash-attn==2.2.3
|
||||
RUN pip install flash-attn==2.2.3.post1
|
||||
|
||||
# Install CUDA extensions for cross-entropy, fused dense, layer norm
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention \
|
||||
&& cd flash-attention && git checkout v2.2.3 \
|
||||
&& cd flash-attention && git checkout v2.2.3.post1 \
|
||||
&& cd csrc/fused_softmax && pip install . && cd ../../ \
|
||||
&& cd csrc/rotary && pip install . && cd ../../ \
|
||||
&& cd csrc/layer_norm && pip install . && cd ../../ \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user