[Gen] Accept cache_batch_idx to index into the KV cache
This commit is contained in:
parent
601b4dc48d
commit
e279bf8ed9
@ -1037,13 +1037,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
|
||||
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
||||
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
||||
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
@ -1084,6 +1085,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = kcache.size(1);
|
||||
const int num_heads_k = kcache.size(2);
|
||||
const int batch_size_c = kcache.size(0);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
@ -1102,8 +1104,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
||||
|
||||
at::Tensor q_padded, kcache_padded, vcache_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
@ -1229,6 +1231,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
params.rotary_dim = 0;
|
||||
}
|
||||
|
||||
if (cache_batch_idx_.has_value()) {
|
||||
auto cache_batch_idx = cache_batch_idx_.value();
|
||||
CHECK_DEVICE(cache_batch_idx);
|
||||
CHECK_CONTIGUOUS(cache_batch_idx);
|
||||
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
|
||||
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
|
||||
}
|
||||
// This needs to match with run_mha_fwd_splitkv_dispatch
|
||||
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
|
||||
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
|
||||
@ -1248,8 +1257,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
// Only split kernel supports appending to KV cache
|
||||
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value());
|
||||
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx
|
||||
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value());
|
||||
|
||||
if (head_size_og % 8 != 0) {
|
||||
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
|
||||
@ -95,6 +95,9 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
void * __restrict__ rotary_cos_ptr;
|
||||
void * __restrict__ rotary_sin_ptr;
|
||||
|
||||
// The indices to index into the KV cache.
|
||||
int *__restrict__ cache_batch_idx;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
// uint32_t p_dropout_in_uint;
|
||||
|
||||
@ -668,9 +668,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
// We move K and V to the last block.
|
||||
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
|
||||
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
|
||||
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
|
||||
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
|
||||
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
|
||||
@ -928,6 +928,7 @@ def flash_attn_with_kvcache(
|
||||
rotary_cos=None,
|
||||
rotary_sin=None,
|
||||
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||
cache_batch_idx: Optional[torch.Tensor] = None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
@ -978,8 +979,8 @@ def flash_attn_with_kvcache(
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
|
||||
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
|
||||
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
|
||||
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
|
||||
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
||||
k with k_cache, starting at the indices specified by cache_seqlens.
|
||||
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
|
||||
@ -988,6 +989,10 @@ def flash_attn_with_kvcache(
|
||||
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
||||
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
||||
KV cache.
|
||||
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
||||
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
||||
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
||||
might come from any of the duplicate indices.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
@ -1014,6 +1019,8 @@ def flash_attn_with_kvcache(
|
||||
cache_seqlens = torch.full(
|
||||
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
||||
)
|
||||
cache_seqlens = maybe_contiguous(cache_seqlens)
|
||||
cache_batch_idx = maybe_contiguous(cache_batch_idx)
|
||||
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
@ -1023,6 +1030,7 @@ def flash_attn_with_kvcache(
|
||||
cache_seqlens,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
cache_batch_idx,
|
||||
None,
|
||||
softmax_scale,
|
||||
causal,
|
||||
|
||||
@ -1668,7 +1668,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
|
||||
@pytest.mark.parametrize("new_kv", [False, True])
|
||||
# @pytest.mark.parametrize("new_kv", [True])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
# @pytest.mark.parametrize("local", [True])
|
||||
# @pytest.mark.parametrize("local", [False])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
|
||||
@ -1677,6 +1677,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
|
||||
# @pytest.mark.parametrize("rotary_interleaved", [False])
|
||||
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
|
||||
# @pytest.mark.parametrize("rotary_fraction", [0.0])
|
||||
@pytest.mark.parametrize("has_batch_idx", [False, True])
|
||||
# @pytest.mark.parametrize("has_batch_idx", [True])
|
||||
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
@ -1703,6 +1705,7 @@ def test_flash_attn_kvcache(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
d,
|
||||
has_batch_idx,
|
||||
rotary_fraction,
|
||||
rotary_interleaved,
|
||||
seqlen_new_eq_seqlen_q,
|
||||
@ -1721,6 +1724,7 @@ def test_flash_attn_kvcache(
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 2
|
||||
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
|
||||
nheads = 6
|
||||
# rotary_dim must be a multiple of 16, and must be <= d
|
||||
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
|
||||
@ -1734,8 +1738,8 @@ def test_flash_attn_kvcache(
|
||||
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
|
||||
else:
|
||||
k, v = None, None
|
||||
k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
||||
v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
||||
k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
||||
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
||||
cache_seqlens = torch.randint(
|
||||
0,
|
||||
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
|
||||
@ -1746,6 +1750,10 @@ def test_flash_attn_kvcache(
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
if has_batch_idx:
|
||||
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
|
||||
else:
|
||||
cache_batch_idx = None
|
||||
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
|
||||
if rotary_dim > 0:
|
||||
angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
|
||||
@ -1775,8 +1783,8 @@ def test_flash_attn_kvcache(
|
||||
cos, sin = None, None
|
||||
q_ro, k_ro = q, k
|
||||
# k_cache[:, 64:] = -1
|
||||
k_cache_ref = k_cache.clone()
|
||||
v_cache_ref = v_cache.clone()
|
||||
k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
|
||||
v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
|
||||
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
|
||||
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
|
||||
if new_kv:
|
||||
@ -1796,6 +1804,7 @@ def test_flash_attn_kvcache(
|
||||
cos,
|
||||
sin,
|
||||
cache_seqlens,
|
||||
cache_batch_idx,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
@ -1844,8 +1853,10 @@ def test_flash_attn_kvcache(
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
if new_kv:
|
||||
assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3)
|
||||
assert torch.equal(v_cache, v_cache_ref)
|
||||
k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
|
||||
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
|
||||
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
|
||||
assert torch.equal(v_cache_select, v_cache_ref)
|
||||
assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user