[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
This commit is contained in:
parent
42832575d4
commit
e6a8026489
@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
|
|||||||
if layer_idx not in inference_params.key_value_memory_dict:
|
if layer_idx not in inference_params.key_value_memory_dict:
|
||||||
kv_cache = torch.empty(
|
kv_cache = torch.empty(
|
||||||
inference_params.max_batch_size,
|
inference_params.max_batch_size,
|
||||||
inference_params.max_sequence_len,
|
inference_params.max_seqlen,
|
||||||
2,
|
2,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
|
|||||||
# Adjust key and value for inference
|
# Adjust key and value for inference
|
||||||
batch_start = inference_params.batch_size_offset
|
batch_start = inference_params.batch_size_offset
|
||||||
batch_end = batch_start + kv.shape[0]
|
batch_end = batch_start + kv.shape[0]
|
||||||
sequence_start = inference_params.sequence_len_offset
|
sequence_start = inference_params.seqlen_offset
|
||||||
sequence_end = sequence_start + kv.shape[1]
|
sequence_end = sequence_start + kv.shape[1]
|
||||||
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
||||||
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
||||||
@ -445,12 +445,12 @@ class MHA(nn.Module):
|
|||||||
q: (batch_size, seqlen_q, nheads, head_dim)
|
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||||
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
||||||
"""
|
"""
|
||||||
assert inference_params is not None and inference_params.sequence_len_offset > 0
|
assert inference_params is not None and inference_params.seqlen_offset > 0
|
||||||
assert self.use_flash_attn
|
assert self.use_flash_attn
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
||||||
self.rotary_emb._update_cos_sin_cache(
|
self.rotary_emb._update_cos_sin_cache(
|
||||||
inference_params.max_sequence_len, device=q.device, dtype=q.dtype
|
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
||||||
)
|
)
|
||||||
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
||||||
else:
|
else:
|
||||||
@ -460,7 +460,7 @@ class MHA(nn.Module):
|
|||||||
cache_seqlens = (
|
cache_seqlens = (
|
||||||
inference_params.lengths_per_sample[:batch]
|
inference_params.lengths_per_sample[:batch]
|
||||||
if inference_params.lengths_per_sample is not None
|
if inference_params.lengths_per_sample is not None
|
||||||
else inference_params.sequence_len_offset
|
else inference_params.seqlen_offset
|
||||||
)
|
)
|
||||||
context = flash_attn_with_kvcache(
|
context = flash_attn_with_kvcache(
|
||||||
q,
|
q,
|
||||||
@ -480,11 +480,11 @@ class MHA(nn.Module):
|
|||||||
def _update_kvcache_attention(self, q, kv, inference_params):
|
def _update_kvcache_attention(self, q, kv, inference_params):
|
||||||
"""Write kv to inference_params, then do attention"""
|
"""Write kv to inference_params, then do attention"""
|
||||||
if (
|
if (
|
||||||
inference_params.sequence_len_offset == 0
|
inference_params.seqlen_offset == 0
|
||||||
or flash_attn_with_kvcache is None
|
or flash_attn_with_kvcache is None
|
||||||
or not self.use_flash_attn
|
or not self.use_flash_attn
|
||||||
):
|
):
|
||||||
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
|
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
||||||
kv = self._update_kv_cache(kv, inference_params)
|
kv = self._update_kv_cache(kv, inference_params)
|
||||||
return self.inner_cross_attn(q, kv)
|
return self.inner_cross_attn(q, kv)
|
||||||
else:
|
else:
|
||||||
@ -493,7 +493,7 @@ class MHA(nn.Module):
|
|||||||
cache_seqlens = (
|
cache_seqlens = (
|
||||||
inference_params.lengths_per_sample[:batch]
|
inference_params.lengths_per_sample[:batch]
|
||||||
if inference_params.lengths_per_sample is not None
|
if inference_params.lengths_per_sample is not None
|
||||||
else inference_params.sequence_len_offset
|
else inference_params.seqlen_offset
|
||||||
)
|
)
|
||||||
return flash_attn_with_kvcache(
|
return flash_attn_with_kvcache(
|
||||||
q,
|
q,
|
||||||
@ -561,12 +561,10 @@ class MHA(nn.Module):
|
|||||||
else (
|
else (
|
||||||
inference_params.lengths_per_sample
|
inference_params.lengths_per_sample
|
||||||
if inference_params.lengths_per_sample is not None
|
if inference_params.lengths_per_sample is not None
|
||||||
else inference_params.sequence_len_offset
|
else inference_params.seqlen_offset
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
rotary_max_seqlen = (
|
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
||||||
inference_params.max_sequence_len if inference_params is not None else None
|
|
||||||
)
|
|
||||||
batch, seqlen = x.shape[:2]
|
batch, seqlen = x.shape[:2]
|
||||||
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
||||||
assert x_kv is None and mixer_subset is None
|
assert x_kv is None and mixer_subset is None
|
||||||
@ -581,7 +579,7 @@ class MHA(nn.Module):
|
|||||||
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
||||||
if (
|
if (
|
||||||
inference_params is None
|
inference_params is None
|
||||||
or inference_params.sequence_len_offset == 0
|
or inference_params.seqlen_offset == 0
|
||||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||||
or not self.use_flash_attn
|
or not self.use_flash_attn
|
||||||
):
|
):
|
||||||
@ -632,7 +630,7 @@ class MHA(nn.Module):
|
|||||||
).contiguous()
|
).contiguous()
|
||||||
if (
|
if (
|
||||||
inference_params is None
|
inference_params is None
|
||||||
or inference_params.sequence_len_offset == 0
|
or inference_params.seqlen_offset == 0
|
||||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||||
or not self.use_flash_attn
|
or not self.use_flash_attn
|
||||||
):
|
):
|
||||||
@ -772,12 +770,12 @@ class ParallelMHA(nn.Module):
|
|||||||
q: (batch_size, seqlen_q, nheads, head_dim)
|
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||||
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
||||||
"""
|
"""
|
||||||
assert inference_params is not None and inference_params.sequence_len_offset > 0
|
assert inference_params is not None and inference_params.seqlen_offset > 0
|
||||||
assert self.use_flash_attn
|
assert self.use_flash_attn
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
||||||
self.rotary_emb._update_cos_sin_cache(
|
self.rotary_emb._update_cos_sin_cache(
|
||||||
inference_params.max_sequence_len, device=q.device, dtype=q.dtype
|
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
||||||
)
|
)
|
||||||
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
||||||
else:
|
else:
|
||||||
@ -787,7 +785,7 @@ class ParallelMHA(nn.Module):
|
|||||||
cache_seqlens = (
|
cache_seqlens = (
|
||||||
inference_params.lengths_per_sample[:batch]
|
inference_params.lengths_per_sample[:batch]
|
||||||
if inference_params.lengths_per_sample is not None
|
if inference_params.lengths_per_sample is not None
|
||||||
else inference_params.sequence_len_offset
|
else inference_params.seqlen_offset
|
||||||
)
|
)
|
||||||
context = flash_attn_with_kvcache(
|
context = flash_attn_with_kvcache(
|
||||||
q,
|
q,
|
||||||
@ -806,8 +804,8 @@ class ParallelMHA(nn.Module):
|
|||||||
|
|
||||||
def _update_kvcache_attention(self, q, kv, inference_params):
|
def _update_kvcache_attention(self, q, kv, inference_params):
|
||||||
"""Write kv to inference_params, then do attention"""
|
"""Write kv to inference_params, then do attention"""
|
||||||
if inference_params.sequence_len_offset == 0 or not self.use_flash_attn:
|
if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
|
||||||
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
|
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
||||||
kv = self._update_kv_cache(kv, inference_params)
|
kv = self._update_kv_cache(kv, inference_params)
|
||||||
return self.inner_cross_attn(q, kv)
|
return self.inner_cross_attn(q, kv)
|
||||||
else:
|
else:
|
||||||
@ -816,7 +814,7 @@ class ParallelMHA(nn.Module):
|
|||||||
cache_seqlens = (
|
cache_seqlens = (
|
||||||
inference_params.lengths_per_sample[:batch]
|
inference_params.lengths_per_sample[:batch]
|
||||||
if inference_params.lengths_per_sample is not None
|
if inference_params.lengths_per_sample is not None
|
||||||
else inference_params.sequence_len_offset
|
else inference_params.seqlen_offset
|
||||||
)
|
)
|
||||||
context = flash_attn_with_kvcache(
|
context = flash_attn_with_kvcache(
|
||||||
q,
|
q,
|
||||||
@ -847,17 +845,15 @@ class ParallelMHA(nn.Module):
|
|||||||
else (
|
else (
|
||||||
inference_params.lengths_per_sample
|
inference_params.lengths_per_sample
|
||||||
if inference_params.lengths_per_sample is not None
|
if inference_params.lengths_per_sample is not None
|
||||||
else inference_params.sequence_len_offset
|
else inference_params.seqlen_offset
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
rotary_max_seqlen = (
|
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
||||||
inference_params.max_sequence_len if inference_params is not None else None
|
|
||||||
)
|
|
||||||
if self.num_heads_kv == self.num_heads:
|
if self.num_heads_kv == self.num_heads:
|
||||||
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
||||||
if (
|
if (
|
||||||
inference_params is None
|
inference_params is None
|
||||||
or inference_params.sequence_len_offset == 0
|
or inference_params.seqlen_offset == 0
|
||||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||||
or not self.use_flash_attn
|
or not self.use_flash_attn
|
||||||
):
|
):
|
||||||
@ -892,7 +888,7 @@ class ParallelMHA(nn.Module):
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
inference_params is None
|
inference_params is None
|
||||||
or inference_params.sequence_len_offset == 0
|
or inference_params.seqlen_offset == 0
|
||||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||||
or not self.use_flash_attn
|
or not self.use_flash_attn
|
||||||
):
|
):
|
||||||
|
|||||||
@ -20,13 +20,20 @@ class InferenceParams:
|
|||||||
"""Inference parameters that are passed to the main model in order
|
"""Inference parameters that are passed to the main model in order
|
||||||
to efficienly calculate and store the context during inference."""
|
to efficienly calculate and store the context during inference."""
|
||||||
|
|
||||||
max_sequence_len: int
|
max_seqlen: int
|
||||||
max_batch_size: int
|
max_batch_size: int
|
||||||
sequence_len_offset: int = 0
|
seqlen_offset: int = 0
|
||||||
batch_size_offset: int = 0
|
batch_size_offset: int = 0
|
||||||
key_value_memory_dict: dict = field(default_factory=dict)
|
key_value_memory_dict: dict = field(default_factory=dict)
|
||||||
lengths_per_sample: Optional[Tensor] = None
|
lengths_per_sample: Optional[Tensor] = None
|
||||||
|
|
||||||
|
def reset(self, max_seqlen, max_batch_size):
|
||||||
|
self.max_seqlen = max_seqlen
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.seqlen_offset = 0
|
||||||
|
if self.lengths_per_sample is not None:
|
||||||
|
self.lengths_per_sample.zero_()
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
||||||
@ -127,19 +134,16 @@ def decode(
|
|||||||
tensor_parallel=tensor_parallel,
|
tensor_parallel=tensor_parallel,
|
||||||
)
|
)
|
||||||
inference_params = model._decoding_cache.inference_params
|
inference_params = model._decoding_cache.inference_params
|
||||||
inference_params.max_sequence_len = max_length
|
inference_params.reset(max_length, batch_size)
|
||||||
inference_params.max_batch_size = batch_size
|
|
||||||
inference_params.sequence_len_offset = 0
|
|
||||||
inference_params.lengths_per_sample.zero_()
|
|
||||||
else:
|
else:
|
||||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
|
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||||
|
|
||||||
def get_logits(input_ids, inference_params):
|
def get_logits(input_ids, inference_params):
|
||||||
decoding = inference_params.sequence_len_offset > 0
|
decoding = inference_params.seqlen_offset > 0
|
||||||
if decoding:
|
if decoding:
|
||||||
position_ids = torch.full(
|
position_ids = torch.full(
|
||||||
(batch_size, 1),
|
(batch_size, 1),
|
||||||
inference_params.sequence_len_offset,
|
inference_params.seqlen_offset,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
)
|
)
|
||||||
@ -154,24 +158,24 @@ def decode(
|
|||||||
).logits.squeeze(dim=1)
|
).logits.squeeze(dim=1)
|
||||||
else:
|
else:
|
||||||
logits = model._decoding_cache.run(
|
logits = model._decoding_cache.run(
|
||||||
input_ids, position_ids, inference_params.sequence_len_offset
|
input_ids, position_ids, inference_params.seqlen_offset
|
||||||
).clone()
|
).clone()
|
||||||
return logits[..., :vocab_size] if vocab_size is not None else logits
|
return logits[..., :vocab_size] if vocab_size is not None else logits
|
||||||
|
|
||||||
def sample_tokens(logits, inference_params):
|
def sample_tokens(logits, inference_params):
|
||||||
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset:
|
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
||||||
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||||
else:
|
else:
|
||||||
token = teacher_outputs[:, inference_params.sequence_len_offset]
|
token = teacher_outputs[:, inference_params.seqlen_offset]
|
||||||
# return rearrange(token, "b -> b 1")
|
# return rearrange(token, "b -> b 1")
|
||||||
return token.unsqueeze(1)
|
return token.unsqueeze(1)
|
||||||
|
|
||||||
def should_stop(current_token, inference_params):
|
def should_stop(current_token, inference_params):
|
||||||
if inference_params.sequence_len_offset == 0:
|
if inference_params.seqlen_offset == 0:
|
||||||
return False
|
return False
|
||||||
if eos_token_id is not None and (current_token == eos_token_id).all():
|
if eos_token_id is not None and (current_token == eos_token_id).all():
|
||||||
return True
|
return True
|
||||||
if inference_params.sequence_len_offset >= max_length - 1:
|
if inference_params.seqlen_offset >= max_length - 1:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -185,7 +189,7 @@ def decode(
|
|||||||
scores, sequences = [], [input_ids]
|
scores, sequences = [], [input_ids]
|
||||||
while not should_stop(sequences[-1], inference_params):
|
while not should_stop(sequences[-1], inference_params):
|
||||||
scores.append(get_logits(sequences[-1], inference_params))
|
scores.append(get_logits(sequences[-1], inference_params))
|
||||||
inference_params.sequence_len_offset += sequences[-1].shape[1]
|
inference_params.seqlen_offset += sequences[-1].shape[1]
|
||||||
sequences.append(sample_tokens(scores[-1], inference_params))
|
sequences.append(sample_tokens(scores[-1], inference_params))
|
||||||
if enable_timing:
|
if enable_timing:
|
||||||
end.record()
|
end.record()
|
||||||
@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t
|
|||||||
return tokens, first_rejected_idx + 1
|
return tokens, first_rejected_idx + 1
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def decode_speculative(
|
def decode_speculative(
|
||||||
input_ids,
|
input_ids,
|
||||||
model,
|
model,
|
||||||
@ -303,15 +308,11 @@ def decode_speculative(
|
|||||||
tensor_parallel=tensor_parallel,
|
tensor_parallel=tensor_parallel,
|
||||||
)
|
)
|
||||||
inference_params_draft = model_draft._decoding_cache.inference_params
|
inference_params_draft = model_draft._decoding_cache.inference_params
|
||||||
inference_params_draft.max_sequence_len = max_length
|
inference_params_draft.reset(max_length, batch_size)
|
||||||
inference_params_draft.max_batch_size = batch_size
|
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||||
inference_params_draft.sequence_len_offset = 0
|
|
||||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
|
|
||||||
else:
|
else:
|
||||||
inference_params_draft = InferenceParams(
|
inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||||
max_sequence_len=max_length, max_batch_size=batch_size
|
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||||
)
|
|
||||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
|
|
||||||
|
|
||||||
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
|
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
|
||||||
if not cg:
|
if not cg:
|
||||||
@ -323,7 +324,7 @@ def decode_speculative(
|
|||||||
).logits.squeeze(dim=1)
|
).logits.squeeze(dim=1)
|
||||||
else:
|
else:
|
||||||
return model._decoding_cache.run(
|
return model._decoding_cache.run(
|
||||||
input_ids, position_ids, inference_params.sequence_len_offset
|
input_ids, position_ids, inference_params.seqlen_offset
|
||||||
).clone()
|
).clone()
|
||||||
|
|
||||||
logits_postprocess_fn = (
|
logits_postprocess_fn = (
|
||||||
@ -365,13 +366,13 @@ def decode_speculative(
|
|||||||
assert seqlen == 1
|
assert seqlen == 1
|
||||||
position_ids = repeat(
|
position_ids = repeat(
|
||||||
torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||||
+ inference_params.sequence_len_offset,
|
+ inference_params.seqlen_offset,
|
||||||
"s -> b s",
|
"s -> b s",
|
||||||
b=batch_size,
|
b=batch_size,
|
||||||
)
|
)
|
||||||
# position_ids = torch.full(
|
# position_ids = torch.full(
|
||||||
# (batch_size, 1),
|
# (batch_size, 1),
|
||||||
# inference_params.sequence_len_offset,
|
# inference_params.seqlen_offset,
|
||||||
# dtype=torch.long,
|
# dtype=torch.long,
|
||||||
# device=input_ids.device,
|
# device=input_ids.device,
|
||||||
# )
|
# )
|
||||||
@ -380,7 +381,7 @@ def decode_speculative(
|
|||||||
logits = logits_postprocess_fn(
|
logits = logits_postprocess_fn(
|
||||||
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
|
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
|
||||||
)
|
)
|
||||||
inference_params.sequence_len_offset += input_ids.shape[1]
|
inference_params.seqlen_offset += input_ids.shape[1]
|
||||||
scores = [logits]
|
scores = [logits]
|
||||||
next_token = sample_fn(logits)
|
next_token = sample_fn(logits)
|
||||||
sequences.append(next_token)
|
sequences.append(next_token)
|
||||||
@ -388,7 +389,7 @@ def decode_speculative(
|
|||||||
if i < num_tokens - 1 or last_token_logits:
|
if i < num_tokens - 1 or last_token_logits:
|
||||||
position_ids = torch.full(
|
position_ids = torch.full(
|
||||||
(batch_size, 1),
|
(batch_size, 1),
|
||||||
inference_params_draft.sequence_len_offset,
|
inference_params_draft.seqlen_offset,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
)
|
)
|
||||||
@ -401,7 +402,7 @@ def decode_speculative(
|
|||||||
cg=cg,
|
cg=cg,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
inference_params.sequence_len_offset += 1
|
inference_params.seqlen_offset += 1
|
||||||
scores.append(logits)
|
scores.append(logits)
|
||||||
if i < num_tokens - 1:
|
if i < num_tokens - 1:
|
||||||
next_token = sample_fn(logits)
|
next_token = sample_fn(logits)
|
||||||
@ -476,8 +477,8 @@ def decode_speculative(
|
|||||||
scores.append(logits[:1, : num_generated_tokens[0]])
|
scores.append(logits[:1, : num_generated_tokens[0]])
|
||||||
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
|
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
|
||||||
# that in the next time we call @model.
|
# that in the next time we call @model.
|
||||||
inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1
|
inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1
|
||||||
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
|
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
|
||||||
if debug:
|
if debug:
|
||||||
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
|
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
|
||||||
scores_ref = model(
|
scores_ref = model(
|
||||||
@ -486,10 +487,10 @@ def decode_speculative(
|
|||||||
print((scores[-1] - scores_ref[:, :-1]).abs().max())
|
print((scores[-1] - scores_ref[:, :-1]).abs().max())
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# sequence_len_offset is total length generated - 1
|
# seqlen_offset is total length generated - 1
|
||||||
if inference_params.sequence_len_offset >= max_length - 1:
|
if inference_params.seqlen_offset >= max_length - 1:
|
||||||
break
|
break
|
||||||
if inference_params.sequence_len_offset >= max_length - 2:
|
if inference_params.seqlen_offset >= max_length - 2:
|
||||||
# Don't do speculative sampling, just sample 1 token from the model
|
# Don't do speculative sampling, just sample 1 token from the model
|
||||||
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
|
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
|
||||||
sequences.append(tokens)
|
sequences.append(tokens)
|
||||||
@ -497,7 +498,7 @@ def decode_speculative(
|
|||||||
break
|
break
|
||||||
# Sample from draft model
|
# Sample from draft model
|
||||||
n_spec_tokens = min(
|
n_spec_tokens = min(
|
||||||
speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2
|
speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
|
||||||
)
|
)
|
||||||
tokens_draft, scores_draft = sample_tokens_draft(
|
tokens_draft, scores_draft = sample_tokens_draft(
|
||||||
sequences[-1][:, -1:], num_tokens=n_spec_tokens
|
sequences[-1][:, -1:], num_tokens=n_spec_tokens
|
||||||
@ -510,9 +511,9 @@ def decode_speculative(
|
|||||||
# Evaluate the draft tokens with the model
|
# Evaluate the draft tokens with the model
|
||||||
position_ids = repeat(
|
position_ids = repeat(
|
||||||
torch.arange(
|
torch.arange(
|
||||||
inference_params.sequence_len_offset,
|
inference_params.seqlen_offset,
|
||||||
# 1 extra token from last time that hasn't been passed through model
|
# 1 extra token from last time that hasn't been passed through model
|
||||||
inference_params.sequence_len_offset + n_spec_tokens + 1,
|
inference_params.seqlen_offset + n_spec_tokens + 1,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
),
|
),
|
||||||
@ -525,7 +526,7 @@ def decode_speculative(
|
|||||||
inference_params=inference_params,
|
inference_params=inference_params,
|
||||||
).logits # (batch, n_spec_tokens, vocab_size)
|
).logits # (batch, n_spec_tokens, vocab_size)
|
||||||
logits = logits_postprocess_fn(logits)
|
logits = logits_postprocess_fn(logits)
|
||||||
inference_params.sequence_len_offset += 1
|
inference_params.seqlen_offset += 1
|
||||||
if debug:
|
if debug:
|
||||||
logits_ref = model(
|
logits_ref = model(
|
||||||
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
|
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
|
||||||
@ -539,8 +540,8 @@ def decode_speculative(
|
|||||||
print(num_generated_tokens)
|
print(num_generated_tokens)
|
||||||
sequences.append(tokens[:1, : num_generated_tokens[0]])
|
sequences.append(tokens[:1, : num_generated_tokens[0]])
|
||||||
scores.append(logits[:1, : num_generated_tokens[0]])
|
scores.append(logits[:1, : num_generated_tokens[0]])
|
||||||
inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1
|
inference_params.seqlen_offset += num_generated_tokens[0].item() - 1
|
||||||
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
|
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
|
||||||
# breakpoint()
|
# breakpoint()
|
||||||
if debug:
|
if debug:
|
||||||
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
|
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
|
||||||
@ -679,9 +680,9 @@ def update_graph_cache(
|
|||||||
)
|
)
|
||||||
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
||||||
cache.inference_params = InferenceParams(
|
cache.inference_params = InferenceParams(
|
||||||
max_sequence_len=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
max_batch_size=batch_size,
|
max_batch_size=batch_size,
|
||||||
sequence_len_offset=seqlen_og,
|
seqlen_offset=seqlen_og,
|
||||||
key_value_memory_dict=inf_cache,
|
key_value_memory_dict=inf_cache,
|
||||||
lengths_per_sample=lengths_per_sample,
|
lengths_per_sample=lengths_per_sample,
|
||||||
)
|
)
|
||||||
@ -705,7 +706,7 @@ def update_graph_cache(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cache.run = dispatch
|
cache.run = dispatch
|
||||||
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
|
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
|||||||
device = next(iter(model.parameters())).device
|
device = next(iter(model.parameters())).device
|
||||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||||
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||||
sequence_len_offset_og = inference_params.sequence_len_offset
|
seqlen_offset_og = inference_params.seqlen_offset
|
||||||
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
|
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
|
||||||
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
|
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
|
||||||
inference_params.sequence_len_offset = max_seqlen - 1
|
inference_params.seqlen_offset = max_seqlen - 1
|
||||||
inference_params.lengths_per_sample[:] = max_seqlen - 1
|
inference_params.lengths_per_sample[:] = max_seqlen - 1
|
||||||
|
|
||||||
# Warmup before capture
|
# Warmup before capture
|
||||||
@ -755,5 +756,5 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
|||||||
graph.replay()
|
graph.replay()
|
||||||
return logits.clone()
|
return logits.clone()
|
||||||
|
|
||||||
inference_params.sequence_len_offset = sequence_len_offset_og
|
inference_params.seqlen_offset = seqlen_offset_og
|
||||||
return run
|
return run
|
||||||
|
|||||||
@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
|
|||||||
logits_ref = model(input_ids).logits
|
logits_ref = model(input_ids).logits
|
||||||
|
|
||||||
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
|
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
|
||||||
inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1)
|
inference_params = InferenceParams(max_seqlen=20, max_batch_size=1)
|
||||||
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
|
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
|
||||||
inference_params.sequence_len_offset += 10
|
inference_params.seqlen_offset += 10
|
||||||
position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
|
position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
|
||||||
logits_1014 = model(
|
logits_1014 = model(
|
||||||
input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
|
input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
|
||||||
).logits
|
).logits
|
||||||
inference_params.sequence_len_offset += 4
|
inference_params.seqlen_offset += 4
|
||||||
position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
|
position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
|
||||||
logits_1420 = model(
|
logits_1420 = model(
|
||||||
input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params
|
input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user