Simplify writing softmax to gmem

This commit is contained in:
Tri Dao 2024-01-13 00:25:04 -08:00
parent 8d1b169ed1
commit a7b66ae25a
4 changed files with 46 additions and 107 deletions

View File

@ -56,23 +56,6 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
@ -92,6 +75,17 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
auto seeds = at::cuda::philox::unpack(params.philox_args);
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
// exit early and no one saves the rng states.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
params.rng_state[0] = seed;
params.rng_state[1] = std::get<1>(seeds);
}
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
@ -107,13 +101,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
@ -188,8 +175,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@ -197,7 +182,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
@ -205,6 +189,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor tSgS = thr_mma.partition_C(gP);
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
//
@ -310,16 +296,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}
auto seeds = at::cuda::philox::unpack(params.philox_args);
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
// Save seed and offset for backward.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
params.rng_state[0] = seed;
params.rng_state[1] = std::get<1>(seeds);
}
clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
@ -429,14 +405,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
cute::copy(tOrP, tOrP_copy);
Tensor acc_s_f16 = flash::convert_type<Element>(acc_s);
Tensor tOrPdrop = make_tensor(acc_s_f16.data(), tOrP.layout());
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
tOrPdrop, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
cute::copy(acc_s_f16, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
@ -514,14 +490,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
cute::copy(tOrP, tOrP_copy);
Tensor acc_s_f16 = flash::convert_type<Element>(acc_s);
Tensor tOrPdrop = make_tensor(acc_s_f16.data(), tOrP.layout());
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
tOrPdrop, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
cute::copy(acc_s_f16, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,

View File

@ -106,10 +106,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
@ -139,15 +137,6 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
@ -285,16 +274,12 @@ struct Flash_bwd_kernel_traits : public Base {
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)

View File

@ -12,7 +12,7 @@ import flash_attn_2_cuda as flash_attn_cuda
# isort: on
def _get_block_size(device, head_dim, is_dropout, is_causal):
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
assert head_dim <= 256
major, minor = torch.cuda.get_device_capability(device)
@ -20,27 +20,27 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
is_sm80 = major == 8 and minor == 0
is_sm90 = major == 9 and minor == 0
if head_dim <= 32:
return 128, 128
return 128
if head_dim <= 64:
return (128, 128) if not is_dropout else (128, 64)
return 128 if not is_dropout else 64
elif head_dim <= 96:
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
return 64
elif head_dim <= 128:
if is_sm8x:
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
return 64 if (not is_dropout and is_causal) else 32
else:
return 128, (64 if not is_dropout else 32)
return 64 if not is_dropout else 32
elif head_dim <= 160:
if is_sm8x:
return (128, 64) if not is_causal else (64, 64)
return 64
else:
return 128, 32
return 32
elif head_dim <= 192:
return (128, 64) if not is_dropout else (64, 64)
return 64
elif head_dim <= 224:
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
return 64
elif head_dim <= 256:
return (128, 64) if is_sm80 else (64, 64)
return 64
def _flash_attn_forward(

View File

@ -14,7 +14,7 @@ from flash_attn import (
flash_attn_with_kvcache,
)
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192
@ -406,29 +406,7 @@ def convert_flash_attn_S_to_softmax(
if causal:
window_size = (window_size[0], 0)
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
warps_n = 4
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal)
nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
mmas_n = (blocksize_n + 16 - 1) // 16
S_flat = rearrange(
S,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
blocksize_m=blocksize_m,
blocksize_n=blocksize_n,
)
S_converted = rearrange(
S_flat,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
mmas_n=mmas_n,
warps_n=warps_n,
eight=8,
c0=2,
c1=2,
c2=2,
four=4,
)
S_converted = S
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
@ -443,7 +421,7 @@ def convert_flash_attn_S_to_softmax(
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True,
)
S_converted.masked_fill_(local_mask, 0.0)
S_converted = S_converted.masked_fill(local_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
@ -504,7 +482,7 @@ def normalize_flash_attn_S(
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias.to(dtype=scores.dtype)
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
lse = torch.logsumexp(lse_block, dim=-1)