Preprocessor switches to control functionality (#788)
For faster and smaller builds in some simple cases, provide switches to allow disabling -backward -alibi -uneven k -dropout -local attention Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
This commit is contained in:
parent
290596c544
commit
0658e320f6
@ -112,6 +112,9 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||
TORCH_CHECK(p_dropout < 1.f);
|
||||
#ifdef FLASHATTENTION_DISABLE_DROPOUT
|
||||
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
|
||||
#endif
|
||||
|
||||
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
@ -122,7 +125,16 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
params.window_size_left = window_size_left;
|
||||
params.window_size_right = window_size_right;
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
||||
TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
|
||||
"This flash attention build does not support local attention.");
|
||||
#endif
|
||||
|
||||
params.is_seqlens_k_cumulative = true;
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
|
||||
#endif
|
||||
}
|
||||
|
||||
void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
@ -282,6 +294,25 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
|
||||
}
|
||||
}
|
||||
|
||||
void set_params_alibi(Flash_fwd_params ¶ms, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
|
||||
#ifdef FLASHATTENTION_DISABLE_ALIBI
|
||||
TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
#else
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
|
||||
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
} else {
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
@ -435,17 +466,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
|
||||
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
} else {
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
}
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
@ -657,17 +678,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
|
||||
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
} else {
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
}
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
if (max_seqlen_k > 0) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
@ -724,6 +735,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state) {
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
||||
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
||||
#endif
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
@ -903,17 +917,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
|
||||
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
} else {
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
}
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
if (seqlen_q > 0) {
|
||||
launch(params, stream);
|
||||
@ -963,6 +967,10 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state) {
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
||||
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
||||
#endif
|
||||
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
@ -1158,17 +1166,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
|
||||
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
} else {
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
}
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
if (max_seqlen_q > 0) {
|
||||
launch(params, stream);
|
||||
@ -1435,17 +1433,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
}
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
|
||||
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
} else {
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
}
|
||||
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
|
||||
|
||||
@ -69,9 +69,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
@ -100,7 +100,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
#ifndef FLASHATTENTION_DISABLE_BACKWARD
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@ -114,7 +116,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
|
||||
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
@ -139,7 +141,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
|
||||
@ -184,7 +186,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
if constexpr(!Is_dropout) { // 92KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
@ -210,7 +212,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
@ -243,7 +245,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
@ -263,7 +265,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 136 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
@ -275,7 +277,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 224;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
});
|
||||
}
|
||||
@ -291,7 +293,7 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 176 * 1024) { // H100
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else { // A100, we don't do double buffering to save smem
|
||||
|
||||
@ -42,10 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
@ -83,11 +83,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
||||
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
||||
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If Is_local, set Is_causal to false
|
||||
@ -113,7 +113,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
|
||||
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
|
||||
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
if (params.num_splits <= 2) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 4) {
|
||||
@ -147,7 +147,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream)
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
@ -157,7 +157,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
@ -181,7 +181,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
@ -207,7 +207,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
@ -244,7 +244,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
@ -272,7 +272,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
@ -300,7 +300,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
@ -331,7 +331,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, we want to run with 128 x 64 (128KB smem).
|
||||
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
@ -25,6 +26,46 @@
|
||||
} \
|
||||
}()
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_DROPOUT
|
||||
#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define DROPOUT_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_ALIBI
|
||||
#define ALIBI_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define ALIBI_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
#define EVENK_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define EVENK_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
||||
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define LOCAL_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#define FP16_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user