diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.cu b/csrc/ft_attention/decoder_masked_multihead_attention.cu index 6e5d5a2..13306f7 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention.cu +++ b/csrc/ft_attention/decoder_masked_multihead_attention.cu @@ -31,9 +31,7 @@ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ auto kernel = mmha::masked_multihead_attention_kernel; \ - if (smem_sz >= 48 * 1024) { \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - } \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \ kernel<<>>(params)