Support batch size > 64K by swapping grid.x and grid.y

This commit is contained in:
Tri Dao 2022-06-29 23:16:24 -07:00
parent 450b64fe44
commit f66603cb6f
9 changed files with 18 additions and 18 deletions

View File

@ -456,9 +456,9 @@ struct Gmem_summary_stats {
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
// The block index for the batch.
const int bidb = blockIdx.y;
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.x;
const int bidh = blockIdx.y;
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;

View File

@ -45,7 +45,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.h, params.b);
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}

View File

@ -118,9 +118,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// The block index for the batch.
const int bidb = blockIdx.y;
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.x;
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
@ -729,9 +729,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params &params) {
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
// The block index for the batch.
const int bidb = blockIdx.y;
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.x;
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;

View File

@ -68,7 +68,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
return;
}
dim3 grid(launch_params.params.h, launch_params.params.b);
dim3 grid(launch_params.params.b, launch_params.params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);

View File

@ -497,9 +497,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so
inline __device__ void device_block_1xN_loop(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.y;
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.x;
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;

View File

@ -44,7 +44,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.h, params.b);
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}

View File

@ -119,9 +119,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// The block index for the batch.
const int bidb = blockIdx.y;
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.x;
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
@ -683,9 +683,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
// The block index for the batch.
const int bidb = blockIdx.y;
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.x;
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;

View File

@ -68,7 +68,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
return;
}
dim3 grid(launch_params.params.h, launch_params.params.b);
dim3 grid(launch_params.params.b, launch_params.params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);

View File

@ -621,9 +621,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so
inline __device__ void device_1xN_loop(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.y;
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.x;
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;