[Gen] Accept cache_batch_idx to index into the KV cache

This commit is contained in:
Tri Dao 2023-10-03 16:27:26 -07:00
parent 601b4dc48d
commit e279bf8ed9
5 changed files with 49 additions and 17 deletions

View File

@ -1037,13 +1037,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
std::vector<at::Tensor> std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size
const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
@ -1084,6 +1085,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int head_size_og = sizes[3]; const int head_size_og = sizes[3];
const int seqlen_k = kcache.size(1); const int seqlen_k = kcache.size(1);
const int num_heads_k = kcache.size(2); const int num_heads_k = kcache.size(2);
const int batch_size_c = kcache.size(0);
TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
@ -1102,8 +1104,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
} }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
at::Tensor q_padded, kcache_padded, vcache_padded; at::Tensor q_padded, kcache_padded, vcache_padded;
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
@ -1229,6 +1231,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
params.rotary_dim = 0; params.rotary_dim = 0;
} }
if (cache_batch_idx_.has_value()) {
auto cache_batch_idx = cache_batch_idx_.value();
CHECK_DEVICE(cache_batch_idx);
CHECK_CONTIGUOUS(cache_batch_idx);
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
}
// This needs to match with run_mha_fwd_splitkv_dispatch // This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
@ -1248,8 +1257,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
} }
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// Only split kernel supports appending to KV cache // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value()); run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value());
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});

View File

@ -95,6 +95,9 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr; void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
// uint32_t p_dropout_in_uint; // uint32_t p_dropout_in_uint;

View File

@ -668,9 +668,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block. // We move K and V to the last block.
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q), Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),

View File

@ -928,6 +928,7 @@ def flash_attn_with_kvcache(
rotary_cos=None, rotary_cos=None,
rotary_sin=None, rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
@ -978,8 +979,8 @@ def flash_attn_with_kvcache(
Arguments: Arguments:
q: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size, seqlen_cache, nheads_k, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
v_cache: (batch_size, seqlen_cache, nheads_k, headdim) v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens. k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
@ -988,6 +989,10 @@ def flash_attn_with_kvcache(
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache. KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
@ -1014,6 +1019,8 @@ def flash_attn_with_kvcache(
cache_seqlens = torch.full( cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
) )
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
out, softmax_lse = flash_attn_cuda.fwd_kvcache( out, softmax_lse = flash_attn_cuda.fwd_kvcache(
q, q,
k_cache, k_cache,
@ -1023,6 +1030,7 @@ def flash_attn_with_kvcache(
cache_seqlens, cache_seqlens,
rotary_cos, rotary_cos,
rotary_sin, rotary_sin,
cache_batch_idx,
None, None,
softmax_scale, softmax_scale,
causal, causal,

View File

@ -1668,7 +1668,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
@pytest.mark.parametrize("new_kv", [False, True]) @pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [True]) # @pytest.mark.parametrize("new_kv", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True]) # @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
@ -1677,6 +1677,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
# @pytest.mark.parametrize("rotary_interleaved", [False]) # @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [True])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
@ -1703,6 +1705,7 @@ def test_flash_attn_kvcache(
seqlen_q, seqlen_q,
seqlen_k, seqlen_k,
d, d,
has_batch_idx,
rotary_fraction, rotary_fraction,
rotary_interleaved, rotary_interleaved,
seqlen_new_eq_seqlen_q, seqlen_new_eq_seqlen_q,
@ -1721,6 +1724,7 @@ def test_flash_attn_kvcache(
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 2 batch_size = 2
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
nheads = 6 nheads = 6
# rotary_dim must be a multiple of 16, and must be <= d # rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
@ -1734,8 +1738,8 @@ def test_flash_attn_kvcache(
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
else: else:
k, v = None, None k, v = None, None
k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
cache_seqlens = torch.randint( cache_seqlens = torch.randint(
0, 0,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
@ -1746,6 +1750,10 @@ def test_flash_attn_kvcache(
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
if has_batch_idx:
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
else:
cache_batch_idx = None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if rotary_dim > 0: if rotary_dim > 0:
angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
@ -1775,8 +1783,8 @@ def test_flash_attn_kvcache(
cos, sin = None, None cos, sin = None, None
q_ro, k_ro = q, k q_ro, k_ro = q, k
# k_cache[:, 64:] = -1 # k_cache[:, 64:] = -1
k_cache_ref = k_cache.clone() k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
v_cache_ref = v_cache.clone() v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
if new_kv: if new_kv:
@ -1796,6 +1804,7 @@ def test_flash_attn_kvcache(
cos, cos,
sin, sin,
cache_seqlens, cache_seqlens,
cache_batch_idx,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
rotary_interleaved=rotary_interleaved, rotary_interleaved=rotary_interleaved,
@ -1844,8 +1853,10 @@ def test_flash_attn_kvcache(
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
if new_kv: if new_kv:
assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3) k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
assert torch.equal(v_cache, v_cache_ref) v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref)
assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5