[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset

This commit is contained in:
Tri Dao 2023-09-18 16:08:44 -07:00
parent 42832575d4
commit e6a8026489
3 changed files with 72 additions and 75 deletions

View File

@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size,
inference_params.max_sequence_len,
inference_params.max_seqlen,
2,
num_heads,
head_dim,
@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
@ -445,12 +445,12 @@ class MHA(nn.Module):
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert inference_params is not None and inference_params.sequence_len_offset > 0
assert inference_params is not None and inference_params.seqlen_offset > 0
assert self.use_flash_attn
if self.rotary_emb_dim > 0:
assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.rotary_emb._update_cos_sin_cache(
inference_params.max_sequence_len, device=q.device, dtype=q.dtype
inference_params.max_seqlen, device=q.device, dtype=q.dtype
)
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else:
@ -460,7 +460,7 @@ class MHA(nn.Module):
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
else inference_params.seqlen_offset
)
context = flash_attn_with_kvcache(
q,
@ -480,11 +480,11 @@ class MHA(nn.Module):
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention"""
if (
inference_params.sequence_len_offset == 0
inference_params.seqlen_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
@ -493,7 +493,7 @@ class MHA(nn.Module):
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
else inference_params.seqlen_offset
)
return flash_attn_with_kvcache(
q,
@ -561,12 +561,10 @@ class MHA(nn.Module):
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
else inference_params.seqlen_offset
)
)
rotary_max_seqlen = (
inference_params.max_sequence_len if inference_params is not None else None
)
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
batch, seqlen = x.shape[:2]
if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None
@ -581,7 +579,7 @@ class MHA(nn.Module):
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
if (
inference_params is None
or inference_params.sequence_len_offset == 0
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
@ -632,7 +630,7 @@ class MHA(nn.Module):
).contiguous()
if (
inference_params is None
or inference_params.sequence_len_offset == 0
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
@ -772,12 +770,12 @@ class ParallelMHA(nn.Module):
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
assert inference_params is not None and inference_params.sequence_len_offset > 0
assert inference_params is not None and inference_params.seqlen_offset > 0
assert self.use_flash_attn
if self.rotary_emb_dim > 0:
assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.rotary_emb._update_cos_sin_cache(
inference_params.max_sequence_len, device=q.device, dtype=q.dtype
inference_params.max_seqlen, device=q.device, dtype=q.dtype
)
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else:
@ -787,7 +785,7 @@ class ParallelMHA(nn.Module):
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
else inference_params.seqlen_offset
)
context = flash_attn_with_kvcache(
q,
@ -806,8 +804,8 @@ class ParallelMHA(nn.Module):
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention"""
if inference_params.sequence_len_offset == 0 or not self.use_flash_attn:
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
# TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
@ -816,7 +814,7 @@ class ParallelMHA(nn.Module):
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
else inference_params.seqlen_offset
)
context = flash_attn_with_kvcache(
q,
@ -847,17 +845,15 @@ class ParallelMHA(nn.Module):
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
else inference_params.seqlen_offset
)
)
rotary_max_seqlen = (
inference_params.max_sequence_len if inference_params is not None else None
)
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
if (
inference_params is None
or inference_params.sequence_len_offset == 0
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
@ -892,7 +888,7 @@ class ParallelMHA(nn.Module):
)
if (
inference_params is None
or inference_params.sequence_len_offset == 0
or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):

View File

