Change causal mask to be aligned to bottom-right instead of top-left
This commit is contained in:
parent
e07aa036db
commit
9e5e8bc91e
26
README.md
26
README.md
@ -136,6 +136,32 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
|
||||
```python
|
||||
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
|
||||
```
|
||||
## Changes in v2.1 (compared to v2.0)
|
||||
|
||||
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
|
||||
bottom right corner of the attention matrix, instead of the top-left corner.
|
||||
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
v2.0:
|
||||
1 0 0 0 0
|
||||
1 1 0 0 0
|
||||
v2.1:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
v2.0:
|
||||
1 0
|
||||
1 1
|
||||
1 1
|
||||
1 1
|
||||
1 1
|
||||
v2.1:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
## Performance
|
||||
|
||||
|
||||
@ -15,12 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
|
||||
# from triton.ops.flash_attention import attention as attention_triton
|
||||
|
||||
try:
|
||||
from fav2 import flash_attn_qkvpacked_func as fav2_qkvpacked_func
|
||||
from fav2 import flash_attn_kvpacked_func as fav2_kvpacked_func
|
||||
except ImportError:
|
||||
fav2_qkvpacked_func = None
|
||||
fav2_kvpacked_func = None
|
||||
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
|
||||
|
||||
try:
|
||||
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
|
||||
@ -80,8 +75,8 @@ def attention_megatron(qkv):
|
||||
|
||||
torch.manual_seed(0)
|
||||
repeats = 30
|
||||
batch_size = 2
|
||||
seqlen = 8192
|
||||
batch_size = 8
|
||||
seqlen = 2048
|
||||
nheads = 12
|
||||
headdim = 128
|
||||
# nheads = 24
|
||||
@ -90,8 +85,8 @@ headdim = 128
|
||||
# seqlen = 512
|
||||
# nheads = 8
|
||||
# headdim = 128
|
||||
dropout_p = 0.1
|
||||
causal = False
|
||||
dropout_p = 0.0
|
||||
causal = True
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
|
||||
@ -100,20 +95,20 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
|
||||
# qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
|
||||
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
|
||||
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
|
||||
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
|
||||
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
|
||||
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
|
||||
# if fav2_qkvpacked_func is not None:
|
||||
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
|
||||
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
|
||||
benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
|
||||
pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)
|
||||
|
||||
# for dropout_p in [0.1, 0.0]:
|
||||
# for causal in [False, True]:
|
||||
# print(f"### {dropout_p = }, {causal = } ###")
|
||||
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
|
||||
|
||||
|
||||
# nheads_k = 2
|
||||
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
||||
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
|
||||
@ -151,6 +146,7 @@ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch
|
||||
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
|
||||
ideal_a100_time = flops / 312 / 1e9
|
||||
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
|
||||
exit(0)
|
||||
|
||||
|
||||
def time_fwd_bwd(func, *args, **kwargs):
|
||||
|
||||
@ -32,8 +32,8 @@ struct BlockInfo {
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const uint32_t actual_seqlen_q;
|
||||
const uint32_t actual_seqlen_k;
|
||||
const int actual_seqlen_q;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -659,46 +659,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded;
|
||||
|
||||
int m_block = m_block_max - 1;
|
||||
int m_block_min = !Is_causal ? 0 : (n_block * kBlockN - int(binfo.actual_seqlen_k - binfo.actual_seqlen_q)) / kBlockM;
|
||||
m_block_min = m_block_min < 0 ? 0 : m_block_min;
|
||||
|
||||
// We might need to exit early and write 0 to dK and dV.
|
||||
// Otherwise we get wrong result for the case where we don't enter the for loop.
|
||||
// And we might read OOB elements from gQ and gdO.
|
||||
// TODO: what if we're not parallelizing, do we need to compute dot_do_o?
|
||||
if (Is_causal && m_block < m_block_min) {
|
||||
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
||||
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
||||
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dk_row_stride, _1{}));
|
||||
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dv_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
|
||||
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
||||
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
||||
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
||||
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
||||
clear(tdKrdK);
|
||||
clear(tdVrdV);
|
||||
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
||||
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
return;
|
||||
}
|
||||
int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
|
||||
// We're guaranteed that m_block_min <= m_block:
|
||||
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
|
||||
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
|
||||
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
|
||||
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
|
||||
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
|
||||
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
|
||||
|
||||
if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
|
||||
tQsQ.data() = tQsQ.data() + size(sQ);
|
||||
@ -743,7 +711,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
// Using uint32_t row makes it 10us slower on d=128, not sure why.
|
||||
const int row = get<0>(taccScS_row(mi));
|
||||
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
|
||||
}
|
||||
@ -824,11 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
|
||||
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
|
||||
// But we still want to mask out elements beyond actual_seqlen_k.
|
||||
if (m_block * kBlockM < (n_block + 1) * kBlockN
|
||||
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|
||||
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
|
||||
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
binfo.actual_seqlen_q, binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
binfo.actual_seqlen_q,
|
||||
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
|
||||
AtomLayoutMS * 16);
|
||||
}
|
||||
@ -837,11 +804,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
// Compute the exponential value.
|
||||
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
|
||||
if (Is_dropout) {
|
||||
uint32_t warp_id = tidx / 32;
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
|
||||
int warp_id = tidx / 32;
|
||||
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
|
||||
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
|
||||
static_assert(MMA_N_SdP % 2 == 0);
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
|
||||
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
|
||||
Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout()));
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
scores_dropped, params.p_dropout_in_uint8_t, seed, offset,
|
||||
@ -1341,7 +1308,6 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
|
||||
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
// Using uint32_t row makes it 10us slower on d=128, not sure why.
|
||||
const int row = get<0>(taccScS_row(mi));
|
||||
lse(mi) = row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
|
||||
}
|
||||
@ -1379,18 +1345,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
|
||||
// the corresponding values of K would be 0, so the result would still be correct.
|
||||
if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) {
|
||||
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
AtomLayoutMS * 16);
|
||||
}
|
||||
// Compute the exponential value.
|
||||
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
|
||||
if (Is_dropout) {
|
||||
uint32_t warp_id = tidx / 32;
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
|
||||
int warp_id = tidx / 32;
|
||||
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
|
||||
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
|
||||
static_assert(MMA_N_SdP % 2 == 0);
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
|
||||
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
|
||||
Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout()));
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
scores_dropped, params.p_dropout_in_uint8_t, seed, offset,
|
||||
|
||||
@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem(
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
@ -130,8 +130,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
// The global block index.
|
||||
const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
@ -139,16 +137,60 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
|
||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
|
||||
|
||||
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
||||
if (Is_causal) {
|
||||
n_block_max = std::min(n_block_max, cute::ceil_div(
|
||||
(m_block + 1) * kBlockM + int(binfo.actual_seqlen_k - binfo.actual_seqlen_q), kBlockN));
|
||||
n_block_max = std::min(n_block_max,
|
||||
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
||||
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
||||
// }
|
||||
// We exit early and write 0 to gO and gLSE.
|
||||
// Otherwise we might read OOB elements from gK and gV.
|
||||
if (n_block_max <= 0) {
|
||||
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
|
||||
// exit early and no one saves the rng state.
|
||||
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
params.rng_state[0] = std::get<0>(seeds);
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
||||
clear(tOrO);
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tOgO); ++m) {
|
||||
const int row = get<0>(tOcO(0, m, 0));
|
||||
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// We iterate over the blocks in reverse order. This is because the last block is the only one
|
||||
@ -275,8 +317,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||
|
||||
// // Copy rmem to smem
|
||||
@ -298,8 +340,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||
flash::copy<Is_even_N, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||
cute::cp_async_fence();
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
||||
// __syncthreads();
|
||||
@ -317,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
|
||||
// Save seed and offset for backward.
|
||||
if (block_id == 0 && tidx == 0) {
|
||||
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
||||
params.rng_state[0] = seed;
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
@ -330,7 +372,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||
// We will have at least 1 "masking" iteration.
|
||||
|
||||
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
|
||||
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
||||
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
||||
constexpr int n_masking_steps = !Is_causal
|
||||
? 1
|
||||
: (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
|
||||
#pragma unroll
|
||||
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
@ -344,7 +390,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
} else {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
@ -363,7 +409,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||
// can produce Inf / NaN.
|
||||
if (!Is_causal) {
|
||||
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
} else {
|
||||
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
|
||||
@ -376,9 +422,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// Idk why it's get<1> and not get<0> of the stride.
|
||||
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
|
||||
// I can't get the stride from idx_row
|
||||
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_q, binfo.actual_seqlen_k,
|
||||
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM + get<0>(idx_row(0)),
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16);
|
||||
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
|
||||
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
|
||||
@ -405,8 +452,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
int block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
cute::copy(tOrP, tOrP_copy);
|
||||
@ -468,8 +515,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
int block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
cute::copy(tOrP, tOrP_copy);
|
||||
@ -563,14 +610,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn(const Params ¶ms) {
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
@ -586,7 +633,7 @@ inline __device__ void compute_attn(const Params ¶ms) {
|
||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||
|
||||
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -10,9 +10,9 @@
|
||||
#include "flash.h"
|
||||
#include "flash_fwd_kernel.h"
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
||||
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
@ -26,17 +26,15 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.b, params.h);
|
||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_q as well.
|
||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
|
||||
@ -117,18 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k,
|
||||
const uint32_t col_idx_offset_ = 0) {
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const uint32_t col_idx = col_idx_base + j;
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
@ -141,28 +141,28 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
|
||||
const uint32_t max_seqlen_q, const uint32_t max_seqlen_k,
|
||||
const uint32_t row_idx_offset_, const uint32_t warp_row_stride) {
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset_,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
|
||||
const uint32_t row_idx_offset = row_idx_offset_;
|
||||
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
|
||||
const int row_idx_offset = row_idx_offset_;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const uint32_t row_idx = row_idx_base + i * 8;
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const uint32_t col_idx = col_idx_base + j;
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
@ -180,7 +180,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
@ -189,7 +189,7 @@ inline __device__ void apply_mask_causal_w_idx(
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
@ -207,8 +207,8 @@ inline __device__ void apply_mask_causal_w_idx(
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||
unsigned long long seed, unsigned long long offset,
|
||||
uint32_t block_row_start, uint32_t block_col_start,
|
||||
uint32_t block_row_stride) {
|
||||
int block_row_start, int block_col_start,
|
||||
int block_row_stride) {
|
||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
__version__ = "2.0.9"
|
||||
__version__ = "2.1.0"
|
||||
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_func,
|
||||
|
||||
@ -528,6 +528,18 @@ def flash_attn_kvpacked_func(
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
kv: (batch_size, seqlen, 2, nheads_k, headdim)
|
||||
@ -559,6 +571,18 @@ def flash_attn_func(
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
@ -645,6 +669,18 @@ def flash_attn_varlen_kvpacked_func(
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
@ -703,6 +739,18 @@ def flash_attn_varlen_func(
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
|
||||
@ -29,9 +29,11 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
|
||||
if mode == "full":
|
||||
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
||||
elif mode == "random":
|
||||
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device)
|
||||
lengths = torch.randint(
|
||||
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
||||
)
|
||||
elif mode == "third":
|
||||
lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device)
|
||||
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
||||
padding_mask = (
|
||||
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
||||
)
|
||||
@ -146,6 +148,23 @@ def generate_qkv(
|
||||
)
|
||||
|
||||
|
||||
def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None,
|
||||
device=None):
|
||||
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
||||
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
||||
sk = (
|
||||
seqlen_k
|
||||
if key_padding_mask is None
|
||||
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
sq = (
|
||||
seqlen_q
|
||||
if query_padding_mask is None
|
||||
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
return col_idx > row_idx + sk - sq
|
||||
|
||||
|
||||
def attention_ref(
|
||||
q,
|
||||
k,
|
||||
@ -190,11 +209,16 @@ def attention_ref(
|
||||
if key_padding_mask is not None:
|
||||
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
||||
if causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
|
||||
# causal_mask = torch.triu(
|
||||
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
|
||||
# )
|
||||
causal_mask = construct_causal_mask(
|
||||
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device
|
||||
)
|
||||
scores.masked_fill_(causal_mask, float("-inf"))
|
||||
attention = torch.softmax(scores, dim=-1)
|
||||
if causal: # Some rows are completely masked out so we fill them with zero instead of NaN
|
||||
attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0)
|
||||
dropout_scaling = 1.0 / (1 - dropout_p)
|
||||
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
|
||||
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
||||
@ -300,19 +324,19 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask
|
||||
|
||||
|
||||
def convert_flash_attn_S_to_softmax(
|
||||
S, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False
|
||||
S, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False
|
||||
):
|
||||
"""FlashAttention stores the S matrix in a different way.
|
||||
Arguments:
|
||||
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
|
||||
query_padding_mask: (batch_size, seqlen_q)
|
||||
key_padding_mask: (batch_size, seqlen_k)
|
||||
query_padding_mask: (batch_size, seqlen_q_rounded)
|
||||
key_padding_mask: (batch_size, seqlen_k_rounded)
|
||||
"""
|
||||
seqlen_q, seqlen_k = S.shape[-2:]
|
||||
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
|
||||
warps_n = 4
|
||||
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal)
|
||||
nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
|
||||
nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
|
||||
nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
|
||||
nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
|
||||
mmas_n = (blocksize_n + 16 - 1) // 16
|
||||
S_flat = rearrange(
|
||||
S,
|
||||
@ -331,37 +355,30 @@ def convert_flash_attn_S_to_softmax(
|
||||
c2=2,
|
||||
four=4,
|
||||
)
|
||||
|
||||
if causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1
|
||||
# causal_mask = torch.triu(
|
||||
# torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1
|
||||
# )
|
||||
causal_mask = construct_causal_mask(
|
||||
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device
|
||||
)
|
||||
causal_mask = F.pad(causal_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True)
|
||||
S_converted.masked_fill_(causal_mask, 0.0)
|
||||
|
||||
# Need to zero out things not in attention_mask in case S was initialized with random values
|
||||
# and some of those values aren't overwritten.
|
||||
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q
|
||||
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
|
||||
if query_padding_mask is not None:
|
||||
if seqlen_q_og < seqlen_q:
|
||||
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
|
||||
else:
|
||||
query_padding_mask = query_padding_mask[:, :seqlen_q]
|
||||
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
|
||||
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
|
||||
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
|
||||
if key_padding_mask is not None:
|
||||
if seqlen_k_og < seqlen_k:
|
||||
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
|
||||
else:
|
||||
key_padding_mask = key_padding_mask[:, :seqlen_k]
|
||||
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
|
||||
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
|
||||
if seqlen_q_og < seqlen_q:
|
||||
S_converted = S_converted[:, :, :seqlen_q_og, :]
|
||||
else:
|
||||
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q))
|
||||
if seqlen_k_og < seqlen_k:
|
||||
S_converted = S_converted[:, :, :, :seqlen_k_og]
|
||||
else:
|
||||
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k))
|
||||
return S_converted
|
||||
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
|
||||
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
|
||||
return S_converted[:, :, :seqlen_q, :seqlen_k]
|
||||
|
||||
|
||||
def normalize_flash_attn_S(
|
||||
@ -390,20 +407,26 @@ def normalize_flash_attn_S(
|
||||
if key_padding_mask is not None:
|
||||
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
||||
if causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
|
||||
# causal_mask = torch.triu(
|
||||
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
|
||||
# )
|
||||
causal_mask = construct_causal_mask(
|
||||
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device
|
||||
)
|
||||
scores.masked_fill_(causal_mask, float("-inf"))
|
||||
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
|
||||
scores_block = scores.split(block_size_n, dim=-1)
|
||||
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
|
||||
lse = torch.logsumexp(lse_block, dim=-1)
|
||||
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
|
||||
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
|
||||
lse[lse == float("-inf")] = float("inf")
|
||||
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
|
||||
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
|
||||
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
|
||||
attn_norm = torch.cat(
|
||||
[
|
||||
a / rearrange(torch.exp(lse - m), "b h s -> b h s 1")
|
||||
a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
|
||||
for a, m in zip(attn_unnorm_block, cummax_block)
|
||||
],
|
||||
dim=-1,
|
||||
@ -428,8 +451,11 @@ def get_dropout_fraction(
|
||||
if key_padding_mask is not None:
|
||||
dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
|
||||
if causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
|
||||
# causal_mask = torch.triu(
|
||||
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
|
||||
# )
|
||||
causal_mask = construct_causal_mask(
|
||||
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, dropout_mask.device
|
||||
)
|
||||
dropped.masked_fill_(causal_mask, False)
|
||||
dropped_total = dropped.sum()
|
||||
@ -447,9 +473,9 @@ def get_dropout_fraction(
|
||||
numel_per_batch = query_lengths * key_lengths
|
||||
else:
|
||||
numel_per_batch = torch.where(
|
||||
query_lengths <= key_lengths,
|
||||
query_lengths * (query_lengths + 1) / 2,
|
||||
query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2),
|
||||
key_lengths <= query_lengths,
|
||||
key_lengths * (key_lengths + 1) / 2,
|
||||
query_lengths * key_lengths - (query_lengths * (query_lengths - 1) / 2),
|
||||
)
|
||||
return dropped_total / (numel_per_batch.sum() * nheads)
|
||||
|
||||
@ -483,8 +509,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
)
|
||||
if dropout_p > 0.0:
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask, None, None, d, dropout_p > 0.0, causal=causal
|
||||
)[:, :, :seqlen, :seqlen]
|
||||
S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
attn_unnorm = S_dmask_converted.abs()
|
||||
attn = normalize_flash_attn_S(
|
||||
@ -596,8 +622,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
out = output_pad_fn(out_unpad)
|
||||
if dropout_p > 0.0:
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
|
||||
)[:, :, :seqlen, :seqlen]
|
||||
S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
attn_unnorm = S_dmask_converted.abs()
|
||||
attn = normalize_flash_attn_S(
|
||||
@ -665,19 +691,19 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kvpacked", [True, False])
|
||||
# @pytest.mark.parametrize('kvpacked', [False])
|
||||
# @pytest.mark.parametrize("kvpacked", [False])
|
||||
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize('mha_type', ["mha"])
|
||||
# @pytest.mark.parametrize("mha_type", ["mha"])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
# @pytest.mark.parametrize("d", [64])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
@ -693,9 +719,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
(2048, 2048),
|
||||
],
|
||||
)
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
||||
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
||||
# @pytest.mark.parametrize('dropout_p', [0.0])
|
||||
# @pytest.mark.parametrize("dropout_p", [0.17])
|
||||
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked):
|
||||
if (
|
||||
max(seqlen_q, seqlen_k) >= 2048
|
||||
@ -732,8 +758,8 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
|
||||
)
|
||||
if dropout_p > 0.0:
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask, None, None, d, dropout_p > 0.0, causal=causal
|
||||
)[:, :, :seqlen_q, :seqlen_k]
|
||||
S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
attn_unnorm = S_dmask_converted.abs()
|
||||
if kvpacked:
|
||||
@ -969,8 +995,8 @@ def test_flash_attn_varlen_output(
|
||||
out = output_pad_fn(out_unpad)
|
||||
if dropout_p > 0.0:
|
||||
S_dmask_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
|
||||
)[:, :, :seqlen_q, :seqlen_k]
|
||||
S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
|
||||
)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
attn_unnorm = S_dmask_converted.abs()
|
||||
if kvpacked:
|
||||
@ -1101,53 +1127,314 @@ def test_flash_attn_varlen_output(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize("d", [64, 128])
|
||||
@pytest.mark.parametrize("swap_sq_sk", [False, True])
|
||||
# @pytest.mark.parametrize("swap_sq_sk", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(1, 239),
|
||||
(3, 799),
|
||||
(127, 512),
|
||||
(127, 513),
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(1023, 1024),
|
||||
],
|
||||
)
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
||||
def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
|
||||
if (
|
||||
max(seqlen_q, seqlen_k) >= 2048
|
||||
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
||||
):
|
||||
pytest.skip() # Reference implementation OOM
|
||||
if swap_sq_sk:
|
||||
seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
||||
device = "cuda"
|
||||
causal = True
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 16
|
||||
nheads = 9
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
out = flash_attn_func(q, k, v, 0.0, causal=causal)
|
||||
out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
None,
|
||||
causal=causal,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
do_o = (g.float() * out.float()).sum(-1)
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
(
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
) = torch.autograd.grad(out, (q, k, v), g)
|
||||
(
|
||||
dq_ref,
|
||||
dk_ref,
|
||||
dv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||
(
|
||||
dq_pt,
|
||||
dk_pt,
|
||||
dv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
||||
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize("d", [128])
|
||||
@pytest.mark.parametrize("swap_sq_sk", [False, True])
|
||||
# @pytest.mark.parametrize("swap_sq_sk", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(1, 239),
|
||||
(3, 799),
|
||||
(127, 512),
|
||||
(127, 513),
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(1023, 1024),
|
||||
],
|
||||
)
|
||||
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
|
||||
def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
|
||||
if (
|
||||
max(seqlen_q, seqlen_k) >= 2048
|
||||
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
||||
):
|
||||
pytest.skip() # Reference implementation OOM
|
||||
if swap_sq_sk:
|
||||
seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
||||
device = "cuda"
|
||||
causal = True
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 16
|
||||
nheads = 9
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
||||
(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dk_pad_fn,
|
||||
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
|
||||
out_unpad = flash_attn_varlen_func(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
causal=causal,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
out_ref, attn_ref = attention_ref(
|
||||
q, k, v, query_padding_mask, key_padding_mask, 0.0, None, causal=causal
|
||||
)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
0.0,
|
||||
None,
|
||||
causal=causal,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
do_o = (g.float() * out.float()).sum(-1)
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
(
|
||||
dq_unpad,
|
||||
dk_unpad,
|
||||
dv_unpad,
|
||||
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
dk = dk_pad_fn(dk_unpad)
|
||||
dv = dk_pad_fn(dv_unpad)
|
||||
(
|
||||
dq_ref,
|
||||
dk_ref,
|
||||
dv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||
(
|
||||
dq_pt,
|
||||
dk_pt,
|
||||
dv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
||||
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
|
||||
@pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [128])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
|
||||
@pytest.mark.parametrize("seqlen", [128])
|
||||
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
|
||||
@pytest.mark.parametrize("dropout_p", [0.0])
|
||||
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(1, 239),
|
||||
(239, 1),
|
||||
(3, 799),
|
||||
(799, 3),
|
||||
(1024, 128),
|
||||
(97, 97),
|
||||
(128, 128),
|
||||
(200, 200),
|
||||
(256, 256),
|
||||
(257, 257),
|
||||
(384, 384),
|
||||
(512, 512),
|
||||
(768, 768),
|
||||
(1024, 1024),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
|
||||
# @pytest.mark.parametrize("dropout_p", [0.0])
|
||||
def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
|
||||
nheads = 4
|
||||
qkv = torch.randn(
|
||||
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
out0, lse0, _ = flash_attn_qkvpacked_func(qkv, dropout_p, return_attn_probs=True, causal=causal)
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
torch.random.manual_seed(42)
|
||||
out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
|
||||
g = torch.randn_like(out0)
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
(dqkv0,) = torch.autograd.grad(out0, qkv, g)
|
||||
(
|
||||
dq0,
|
||||
dk0,
|
||||
dv0,
|
||||
) = torch.autograd.grad(out0, (q, k, v), g)
|
||||
# Numerical error if we just do any arithmetic on dq
|
||||
dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item()
|
||||
dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()
|
||||
|
||||
for i in range(200):
|
||||
torch.random.manual_seed(0)
|
||||
out, lse, S_dmask = flash_attn_qkvpacked_func(
|
||||
qkv, dropout_p, return_attn_probs=True, causal=causal
|
||||
)
|
||||
for i in range(250):
|
||||
torch.random.manual_seed(42)
|
||||
out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
|
||||
assert torch.equal(out, out0)
|
||||
assert torch.equal(lse, lse0)
|
||||
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
(dqkv,) = torch.autograd.grad(out, qkv, g)
|
||||
dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol)
|
||||
(
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
) = torch.autograd.grad(out, (q, k, v), g)
|
||||
dq_equal = torch.allclose(dq, dq0, atol=dq_atol)
|
||||
if not dq_equal:
|
||||
dq0 = dqkv0[:, :, 0]
|
||||
dq = dqkv[:, :, 0]
|
||||
print(
|
||||
f"Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}"
|
||||
)
|
||||
print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}")
|
||||
assert torch.equal(dv, dv0)
|
||||
assert torch.equal(dk, dk0)
|
||||
assert dq_equal
|
||||
assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
|
||||
assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
|
||||
@ -89,7 +89,7 @@ RUN pip install flash-attn==2.0.9
|
||||
|
||||
# Install CUDA extensions for cross-entropy, fused dense, layer norm
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention \
|
||||
&& cd flash-attention && git checkout v2.0.9 \
|
||||
&& cd flash-attention && git checkout v2.1.0 \
|
||||
&& cd csrc/fused_softmax && pip install . && cd ../../ \
|
||||
&& cd csrc/rotary && pip install . && cd ../../ \
|
||||
&& cd csrc/xentropy && pip install . && cd ../../ \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user