diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index cdc7608..be5fd77 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -24,7 +24,7 @@ struct Flash_kernel_traits { #endif using ElementAccum = float; - using index_t = uint32_t; + using index_t = int64_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t<