Reduce number of templates for headdim > 128
This commit is contained in:
parent
dd9a6fa45a
commit
1879e089c7
@ -63,7 +63,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst && IsEvenKConst, IsEvenKConst>;
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
|
||||
@ -45,7 +45,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst && IsEvenKConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst && IsEvenKConst && (!ReturnSoftmaxConst) && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
@ -78,7 +80,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst, IsEvenKConst, Split, Append_KV>;
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
|
||||
@ -371,6 +371,7 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
# or if we're switching from inference mode to training
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user