Use int64_t instead of uint32_t for index_t
This commit is contained in:
parent
e43a4ceaab
commit
000b67f5d8
@ -22,7 +22,7 @@ constexpr int D_DIM = 2;
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct Qkv_params {
|
struct Qkv_params {
|
||||||
using index_t = uint32_t;
|
using index_t = int64_t;
|
||||||
// The QKV matrices.
|
// The QKV matrices.
|
||||||
void *__restrict__ q_ptr;
|
void *__restrict__ q_ptr;
|
||||||
void *__restrict__ k_ptr;
|
void *__restrict__ k_ptr;
|
||||||
@ -99,7 +99,7 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
void * __restrict__ rotary_sin_ptr;
|
void * __restrict__ rotary_sin_ptr;
|
||||||
|
|
||||||
// The indices to index into the KV cache.
|
// The indices to index into the KV cache.
|
||||||
int *__restrict__ cache_batch_idx;
|
int * __restrict__ cache_batch_idx;
|
||||||
|
|
||||||
// The dropout probability (probability of keeping an activation).
|
// The dropout probability (probability of keeping an activation).
|
||||||
float p_dropout;
|
float p_dropout;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user