Speed up compilation by splitting into separate .cu files

This commit is contained in:
Tri Dao 2022-11-25 16:29:17 -08:00
parent b784ed73cf
commit d95ee1a95d
13 changed files with 251 additions and 318 deletions

View File

@ -176,6 +176,16 @@ void set_params_dgrad(FMHA_dgrad_params &params,
params.dsoftmax_sum = dsoftmax_sum_d;
}
void run_fmha_fwd(Launch_params<FMHA_fprop_params> &launch_params) {
if (launch_params.params.d <= 32) {
run_fmha_fwd_hdim32(launch_params);
} else if (launch_params.params.d <= 64) {
run_fmha_fwd_hdim64(launch_params);
} else if (launch_params.params.d <= 128) {
run_fmha_fwd_hdim128(launch_params);
}
}
std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
@ -307,13 +317,22 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
}
run_fmha_fp16_sm80(launch_params);
run_fmha_fwd(launch_params);
std::vector<at::Tensor> result = {softmax_lse};
if (return_softmax) {result.push_back(s);}
return result;
}
void run_fmha_bwd(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
if (params.d <= 32) {
run_fmha_bwd_hdim32(params, stream, configure);
} else if (params.d <= 64) {
run_fmha_bwd_hdim64(params, stream, configure);
} else if (params.d <= 128) {
run_fmha_bwd_hdim128(params, stream, configure);
}
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
@ -341,7 +360,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
TORCH_CHECK(is_sm8x || is_sm75);
auto launch = &run_fmha_dgrad_fp16_sm80;
auto launch = &run_fmha_bwd;
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();

View File

@ -195,9 +195,13 @@ struct Launch_params{
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);

View File

@ -0,0 +1,13 @@
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_bwd_launch_template.h"
void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
});
}

View File

@ -0,0 +1,18 @@
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_bwd_launch_template.h"
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
}
});
}

View File

@ -0,0 +1,31 @@
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_bwd_launch_template.h"
void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
}
}
});
}

View File

@ -1,5 +1,6 @@
/* Copyright (c) 2022, Tri Dao.
*/
// Copyright (c) 2022, Tri Dao.
#pragma once
#include "static_switch.h"
#include "fp16_switch.h"
@ -9,7 +10,7 @@
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
// dq_tmp and having to copy dq_tmp to dq.
int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
inline int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
int blocksize, bool is_causal) {
float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm);
float eff_1 = n_waves_1 / ceil(n_waves_1);
@ -29,22 +30,22 @@ int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int
}
template<typename Kernel_traits>
__global__ void fmha_dgrad_dot_do_o_kernel(FMHA_dgrad_params params) {
__global__ void fmha_bwd_dot_do_o_kernel(FMHA_dgrad_params params) {
fmha::compute_dot_do_o<Kernel_traits>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
__global__ void fmha_bwd_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
__global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_seqparallel<Kernel_traits, Is_dropout, Is_causal>(params);
}
template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
@ -63,20 +64,20 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
auto kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
if (params.seqlen_k == blocksize_c) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
}
auto kernel_seqparallel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
? &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
@ -104,7 +105,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
} else {
dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128);
fmha_dgrad_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
fmha_bwd_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c
dim3 grid(params.b, params.h, num_splits);
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
@ -112,42 +113,3 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}
void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.d <= 32) {
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
} else if (params.d <= 64) {
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
}
} else if (params.d <= 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
});
}

View File

@ -1,153 +0,0 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 95%
// of the best efficiency.
int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) {
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (efficiency[num_splits - 1] > 0.95 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}
template<typename Kernel_traits>
void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) {
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
auto kernel = launch_params.params.is_causal
? (launch_params.return_softmax
? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
: &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
: (launch_params.return_softmax
? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
: &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// Automatically set num_splits to maximize occupancy
if (launch_params.params.num_splits <= 0) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size);
auto dprops = at::cuda::getCurrentDeviceProperties();
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
constexpr int M = Kernel_traits::Cta_tile_p::M;
launch_params.params.num_splits = num_splits_heuristic_fwd(
launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount,
ctas_per_sm,
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
);
}
// printf("smem_size = %d\n", smem_size);
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) {
FP16_SWITCH(launch_params.params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (launch_params.params.d <= 32) {
if (launch_params.params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else if (launch_params.params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
}
} else if (launch_params.params.d <= 64) {
if (launch_params.params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else if (launch_params.params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
}
} else if (launch_params.params.d <= 128) {
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
// to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
// reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
// some speedup (6-10%) for large batch size, but slows things down for smal batch size.
// Now that we have better parallelism (over seqlen_q), block size 128 is faster for small
// batch size and only slightly slower (~3%) on large batch size.
// For causal=True, block size 128 seems always faster (for small & large batch size).
// So we're just gonna use block size 128 for simplicity.
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
}
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
});
}

View File

@ -0,0 +1,12 @@
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_fwd_launch_template.h"
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params) {
FP16_SWITCH(launch_params.params.is_bf16, [&] {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fwd_loop<Kernel_traits>(launch_params);
});
}

