From c3f2a632aa558e15e97a23d18580dc1cfebf0398 Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Mon, 28 Aug 2023 19:00:22 +0200 Subject: [PATCH] [ft_attention] Fix for seqlen=8136 (#488) When seqlen=8136, `smem_sz = 48840`, and apparently starting the kernel returns an `invalid argument` CUDA error. `48840 < 48 * 1024` but apparently it's still above the limit somehow..? Tested on A100 --- csrc/ft_attention/decoder_masked_multihead_attention.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.cu b/csrc/ft_attention/decoder_masked_multihead_attention.cu index 6e5d5a2..13306f7 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention.cu +++ b/csrc/ft_attention/decoder_masked_multihead_attention.cu @@ -31,9 +31,7 @@ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ auto kernel = mmha::masked_multihead_attention_kernel; \ - if (smem_sz >= 48 * 1024) { \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - } \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \ kernel<<>>(params)