diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index da95634..50ff5aa 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -236,7 +236,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; + int blocksize_c = (head_size == 128 && (!is_sm80)) ? 128 : 256; // Need to round max_seqlen_k to multiples of blocksize_c int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; if( max_seqlen_k_ <= 128 ) { diff --git a/csrc/flash_attn/src/fmha/mask.h b/csrc/flash_attn/src/fmha/mask.h index dadb665..6c80929 100644 --- a/csrc/flash_attn/src/fmha/mask.h +++ b/csrc/flash_attn/src/fmha/mask.h @@ -63,7 +63,8 @@ struct Mask { const bool col_valid = current_col < actual_seqlen_k; // const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k; //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) { // printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid); // } return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; diff --git a/csrc/flash_attn/src/fmha/smem_tile.h b/csrc/flash_attn/src/fmha/smem_tile.h index 4be3809..491253b 100644 --- a/csrc/flash_attn/src/fmha/smem_tile.h +++ b/csrc/flash_attn/src/fmha/smem_tile.h @@ -1646,6 +1646,19 @@ struct Smem_tile_dp_sum { } } + inline __device__ void store_pair(const float (&sum)[MMAS_M * 2]) { + float *smem_write = smem_; + // Extract the position in the warp. + int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + int row = lane / 4; + #pragma unroll + for (int mi = 0; mi < MMAS_M; ++mi) { + smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0]; + smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1]; + } + } + inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) { float *smem_write = smem_ + buffer_idx * ROWS; // Extract the position in the warp. diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index c4783ee..bd87437 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -277,73 +277,6 @@ struct Softmax_base { // } // } - template - inline __device__ void apply_dropout(Philox &ph, uint32_t p_dropout_in_uint) { - // We encode the dropout pattern in the sign bit of the non-negative - // softmax to distinguish from pre-existing zeros - auto encode_dropout = [](bool keep, float val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); - }; - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; mi++ ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ni++ ) { - uint4 tmp = ph(); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); - // } - elt_[mi][4 * ni + 0] = - encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]); - elt_[mi][4 * ni + 1] = - encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]); - elt_[mi][4 * ni + 2] = - encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]); - elt_[mi][4 * ni + 3] = - encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]); - } - } - } - - template - inline __device__ void apply_dropout(Philox &ph0, Philox &ph1, uint32_t p_dropout_in_uint) { - // We encode the dropout pattern in the sign bit of the non-negative - // softmax to distinguish from pre-existing zeros - auto encode_dropout = [](bool keep, float val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); - }; - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; mi++ ) { - static_assert(MMAS_N % 2 == 0); - #pragma unroll - for( int ni = 0; ni < MMAS_N; ni += 2 ) { - uint4 tmp = ph0(); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("ni = %d, ph0, Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); - // } - elt_[mi][4 * ni + 0] = - encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]); - elt_[mi][4 * ni + 1] = - encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]); - elt_[mi][4 * ni + 2] = - encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]); - elt_[mi][4 * ni + 3] = - encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]); - tmp = ph1(); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("ni = %d, ph1, Philox: %u, %u, %u, %u\n", ni + 1, tmp.x, tmp.y, tmp.z, tmp.w); - // } - elt_[mi][4 * (ni + 1) + 0] = - encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 0]); - elt_[mi][4 * (ni + 1) + 1] = - encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 1]); - elt_[mi][4 * (ni + 1) + 2] = - encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 2]); - elt_[mi][4 * (ni + 1) + 3] = - encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 3]); - } - } - } - template inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) { // We encode the dropout pattern in the sign bit of the non-negative @@ -356,9 +289,44 @@ struct Softmax_base { #pragma unroll for( int ni = 0; ni < MMAS_N; ni++ ) { uint16_t tmp[8]; - fmha::uint4_to_ushort8(ph(), tmp); + // fmha::uint4_to_ushort8(ph(), tmp); + uint4 tmp_32 = ph(); + fmha::uint4_to_ushort8(tmp_32, tmp); + // if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w); + // } + #pragma unroll + for (int ii = 0; ii < 2; ++ii) { + #pragma unroll + for (int jj = 0; jj < 4; ++jj) { + elt_[mi * 2 + ii][4 * ni + jj] = + encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]); + } + } + } + } + } + + template + inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t, + unsigned long long philox_subsequence) { + // We encode the dropout pattern in the sign bit of the non-negative + // softmax to distinguish from pre-existing zeros + auto encode_dropout = [](bool keep, float val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0)); + }; + static_assert(MMAS_M == 1); // We're assuming 16x16 blocks. + #pragma unroll + for( int mi = 0; mi < MMAS_M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ni++ ) { + uint16_t tmp[8]; + // fmha::uint4_to_ushort8(ph(), tmp); + fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp); + // uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N); + // fmha::uint4_to_ushort8(tmp_32, tmp); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w); + // printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w); // } #pragma unroll for (int ii = 0; ii < 2; ++ii) { diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 8fc0791..6af2e16 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -334,7 +334,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng if (Is_dropout) { // softmax.apply_dropout(ph, params.p_dropout_in_uint); // softmax.template apply_dropout(ph, params.p_dropout_in_uint); - softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t); + // softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t); + unsigned int warp_idx = threadIdx.x / 32; + // TODO: this should change after we rearrange the warps (e.g. cutlass branch) + unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx; + unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx; + softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t, philox_subsequence); } using Frag_p = fmha::Fragment_a; @@ -676,9 +681,8 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) { // The thread index. const int tidx = threadIdx.x; - const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx; auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); + Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); if (loop_steps == 1) { compute_dq_dk_dv_1xN_one_iter(params, ph, 0); diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 32e793c..8f41505 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -115,25 +115,15 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else if( launch_params.params.seqlen_k >= 256 ) { - if (dprops->major == 8 && dprops->minor >= 0) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if (dprops->major == 7 && dprops->minor == 5) { - if (launch_params.is_dropout) { // Need to use the same block size as backward - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); } } else if (launch_params.params.d == 128) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else { - if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) { + if (dprops->major == 8 && dprops->minor == 0) { // TD [2022-06-05] Keep K in registers to reduce register spilling // Gives about 6% speedup compared to using block size 128. using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index f156018..f2fa8ad 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){ } template -inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { +inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph, const int loop_step_idx) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using elem_type = typename Kernel_traits::elem_type; @@ -470,9 +470,17 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i constexpr bool encode_dropout_in_sign_bit = Return_softmax; if (Is_dropout) { - // softmax.template apply_dropout(ph0, params.p_dropout_in_uint); - // softmax.template apply_dropout(ph0, ph1, params.p_dropout_in_uint); - softmax.template apply_dropout_16bits(ph0, ph1, params.p_dropout_in_uint16_t); + // softmax.template apply_dropout(ph, params.p_dropout_in_uint); + // softmax.template apply_dropout(ph, ph1, params.p_dropout_in_uint); + // softmax.template apply_dropout_16bits(ph, ph1, params.p_dropout_in_uint16_t); + unsigned int warp_idx = threadIdx.x / 32; + // TODO: this should change after we rearrange the warps (e.g. cutlass branch) + unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx; + // We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded + // differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k + // to multiples of 256 while bwd rounds seqlen_k to multiples of 128. + unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx; + softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t, philox_subsequence); } using Frag_p = fmha::Fragment_a; @@ -650,23 +658,28 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { // The thread index. const int tidx = threadIdx.x; - const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern. auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); - Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); + Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); constexpr int M = Kernel_traits::Cta_tile_p::M; const int STEPS = (params.seqlen_q + M - 1) / M; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; if (params.seqlen_k == blocksize_c) { - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph, 0); } else { const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph, 0); for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph, loop_step_idx); } - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1); + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph, max_loop_steps - 1); } } diff --git a/csrc/flash_attn/src/philox.cuh b/csrc/flash_attn/src/philox.cuh index e5af22e..a1e4c64 100644 --- a/csrc/flash_attn/src/philox.cuh +++ b/csrc/flash_attn/src/philox.cuh @@ -1,3 +1,4 @@ +// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/philox.cuh // Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu #pragma once // Philox CUDA. @@ -9,8 +10,7 @@ public: __device__ inline Philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) - : STATE(0) - , key(reinterpret_cast(seed)) { + : key(reinterpret_cast(seed)) { //key.x = (unsigned int)seed; //key.y = (unsigned int)(seed >> 32); //counter = make_uint4(0, 0, 0, 0); @@ -19,7 +19,6 @@ public: //STATE = 0; //incr_n(offset / 4); - // key = reinterpret_cast(seed); ull2 * tmp = reinterpret_cast(&counter); tmp->x = offset / 4; tmp->y = subsequence; @@ -27,34 +26,46 @@ public: // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); // } } + __device__ inline uint4 operator()() { - // if (STATE == 0) { - uint4 counter_ = counter; - uint2 key_ = key; - // 7-round philox - #pragma unroll - for (int i = 0; i < 6; i++) { - counter_ = single_round(counter_, key_); - key_.x += (kPhilox10A); - key_.y += (kPhilox10B); - } - // output = single_round(counter_, key_); - uint4 output = single_round(counter_, key_); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); - // } - incr(); + uint4 counter_ = counter; + uint2 key_ = key; + // 7-round philox + #pragma unroll + for (int i = 0; i < 6; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); + key_.y += (kPhilox10B); + } + uint4 output = single_round(counter_, key_); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // } + incr(); + return output; + } + + __device__ inline uint4 operator()(const unsigned long long subsequence) { + uint4 counter_ = counter; + ull2 * tmp = reinterpret_cast(&counter_); + tmp->y = subsequence; + // if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("tidx = %d, counter_: %u, %u, %u, %u\n", threadIdx.x, counter_.x, counter_.y, counter_.z, counter_.w); + // } + uint2 key_ = key; + // 7-round philox + #pragma unroll + for (int i = 0; i < 6; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); + key_.y += (kPhilox10B); + } + uint4 output = single_round(counter_, key_); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); // } - // return a float4 directly - // unsigned long ret; - // switch(STATE) { - // case 0: ret = output.x; break; - // case 1: ret = output.y; break; - // case 2: ret = output.z; break; - // case 3: ret = output.w; break; - //} - // STATE = (STATE + 1) % 4; return output; } @@ -64,25 +75,23 @@ private: uint64_t y; }; uint4 counter; - // uint4 output; const uint2 key; - unsigned int STATE; - __device__ inline void incr_n(unsigned long long n) { - unsigned int nlo = (unsigned int)(n); - unsigned int nhi = (unsigned int)(n >> 32); - counter.x += nlo; - if (counter.x < nlo) - nhi++; - counter.y += nhi; - if (nhi <= counter.y) - return; - if (++counter.z) - return; - ++counter.w; - } - __device__ uint4 incr128 (uint4 ctr) - { + // __device__ inline void incr_n(unsigned long long n) { + // unsigned int nlo = (unsigned int)(n); + // unsigned int nhi = (unsigned int)(n >> 32); + // counter.x += nlo; + // if (counter.x < nlo) + // nhi++; + // counter.y += nhi; + // if (nhi <= counter.y) + // return; + // if (++counter.z) + // return; + // ++counter.w; + // } + + __device__ uint4 incr(uint4 ctr) { uint4 res; asm ("add.cc.u32 %0, %4, %8;\n\t" "addc.cc.u32 %1, %5, %9;\n\t" @@ -98,42 +107,46 @@ private: // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); // } - counter = incr128(counter); + counter = incr(counter); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); // } } - __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, - unsigned int *result_high) { - *result_high = __umulhi(a, b); - return a * b; - } - __device__ uint2 mulhilo32_v2 (const unsigned int a, const unsigned int b) - { + + // __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + // unsigned int *result_high) { + // *result_high = __umulhi(a, b); + // return a * b; + // } + + __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { uint2 *res; unsigned long long tmp; asm ("mul.wide.u32 %0, %1, %2;\n\t" - : "=l"(tmp) - : "r"(a), "r"(b)); + : "=l"(tmp) + : "r"(a), "r"(b)); res = (uint2*)(&tmp); return *res; } + __device__ inline uint4 single_round(const uint4 ctr, const uint2 key) { //unsigned int hi0; //unsigned int hi1; //unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); //unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); //uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; - uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x); - uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z); + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; return ret; } + static const unsigned long kPhilox10A = 0x9E3779B9; static const unsigned long kPhilox10B = 0xBB67AE85; static const unsigned long kPhiloxSA = 0xD2511F53; static const unsigned long kPhiloxSB = 0xCD9E8D57; }; + // Inverse of 2^32. constexpr float M_RAN_INVM32 = 2.3283064e-10f; __device__ __inline__ float4 uniform4(const uint4 x) { diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a6d8b4f..9a8244a 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -7,12 +7,10 @@ import flash_attn_cuda def _get_block_size(device, head_dim, is_dropout): assert head_dim in [16, 32, 64, 128] - if head_dim in [16, 32]: + if head_dim in [16, 32, 64]: return 256 - elif head_dim == 64: - return 128 if (torch.cuda.get_device_capability(device) == (7, 5) and is_dropout) else 256 elif head_dim == 128: - return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128 + return 256 if (torch.cuda.get_device_capability(device) == (8, 0)) else 128 def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 78885c0..1f4ab56 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -621,7 +621,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): @pytest.mark.parametrize('seqlen', [512]) @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_unpadded_qkvpacked_split(seqlen, d, dropout_p, causal, dtype): +def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = 'cuda'