diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index d0852b6..c38d0fd 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -22,7 +22,7 @@ constexpr int D_DIM = 2; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { - using index_t = uint32_t; + using index_t = int64_t; // The QKV matrices. void *__restrict__ q_ptr; void *__restrict__ k_ptr; @@ -99,7 +99,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ rotary_sin_ptr; // The indices to index into the KV cache. - int *__restrict__ cache_batch_idx; + int * __restrict__ cache_batch_idx; // The dropout probability (probability of keeping an activation). float p_dropout;