Remove template for (IsEvenMN=T, IsEvenK=F) to speed up compilation

This commit is contained in:
Tri Dao 2023-09-18 11:46:15 -07:00
parent fa3ddcbaaa
commit 43617deab9
3 changed files with 9 additions and 5 deletions

View File

@ -62,7 +62,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst>;
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst && IsEvenKConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
@ -101,7 +102,8 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst, IsEvenKConst>;
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(

View File

@ -44,7 +44,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst && IsEvenKConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
@ -76,7 +77,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV>;
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst, IsEvenKConst, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) {

View File

@ -87,7 +87,7 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==2.2.3.post2
# Install CUDA extensions for cross-entropy, fused dense, layer norm
# Install CUDA extensions for fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v2.2.3.post2 \
&& cd csrc/layer_norm && pip install . && cd ../../ \