View File

@ -0,0 +1,17 @@
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_fwd_launch_template.h"
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
FP16_SWITCH(launch_params.params.is_bf16, [&] {
if (launch_params.params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fwd_loop<Kernel_traits>(launch_params);
} else if (launch_params.params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fwd_loop<Kernel_traits>(launch_params);
}
});
}

View File

@ -0,0 +1,17 @@
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_fwd_launch_template.h"
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
FP16_SWITCH(launch_params.params.is_bf16, [&] {
if (launch_params.params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fwd_loop<Kernel_traits>(launch_params);
} else if (launch_params.params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fwd_loop<Kernel_traits>(launch_params);
}
});
}

View File

@ -0,0 +1,92 @@
// Copyright (c) 2022, Tri Dao.
#pragma once
#include <vector>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 95%
// of the best efficiency.
// [2022-11-25] TD: Mark this as "inline" otherwise we get "multiple definition" error.
inline int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) {
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (efficiency[num_splits - 1] > 0.95 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_fwd_loop_kernel(FMHA_fprop_params params) {
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}
template<typename Kernel_traits>
void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
auto kernel = launch_params.params.is_causal
? (launch_params.return_softmax
? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
: &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
: (launch_params.return_softmax
? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
: &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// Automatically set num_splits to maximize occupancy
if (launch_params.params.num_splits <= 0) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size);
auto dprops = at::cuda::getCurrentDeviceProperties();
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
constexpr int M = Kernel_traits::Cta_tile_p::M;
launch_params.params.num_splits = num_splits_heuristic_fwd(
launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount,
ctas_per_sm,
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
);
}
// printf("smem_size = %d\n", smem_size);
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}

View File

@ -75,107 +75,4 @@ struct BlockInfoPadded {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Cta_tile>
struct Noloop_traits{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
template<typename Block_info>
inline __device__ Noloop_traits(const int bidc, const Block_info& binfo)
: bidc_(bidc) {
const int seqlen = binfo.actual_seqlen;
const int steps = (seqlen + STEP - 1) / STEP;
const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS;
const int step_begin = bidc_ * steps_per_chunk;
const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk);
const int actual_steps = max(0, step_end - step_begin);
loop_offset_ = step_begin;
num_steps_ = actual_steps;
}
template<typename ... Tiles>
inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
}
inline __device__ int get_idx_dk() const {
//return bidc_;
return bidc_ * 2 + 0;
}
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) {
// convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
}
const uint32_t bidc_;
int loop_offset_;
int num_steps_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits>
std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const int heads_total) {
constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
const int num_full_heads = heads_total / total_ctas;
const int heads_last_wave = heads_total % total_ctas;
int num_main_groups = 0;
int main_steps = 0;
int rest_steps = 0;
if( heads_last_wave > 0 ) {
// Number of CTA groups that process within heads.
num_main_groups = total_ctas / heads_last_wave;
// Remaining CTAs that process between heads.
const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups);
if(rest_ctas == 0) {
// We have exactly "num_main_groups" CTAs to process each of the remaining heads.
main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups;
num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0
rest_steps = STEPS_PER_HEAD % main_steps;
} else {
// Ideal number of steps if we could load-balance as evenly as possible.
const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas;
// Iterations that a "rest" CTA has to do at most.
const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps = steps_ideal;
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) {
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
const int max_rest_total_steps = rest_steps * max_rest_iters;
if( max_rest_total_steps < main_steps )
break;
}
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
}
}
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps);
const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8;
const int elts_per_thread = max_steps * elts_per_thread_per_step;
return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha

View File

@ -119,8 +119,12 @@ ext_modules.append(
name="flash_attn_cuda",
sources=[
"csrc/flash_attn/fmha_api.cpp",
"csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu",
"csrc/flash_attn/src/fmha_fwd_hdim32.cu",
"csrc/flash_attn/src/fmha_fwd_hdim64.cu",
"csrc/flash_attn/src/fmha_fwd_hdim128.cu",
"csrc/flash_attn/src/fmha_bwd_hdim32.cu",
"csrc/flash_attn/src/fmha_bwd_hdim64.cu",
"csrc/flash_attn/src/fmha_bwd_hdim128.cu",
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
],
@ -152,7 +156,7 @@ ext_modules.append(
setup(
name="flash_attn",
version="0.2.1",
version="0.2.2",
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
),