Support batch size > 64K by swapping grid.x and grid.y
This commit is contained in:
parent
450b64fe44
commit
f66603cb6f
@ -456,9 +456,9 @@ struct Gmem_summary_stats {
|
||||
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
// The block index.
|
||||
// size_t bidx = bidb * params.h + bidh;
|
||||
uint32_t bidx = bidb * params.h + bidh;
|
||||
|
||||
@ -45,7 +45,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
dim3 grid(params.h, params.b);
|
||||
dim3 grid(params.b, params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -118,9 +118,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
|
||||
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
@ -729,9 +729,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params ¶ms) {
|
||||
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
|
||||
@ -68,7 +68,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 grid(launch_params.params.h, launch_params.params.b);
|
||||
dim3 grid(launch_params.params.b, launch_params.params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
|
||||
|
||||
@ -497,9 +497,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so
|
||||
inline __device__ void device_block_1xN_loop(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
dim3 grid(params.h, params.b);
|
||||
dim3 grid(params.b, params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@ -119,9 +119,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
@ -683,9 +683,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
|
||||
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
|
||||
@ -68,7 +68,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 grid(launch_params.params.h, launch_params.params.b);
|
||||
dim3 grid(launch_params.params.b, launch_params.params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
|
||||
|
||||
@ -621,9 +621,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so
|
||||
inline __device__ void device_1xN_loop(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user