diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 56aae44..919173e 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -294,7 +294,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q is_causal, num_splits); - run_fmha_fp16_sm80(launch_params, /*configure=*/ true); // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. @@ -307,7 +306,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); } - run_fmha_fp16_sm80(launch_params, /*configure=*/false); + run_fmha_fp16_sm80(launch_params); std::vector result = {softmax_lse}; if (return_softmax) {result.push_back(s);} @@ -453,9 +452,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // We're gonna reset the rng state in Python after this kernel, so the counter offset - // here doesn't matter at all. We just choose an arbitrary number. - int64_t counter_offset = 4; + // We use a custom RNG that increases the offset by batch_size * nheads * 32. + int64_t counter_offset = params.b * params.h * 32; if( is_dropout ) { // See Note [Acquire lock when using random generators] diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 48ac1c2..0f34fa7 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -191,7 +191,7 @@ struct Launch_params{ //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_fmha_fp16_sm80(Launch_params &launch_params, const bool configure); +void run_fmha_fp16_sm80(Launch_params &launch_params); void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index aa39992..21ad40b 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -65,22 +65,10 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { } template -void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, - const bool configure) { +void run_fmha_fp16_sm80_loop_(Launch_params &launch_params) { constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; - if (configure) { - using Mma_tile_p = fmha::Hmma_tile; - constexpr int M = Kernel_traits::Cta_tile_p::M; - size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M; - constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; - constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; - size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; - launch_params.elts_per_thread = elts_per_head; - return; - } - constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; // Don't need smem_size_softmax_lse if we're not looping const int smem_size = fmha::get_dynamic_smem_size() @@ -123,38 +111,37 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, }); } -void run_fmha_fp16_sm80(Launch_params &launch_params, - const bool configure) { +void run_fmha_fp16_sm80(Launch_params &launch_params) { FP16_SWITCH(launch_params.params.is_bf16, [&] { auto dprops = at::cuda::getCurrentDeviceProperties(); if (launch_params.params.d == 16) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + run_fmha_fp16_sm80_loop_(launch_params); } else if( launch_params.params.seqlen_k == 256 ) { using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + run_fmha_fp16_sm80_loop_(launch_params); } else { // TD [2022-05-15] 512 gives wrong results rn // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + run_fmha_fp16_sm80_loop_(launch_params); } } else if (launch_params.params.d == 32) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + run_fmha_fp16_sm80_loop_(launch_params); } else if( launch_params.params.seqlen_k >= 256 ) { using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + run_fmha_fp16_sm80_loop_(launch_params); } } else if (launch_params.params.d == 64) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + run_fmha_fp16_sm80_loop_(launch_params); } else if( launch_params.params.seqlen_k >= 256 ) { using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + run_fmha_fp16_sm80_loop_(launch_params); } } else if (launch_params.params.d == 128) { // TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory @@ -166,30 +153,30 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, // For causal=True, block size 128 seems always faster (for small & large batch size). // So we're just gonna use block size 128 for simplicity. using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); + run_fmha_fp16_sm80_loop_(launch_params); } // if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // run_fmha_fp16_sm80_loop_(launch_params); // } // if (launch_params.params.d == 64) { // if( launch_params.params.seqlen_k == 128 ) { // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // run_fmha_fp16_sm80_loop_(launch_params); // } else if( launch_params.params.seqlen_k >= 256 ) { // if (dprops->major == 8 && dprops->minor >= 0) { // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // run_fmha_fp16_sm80_loop_(launch_params); // } else if (dprops->major == 7 && dprops->minor == 5) { // if (launch_params.is_dropout) { // Need to use the same block size as backward // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // run_fmha_fp16_sm80_loop_(launch_params); // } else { // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // run_fmha_fp16_sm80_loop_(launch_params); // } // } // } @@ -197,16 +184,16 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, // if (launch_params.params.d == 128) { // if( launch_params.params.seqlen_k == 128 ) { // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // run_fmha_fp16_sm80_loop_(launch_params); // } else { // if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { // // TD [2022-06-05] Keep K in registers to reduce register spilling // // Gives about 6% speedup compared to using block size 128. // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // run_fmha_fp16_sm80_loop_(launch_params); // } else { // Need to use the same block size as backward // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); + // run_fmha_fp16_sm80_loop_(launch_params); // } // } // } diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 13cb01d..8afb8da 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){ } template -inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, int step_stride, Prng &ph, const int loop_step_idx) { +inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using elem_type = typename Kernel_traits::elem_type; @@ -250,6 +250,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // The thread index. const int tidx = threadIdx.x; + // How many steps to jump per iteration, which is the same as params.num_splits. + const int step_stride = gridDim.z; + const BlockInfoPadded binfo(params, bidb, bidh, tidx); // if( binfo.stop_early() ) return; if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; @@ -683,14 +686,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; if (params.seqlen_k == blocksize_c) { - fmha::device_1xN_(params, bidb, bidh, STEPS, gridDim.z, ph, 0); + fmha::device_1xN_(params, bidb, bidh, STEPS, ph, 0); } else { const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; - fmha::device_1xN_(params, bidb, bidh, STEPS, gridDim.z, ph, 0); + fmha::device_1xN_(params, bidb, bidh, STEPS, ph, 0); for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { - fmha::device_1xN_(params, bidb, bidh, STEPS, gridDim.z, ph, loop_step_idx); + fmha::device_1xN_(params, bidb, bidh, STEPS, ph, loop_step_idx); } - fmha::device_1xN_(params, bidb, bidh, STEPS, gridDim.z, ph, max_loop_steps - 1); + fmha::device_1xN_(params, bidb, bidh, STEPS, ph, max_loop_steps - 1); } }