diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index e5bcac7..886b1ef 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -456,9 +456,9 @@ struct Gmem_summary_stats { : ptr_(reinterpret_cast(ptr)), tidx_(tidx) { // The block index for the batch. - const int bidb = blockIdx.y; + const int bidb = blockIdx.x; // The block index for the head. - const int bidh = blockIdx.x; + const int bidh = blockIdx.y; // The block index. // size_t bidx = bidb * params.h + bidh; uint32_t bidx = bidb * params.h + bidh; diff --git a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu index d836bc4..706f898 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu @@ -45,7 +45,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_ FMHA_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } - dim3 grid(params.h, params.b); + dim3 grid(params.b, params.h); kernel<<>>(params); FMHA_CHECK_CUDA(cudaPeekAtLastError()); } diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index a6dc8ff..714d02a 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -118,9 +118,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // The block index for the batch. - const int bidb = blockIdx.y; + const int bidb = blockIdx.x; // The block index for the head. - const int bidh = blockIdx.x; + const int bidh = blockIdx.y; // The thread index. const int tidx = threadIdx.x; @@ -729,9 +729,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params ¶ms) { constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; // The block index for the batch. - const int bidb = blockIdx.y; + const int bidb = blockIdx.x; // The block index for the head. - const int bidh = blockIdx.x; + const int bidh = blockIdx.y; // The thread index. const int tidx = threadIdx.x; diff --git a/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu index 82c6ab3..6fdb9a2 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu @@ -68,7 +68,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params>>( launch_params.params); diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h index 5549d4b..fe1920b 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h @@ -497,9 +497,9 @@ template>>(params); FMHA_CHECK_CUDA(cudaPeekAtLastError()); } diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 2185dc1..fca0b35 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -119,9 +119,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // The block index for the batch. - const int bidb = blockIdx.y; + const int bidb = blockIdx.x; // The block index for the head. - const int bidh = blockIdx.x; + const int bidh = blockIdx.y; // The thread index. const int tidx = threadIdx.x; @@ -683,9 +683,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) { constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; // The block index for the batch. - const int bidb = blockIdx.y; + const int bidb = blockIdx.x; // The block index for the head. - const int bidh = blockIdx.x; + const int bidh = blockIdx.y; // The thread index. const int tidx = threadIdx.x; diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index a62d4d7..f591fe3 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -68,7 +68,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params>>( launch_params.params); diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 6502878..09849c3 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -621,9 +621,9 @@ template