Don't need to run configure for the forward pass
This commit is contained in:
parent
7fc39832e2
commit
871db47941
@ -294,7 +294,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
is_causal,
|
||||
num_splits);
|
||||
|
||||
run_fmha_fp16_sm80(launch_params, /*configure=*/ true);
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
@ -307,7 +306,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
run_fmha_fp16_sm80(launch_params, /*configure=*/false);
|
||||
run_fmha_fp16_sm80(launch_params);
|
||||
|
||||
std::vector<at::Tensor> result = {softmax_lse};
|
||||
if (return_softmax) {result.push_back(s);}
|
||||
@ -453,9 +452,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
// We're gonna reset the rng state in Python after this kernel, so the counter offset
|
||||
// here doesn't matter at all. We just choose an arbitrary number.
|
||||
int64_t counter_offset = 4;
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
|
||||
if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
|
||||
@ -191,7 +191,7 @@ struct Launch_params{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
|
||||
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params);
|
||||
|
||||
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream);
|
||||
|
||||
|
||||
@ -65,22 +65,10 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const bool configure) {
|
||||
void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
|
||||
if (configure) {
|
||||
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
|
||||
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
|
||||
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
|
||||
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
|
||||
launch_params.elts_per_thread = elts_per_head;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
|
||||
// Don't need smem_size_softmax_lse if we're not looping
|
||||
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
|
||||
@ -123,38 +111,37 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
});
|
||||
}
|
||||
|
||||
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const bool configure) {
|
||||
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, [&] {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (launch_params.params.d == 16) {
|
||||
if( launch_params.params.seqlen_k == 128 ) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
} else if( launch_params.params.seqlen_k == 256 ) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
} else {
|
||||
// TD [2022-05-15] 512 gives wrong results rn
|
||||
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
}
|
||||
} else if (launch_params.params.d == 32) {
|
||||
if( launch_params.params.seqlen_k == 128 ) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
} else if( launch_params.params.seqlen_k >= 256 ) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
}
|
||||
} else if (launch_params.params.d == 64) {
|
||||
if( launch_params.params.seqlen_k == 128 ) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
} else if( launch_params.params.seqlen_k >= 256 ) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
}
|
||||
} else if (launch_params.params.d == 128) {
|
||||
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
|
||||
@ -166,30 +153,30 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
// For causal=True, block size 128 seems always faster (for small & large batch size).
|
||||
// So we're just gonna use block size 128 for simplicity.
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
}
|
||||
// if (launch_params.params.d == 64) {
|
||||
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
// }
|
||||
// if (launch_params.params.d == 64) {
|
||||
// if( launch_params.params.seqlen_k == 128 ) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
// } else if( launch_params.params.seqlen_k >= 256 ) {
|
||||
// if (dprops->major == 8 && dprops->minor >= 0) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
// } else if (dprops->major == 7 && dprops->minor == 5) {
|
||||
// if (launch_params.is_dropout) { // Need to use the same block size as backward
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
// } else {
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
@ -197,16 +184,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
// if (launch_params.params.d == 128) {
|
||||
// if( launch_params.params.seqlen_k == 128 ) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
// } else {
|
||||
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
|
||||
// // TD [2022-06-05] Keep K in registers to reduce register spilling
|
||||
// // Gives about 6% speedup compared to using block size 128.
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
// } else { // Need to use the same block size as backward
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
|
||||
inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, int step_stride, Prng &ph, const int loop_step_idx) {
|
||||
inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using elem_type = typename Kernel_traits::elem_type;
|
||||
@ -250,6 +250,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// How many steps to jump per iteration, which is the same as params.num_splits.
|
||||
const int step_stride = gridDim.z;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
// if( binfo.stop_early() ) return;
|
||||
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
|
||||
@ -683,14 +686,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
|
||||
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, gridDim.z, ph, 0);
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph, 0);
|
||||
} else {
|
||||
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, gridDim.z, ph, 0);
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph, 0);
|
||||
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, gridDim.z, ph, loop_step_idx);
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph, loop_step_idx);
|
||||
}
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, gridDim.z, ph, max_loop_steps - 1);
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, ph, max_loop_steps - 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user