Rework dropout to decouple forward and backward

They don't have to have the same block size, number of threads, etc.
This commit is contained in:
Tri Dao 2022-10-18 23:04:01 -07:00
parent 1d0b41be3b
commit 1aa6d7d9b6
10 changed files with 161 additions and 161 deletions

View File

@ -236,7 +236,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256;
int blocksize_c = (head_size == 128 && (!is_sm80)) ? 128 : 256;
// Need to round max_seqlen_k to multiples of blocksize_c
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
if( max_seqlen_k_ <= 128 ) {

View File

@ -63,7 +63,8 @@ struct Mask {
const bool col_valid = current_col < actual_seqlen_k;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) {
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
// }
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;

View File

@ -1646,6 +1646,19 @@ struct Smem_tile_dp_sum {
}
}
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2]) {
float *smem_write = smem_;
// Extract the position in the warp.
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
int row = lane / 4;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
}
}
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
// Extract the position in the warp.

View File

@ -277,73 +277,6 @@ struct Softmax_base {
// }
// }
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout(Philox &ph, uint32_t p_dropout_in_uint) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; mi++ ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint4 tmp = ph();
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
elt_[mi][4 * ni + 0] =
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]);
elt_[mi][4 * ni + 1] =
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]);
elt_[mi][4 * ni + 2] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]);
elt_[mi][4 * ni + 3] =
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]);
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout(Philox &ph0, Philox &ph1, uint32_t p_dropout_in_uint) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; mi++ ) {
static_assert(MMAS_N % 2 == 0);
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
uint4 tmp = ph0();
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph0, Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
elt_[mi][4 * ni + 0] =
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]);
elt_[mi][4 * ni + 1] =
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]);
elt_[mi][4 * ni + 2] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]);
elt_[mi][4 * ni + 3] =
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]);
tmp = ph1();
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph1, Philox: %u, %u, %u, %u\n", ni + 1, tmp.x, tmp.y, tmp.z, tmp.w);
// }
elt_[mi][4 * (ni + 1) + 0] =
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 0]);
elt_[mi][4 * (ni + 1) + 1] =
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 1]);
elt_[mi][4 * (ni + 1) + 2] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 2]);
elt_[mi][4 * (ni + 1) + 3] =
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 3]);
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
// We encode the dropout pattern in the sign bit of the non-negative
@ -356,9 +289,44 @@ struct Softmax_base {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint16_t tmp[8];
fmha::uint4_to_ushort8(ph(), tmp);
// fmha::uint4_to_ushort8(ph(), tmp);
uint4 tmp_32 = ph();
fmha::uint4_to_ushort8(tmp_32, tmp);
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t,
unsigned long long philox_subsequence) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
static_assert(MMAS_M == 1); // We're assuming 16x16 blocks.
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint16_t tmp[8];
// fmha::uint4_to_ushort8(ph(), tmp);
fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp);
// uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N);
// fmha::uint4_to_ushort8(tmp_32, tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {

View File

@ -334,7 +334,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
if (Is_dropout) {
// softmax.apply_dropout(ph, params.p_dropout_in_uint);
// softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint);
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t);
// softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t);
unsigned int warp_idx = threadIdx.x / 32;
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
@ -676,9 +681,8 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
// The thread index.
const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
if (loop_steps == 1) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);

View File

@ -115,25 +115,15 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k >= 256 ) {
if (dprops->major == 8 && dprops->minor >= 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
if (launch_params.is_dropout) { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 128) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) {
if (dprops->major == 8 && dprops->minor == 0) {
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;

View File

@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) {
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph, const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
@ -470,9 +470,17 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
if (Is_dropout) {
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, params.p_dropout_in_uint);
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint);
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint);
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint);
// softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint16_t);
unsigned int warp_idx = threadIdx.x / 32;
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
// We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded
// differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k
// to multiples of 256 while bwd rounds seqlen_k to multiples of 128.
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
@ -650,23 +658,28 @@ inline __device__ void device_1xN_loop(const Params &params) {
// The thread index.
const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx;
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
// them to have the same number of threads or have to traverse the attention matrix
// in the same order.
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
// (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
constexpr int M = Kernel_traits::Cta_tile_p::M;
const int STEPS = (params.seqlen_q + M - 1) / M;
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph, 0);
} else {
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx);
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph, loop_step_idx);
}
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1);
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph, max_loop_steps - 1);
}
}

