flash-attention/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu
2022-05-26 13:57:38 -07:00

90 lines
5.2 KiB
Plaintext

/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_block_fprop_kernel_1xN.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_block_fprop_fp16_sm80_loop_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_block_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}
template<typename Kernel_traits>
void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params,
const bool configure) {
bool is_causal = launch_params.params.is_causal;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
auto kernel = launch_params.is_dropout
? (is_causal
? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, false>)
: (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, false>))
: (is_causal
? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
: (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));
constexpr int N = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.s + N - 1) / N;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr int M = Kernel_traits::Cta_tile_p::M;
size_t STEPS = (launch_params.params.s + M - 1) / M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
launch_params.elts_per_thread = elts_per_head;
return;
}
dim3 grid(launch_params.params.h, launch_params.params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_block_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params,
const bool configure) {
if (launch_params.params.d == 16) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (launch_params.params.d == 32) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (launch_params.params.d == 64) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}