@ -20,13 +20,20 @@ class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_sequence_len: int
max_seqlen: int
max_batch_size: int
sequence_len_offset: int = 0
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
@ -127,19 +134,16 @@ def decode(
tensor_parallel=tensor_parallel,
)
inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length
inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0
inference_params.lengths_per_sample.zero_()
inference_params.reset(max_length, batch_size)
else:
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def get_logits(input_ids, inference_params):
decoding = inference_params.sequence_len_offset > 0
decoding = inference_params.seqlen_offset > 0
if decoding:
position_ids = torch.full(
(batch_size, 1),
inference_params.sequence_len_offset,
inference_params.seqlen_offset,
dtype=torch.long,
device=input_ids.device,
)
@ -154,24 +158,24 @@ def decode(
).logits.squeeze(dim=1)
else:
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset
input_ids, position_ids, inference_params.seqlen_offset
).clone()
return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(logits, inference_params):
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset:
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
token = teacher_outputs[:, inference_params.sequence_len_offset]
token = teacher_outputs[:, inference_params.seqlen_offset]
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)
def should_stop(current_token, inference_params):
if inference_params.sequence_len_offset == 0:
if inference_params.seqlen_offset == 0:
return False
if eos_token_id is not None and (current_token == eos_token_id).all():
return True
if inference_params.sequence_len_offset >= max_length - 1:
if inference_params.seqlen_offset >= max_length - 1:
return True
return False
@ -185,7 +189,7 @@ def decode(
scores, sequences = [], [input_ids]
while not should_stop(sequences[-1], inference_params):
scores.append(get_logits(sequences[-1], inference_params))
inference_params.sequence_len_offset += sequences[-1].shape[1]
inference_params.seqlen_offset += sequences[-1].shape[1]
sequences.append(sample_tokens(scores[-1], inference_params))
if enable_timing:
end.record()
@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t
return tokens, first_rejected_idx + 1
@torch.inference_mode()
def decode_speculative(
input_ids,
model,
@ -303,15 +308,11 @@ def decode_speculative(
tensor_parallel=tensor_parallel,
)
inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length
inference_params_draft.max_batch_size = batch_size
inference_params_draft.sequence_len_offset = 0
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
inference_params_draft.reset(max_length, batch_size)
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
else:
inference_params_draft = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size
)
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
if not cg:
@ -323,7 +324,7 @@ def decode_speculative(
).logits.squeeze(dim=1)
else:
return model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset
input_ids, position_ids, inference_params.seqlen_offset
).clone()
logits_postprocess_fn = (
@ -365,13 +366,13 @@ def decode_speculative(
assert seqlen == 1
position_ids = repeat(
torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
+ inference_params.sequence_len_offset,
+ inference_params.seqlen_offset,
"s -> b s",
b=batch_size,
)
# position_ids = torch.full(
# (batch_size, 1),
# inference_params.sequence_len_offset,
# inference_params.seqlen_offset,
# dtype=torch.long,
# device=input_ids.device,
# )
@ -380,7 +381,7 @@ def decode_speculative(
logits = logits_postprocess_fn(
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
)
inference_params.sequence_len_offset += input_ids.shape[1]
inference_params.seqlen_offset += input_ids.shape[1]
scores = [logits]
next_token = sample_fn(logits)
sequences.append(next_token)
@ -388,7 +389,7 @@ def decode_speculative(
if i < num_tokens - 1 or last_token_logits:
position_ids = torch.full(
(batch_size, 1),
inference_params_draft.sequence_len_offset,
inference_params_draft.seqlen_offset,
dtype=torch.long,
device=input_ids.device,
)
@ -401,7 +402,7 @@ def decode_speculative(
cg=cg,
)
)
inference_params.sequence_len_offset += 1
inference_params.seqlen_offset += 1
scores.append(logits)
if i < num_tokens - 1:
next_token = sample_fn(logits)
@ -476,8 +477,8 @@ def decode_speculative(
scores.append(logits[:1, : num_generated_tokens[0]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
if debug:
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
scores_ref = model(
@ -486,10 +487,10 @@ def decode_speculative(
print((scores[-1] - scores_ref[:, :-1]).abs().max())
while True:
# sequence_len_offset is total length generated - 1
if inference_params.sequence_len_offset >= max_length - 1:
# seqlen_offset is total length generated - 1
if inference_params.seqlen_offset >= max_length - 1:
break
if inference_params.sequence_len_offset >= max_length - 2:
if inference_params.seqlen_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens)
@ -497,7 +498,7 @@ def decode_speculative(
break
# Sample from draft model
n_spec_tokens = min(
speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2
speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
)
tokens_draft, scores_draft = sample_tokens_draft(
sequences[-1][:, -1:], num_tokens=n_spec_tokens
@ -510,9 +511,9 @@ def decode_speculative(
# Evaluate the draft tokens with the model
position_ids = repeat(
torch.arange(
inference_params.sequence_len_offset,
inference_params.seqlen_offset,
# 1 extra token from last time that hasn't been passed through model
inference_params.sequence_len_offset + n_spec_tokens + 1,
inference_params.seqlen_offset + n_spec_tokens + 1,
dtype=torch.long,
device=input_ids.device,
),
@ -525,7 +526,7 @@ def decode_speculative(
inference_params=inference_params,
).logits # (batch, n_spec_tokens, vocab_size)
logits = logits_postprocess_fn(logits)
inference_params.sequence_len_offset += 1
inference_params.seqlen_offset += 1
if debug:
logits_ref = model(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
@ -539,8 +540,8 @@ def decode_speculative(
print(num_generated_tokens)
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
inference_params.seqlen_offset += num_generated_tokens[0].item() - 1
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
# breakpoint()
if debug:
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
@ -679,9 +680,9 @@ def update_graph_cache(
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
max_sequence_len=max_seqlen,
max_seqlen=max_seqlen,
max_batch_size=batch_size,
sequence_len_offset=seqlen_og,
seqlen_offset=seqlen_og,
key_value_memory_dict=inf_cache,
lengths_per_sample=lengths_per_sample,
)
@ -705,7 +706,7 @@ def update_graph_cache(
)
cache.run = dispatch
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
return cache
@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
sequence_len_offset_og = inference_params.sequence_len_offset
seqlen_offset_og = inference_params.seqlen_offset
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
inference_params.sequence_len_offset = max_seqlen - 1
inference_params.seqlen_offset = max_seqlen - 1
inference_params.lengths_per_sample[:] = max_seqlen - 1
# Warmup before capture
@ -755,5 +756,5 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
graph.replay()
return logits.clone()
inference_params.sequence_len_offset = sequence_len_offset_og
inference_params.seqlen_offset = seqlen_offset_og
return run

View File

@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
logits_ref = model(input_ids).logits
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1)
inference_params = InferenceParams(max_seqlen=20, max_batch_size=1)
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
inference_params.sequence_len_offset += 10
inference_params.seqlen_offset += 10
position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
logits_1014 = model(
input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
).logits
inference_params.sequence_len_offset += 4
inference_params.seqlen_offset += 4
position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
logits_1420 = model(
input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params