View File

@ -1,3 +1,4 @@
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/philox.cuh
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
#pragma once
// Philox CUDA.
@ -9,8 +10,7 @@ public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset)
: STATE(0)
, key(reinterpret_cast<const uint2&>(seed)) {
: key(reinterpret_cast<const uint2&>(seed)) {
//key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0);
@ -19,7 +19,6 @@ public:
//STATE = 0;
//incr_n(offset / 4);
// key = reinterpret_cast<const uint2&>(seed);
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset / 4;
tmp->y = subsequence;
@ -27,34 +26,46 @@ public:
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ inline uint4 operator()() {
// if (STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
// 7-round philox
#pragma unroll
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
// output = single_round(counter_, key_);
uint4 output = single_round(counter_, key_);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// }
incr();
uint4 counter_ = counter;
uint2 key_ = key;
// 7-round philox
#pragma unroll
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
uint4 output = single_round(counter_, key_);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// }
incr();
return output;
}
__device__ inline uint4 operator()(const unsigned long long subsequence) {
uint4 counter_ = counter;
ull2 * tmp = reinterpret_cast<ull2*>(&counter_);
tmp->y = subsequence;
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("tidx = %d, counter_: %u, %u, %u, %u\n", threadIdx.x, counter_.x, counter_.y, counter_.z, counter_.w);
// }
uint2 key_ = key;
// 7-round philox
#pragma unroll
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
uint4 output = single_round(counter_, key_);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// }
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
// STATE = (STATE + 1) % 4;
return output;
}
@ -64,25 +75,23 @@ private:
uint64_t y;
};
uint4 counter;
// uint4 output;
const uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ uint4 incr128 (uint4 ctr)
{
// __device__ inline void incr_n(unsigned long long n) {
// unsigned int nlo = (unsigned int)(n);
// unsigned int nhi = (unsigned int)(n >> 32);
// counter.x += nlo;
// if (counter.x < nlo)
// nhi++;
// counter.y += nhi;
// if (nhi <= counter.y)
// return;
// if (++counter.z)
// return;
// ++counter.w;
// }
__device__ uint4 incr(uint4 ctr) {
uint4 res;
asm ("add.cc.u32 %0, %4, %8;\n\t"
"addc.cc.u32 %1, %5, %9;\n\t"
@ -98,42 +107,46 @@ private:
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
counter = incr128(counter);
counter = incr(counter);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a * b;
}
__device__ uint2 mulhilo32_v2 (const unsigned int a, const unsigned int b)
{
// __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
// unsigned int *result_high) {
// *result_high = __umulhi(a, b);
// return a * b;
// }
__device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res;
unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t"
: "=l"(tmp)
: "r"(a), "r"(b));
: "=l"(tmp)
: "r"(a), "r"(b));
res = (uint2*)(&tmp);
return *res;
}
__device__ inline uint4 single_round(const uint4 ctr, const uint2 key) {
//unsigned int hi0;
//unsigned int hi1;
//unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
//unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
//uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z);
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
constexpr float M_RAN_INVM32 = 2.3283064e-10f;
__device__ __inline__ float4 uniform4(const uint4 x) {

View File

@ -7,12 +7,10 @@ import flash_attn_cuda
def _get_block_size(device, head_dim, is_dropout):
assert head_dim in [16, 32, 64, 128]
if head_dim in [16, 32]:
if head_dim in [16, 32, 64]:
return 256
elif head_dim == 64:
return 128 if (torch.cuda.get_device_capability(device) == (7, 5) and is_dropout) else 256
elif head_dim == 128:
return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128
return 256 if (torch.cuda.get_device_capability(device) == (8, 0)) else 128
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,

View File

@ -621,7 +621,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
@pytest.mark.parametrize('seqlen', [512])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded_qkvpacked_split(seqlen, d, dropout_p, causal, dtype):
def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'