[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset

This commit is contained in:
Tri Dao 2023-09-18 16:08:44 -07:00
parent 42832575d4
commit e6a8026489
3 changed files with 72 additions and 75 deletions

View File

@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
if layer_idx not in inference_params.key_value_memory_dict: if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty( kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_batch_size,
inference_params.max_sequence_len, inference_params.max_seqlen,
2, 2,
num_heads, num_heads,
head_dim, head_dim,
@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
# Adjust key and value for inference # Adjust key and value for inference
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0] batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1] sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
@ -445,12 +445,12 @@ class MHA(nn.Module):
q: (batch_size, seqlen_q, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
""" """
assert inference_params is not None and inference_params.sequence_len_offset > 0 assert inference_params is not None and inference_params.seqlen_offset > 0
assert self.use_flash_attn assert self.use_flash_attn
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert self.rotary_emb.scale is None, "This code path does not support xPos" assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.rotary_emb._update_cos_sin_cache( self.rotary_emb._update_cos_sin_cache(
inference_params.max_sequence_len, device=q.device, dtype=q.dtype inference_params.max_seqlen, device=q.device, dtype=q.dtype
) )
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else: else:
@ -460,7 +460,7 @@ class MHA(nn.Module):
cache_seqlens = ( cache_seqlens = (
inference_params.lengths_per_sample[:batch] inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset else inference_params.seqlen_offset
) )
context = flash_attn_with_kvcache( context = flash_attn_with_kvcache(
q, q,
@ -480,11 +480,11 @@ class MHA(nn.Module):
def _update_kvcache_attention(self, q, kv, inference_params): def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention""" """Write kv to inference_params, then do attention"""
if ( if (
inference_params.sequence_len_offset == 0 inference_params.seqlen_offset == 0
or flash_attn_with_kvcache is None or flash_attn_with_kvcache is None
or not self.use_flash_attn or not self.use_flash_attn
): ):
# TODO: this only uses sequence_len_offset and not lengths_per_sample. # TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params) kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv) return self.inner_cross_attn(q, kv)
else: else:
@ -493,7 +493,7 @@ class MHA(nn.Module):
cache_seqlens = ( cache_seqlens = (
inference_params.lengths_per_sample[:batch] inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset else inference_params.seqlen_offset
) )
return flash_attn_with_kvcache( return flash_attn_with_kvcache(
q, q,
@ -561,12 +561,10 @@ class MHA(nn.Module):
else ( else (
inference_params.lengths_per_sample inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset else inference_params.seqlen_offset
) )
) )
rotary_max_seqlen = ( rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
inference_params.max_sequence_len if inference_params is not None else None
)
batch, seqlen = x.shape[:2] batch, seqlen = x.shape[:2]
if not self.cross_attn and self.num_heads_kv == self.num_heads: if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None assert x_kv is None and mixer_subset is None
@ -581,7 +579,7 @@ class MHA(nn.Module):
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn or not self.use_flash_attn
): ):
@ -632,7 +630,7 @@ class MHA(nn.Module):
).contiguous() ).contiguous()
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn or not self.use_flash_attn
): ):
@ -772,12 +770,12 @@ class ParallelMHA(nn.Module):
q: (batch_size, seqlen_q, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
""" """
assert inference_params is not None and inference_params.sequence_len_offset > 0 assert inference_params is not None and inference_params.seqlen_offset > 0
assert self.use_flash_attn assert self.use_flash_attn
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert self.rotary_emb.scale is None, "This code path does not support xPos" assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.rotary_emb._update_cos_sin_cache( self.rotary_emb._update_cos_sin_cache(
inference_params.max_sequence_len, device=q.device, dtype=q.dtype inference_params.max_seqlen, device=q.device, dtype=q.dtype
) )
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else: else:
@ -787,7 +785,7 @@ class ParallelMHA(nn.Module):
cache_seqlens = ( cache_seqlens = (
inference_params.lengths_per_sample[:batch] inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset else inference_params.seqlen_offset
) )
context = flash_attn_with_kvcache( context = flash_attn_with_kvcache(
q, q,
@ -806,8 +804,8 @@ class ParallelMHA(nn.Module):
def _update_kvcache_attention(self, q, kv, inference_params): def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention""" """Write kv to inference_params, then do attention"""
if inference_params.sequence_len_offset == 0 or not self.use_flash_attn: if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
# TODO: this only uses sequence_len_offset and not lengths_per_sample. # TODO: this only uses seqlen_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params) kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv) return self.inner_cross_attn(q, kv)
else: else:
@ -816,7 +814,7 @@ class ParallelMHA(nn.Module):
cache_seqlens = ( cache_seqlens = (
inference_params.lengths_per_sample[:batch] inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset else inference_params.seqlen_offset
) )
context = flash_attn_with_kvcache( context = flash_attn_with_kvcache(
q, q,
@ -847,17 +845,15 @@ class ParallelMHA(nn.Module):
else ( else (
inference_params.lengths_per_sample inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset else inference_params.seqlen_offset
) )
) )
rotary_max_seqlen = ( rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
inference_params.max_sequence_len if inference_params is not None else None
)
if self.num_heads_kv == self.num_heads: if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn or not self.use_flash_attn
): ):
@ -892,7 +888,7 @@ class ParallelMHA(nn.Module):
) )
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.seqlen_offset == 0
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn or not self.use_flash_attn
): ):

View File

@ -20,13 +20,20 @@ class InferenceParams:
"""Inference parameters that are passed to the main model in order """Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.""" to efficienly calculate and store the context during inference."""
max_sequence_len: int max_seqlen: int
max_batch_size: int max_batch_size: int
sequence_len_offset: int = 0 seqlen_offset: int = 0
batch_size_offset: int = 0 batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict) key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
@ -127,19 +134,16 @@ def decode(
tensor_parallel=tensor_parallel, tensor_parallel=tensor_parallel,
) )
inference_params = model._decoding_cache.inference_params inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length inference_params.reset(max_length, batch_size)
inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0
inference_params.lengths_per_sample.zero_()
else: else:
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size) inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def get_logits(input_ids, inference_params): def get_logits(input_ids, inference_params):
decoding = inference_params.sequence_len_offset > 0 decoding = inference_params.seqlen_offset > 0
if decoding: if decoding:
position_ids = torch.full( position_ids = torch.full(
(batch_size, 1), (batch_size, 1),
inference_params.sequence_len_offset, inference_params.seqlen_offset,
dtype=torch.long, dtype=torch.long,
device=input_ids.device, device=input_ids.device,
) )
@ -154,24 +158,24 @@ def decode(
).logits.squeeze(dim=1) ).logits.squeeze(dim=1)
else: else:
logits = model._decoding_cache.run( logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset input_ids, position_ids, inference_params.seqlen_offset
).clone() ).clone()
return logits[..., :vocab_size] if vocab_size is not None else logits return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(logits, inference_params): def sample_tokens(logits, inference_params):
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset: if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else: else:
token = teacher_outputs[:, inference_params.sequence_len_offset] token = teacher_outputs[:, inference_params.seqlen_offset]
# return rearrange(token, "b -> b 1") # return rearrange(token, "b -> b 1")
return token.unsqueeze(1) return token.unsqueeze(1)
def should_stop(current_token, inference_params): def should_stop(current_token, inference_params):
if inference_params.sequence_len_offset == 0: if inference_params.seqlen_offset == 0:
return False return False
if eos_token_id is not None and (current_token == eos_token_id).all(): if eos_token_id is not None and (current_token == eos_token_id).all():
return True return True
if inference_params.sequence_len_offset >= max_length - 1: if inference_params.seqlen_offset >= max_length - 1:
return True return True
return False return False
@ -185,7 +189,7 @@ def decode(
scores, sequences = [], [input_ids] scores, sequences = [], [input_ids]
while not should_stop(sequences[-1], inference_params): while not should_stop(sequences[-1], inference_params):
scores.append(get_logits(sequences[-1], inference_params)) scores.append(get_logits(sequences[-1], inference_params))
inference_params.sequence_len_offset += sequences[-1].shape[1] inference_params.seqlen_offset += sequences[-1].shape[1]
sequences.append(sample_tokens(scores[-1], inference_params)) sequences.append(sample_tokens(scores[-1], inference_params))
if enable_timing: if enable_timing:
end.record() end.record()
@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t
return tokens, first_rejected_idx + 1 return tokens, first_rejected_idx + 1
@torch.inference_mode()
def decode_speculative( def decode_speculative(
input_ids, input_ids,
model, model,
@ -303,15 +308,11 @@ def decode_speculative(
tensor_parallel=tensor_parallel, tensor_parallel=tensor_parallel,
) )
inference_params_draft = model_draft._decoding_cache.inference_params inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length inference_params_draft.reset(max_length, batch_size)
inference_params_draft.max_batch_size = batch_size inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
inference_params_draft.sequence_len_offset = 0
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
else: else:
inference_params_draft = InferenceParams( inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
max_sequence_len=max_length, max_batch_size=batch_size inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
)
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False): def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
if not cg: if not cg:
@ -323,7 +324,7 @@ def decode_speculative(
).logits.squeeze(dim=1) ).logits.squeeze(dim=1)
else: else:
return model._decoding_cache.run( return model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset input_ids, position_ids, inference_params.seqlen_offset
).clone() ).clone()
logits_postprocess_fn = ( logits_postprocess_fn = (
@ -365,13 +366,13 @@ def decode_speculative(
assert seqlen == 1 assert seqlen == 1
position_ids = repeat( position_ids = repeat(
torch.arange(seqlen, dtype=torch.long, device=input_ids.device) torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
+ inference_params.sequence_len_offset, + inference_params.seqlen_offset,
"s -> b s", "s -> b s",
b=batch_size, b=batch_size,
) )
# position_ids = torch.full( # position_ids = torch.full(
# (batch_size, 1), # (batch_size, 1),
# inference_params.sequence_len_offset, # inference_params.seqlen_offset,
# dtype=torch.long, # dtype=torch.long,
# device=input_ids.device, # device=input_ids.device,
# ) # )
@ -380,7 +381,7 @@ def decode_speculative(
logits = logits_postprocess_fn( logits = logits_postprocess_fn(
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg) logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
) )
inference_params.sequence_len_offset += input_ids.shape[1] inference_params.seqlen_offset += input_ids.shape[1]
scores = [logits] scores = [logits]
next_token = sample_fn(logits) next_token = sample_fn(logits)
sequences.append(next_token) sequences.append(next_token)
@ -388,7 +389,7 @@ def decode_speculative(
if i < num_tokens - 1 or last_token_logits: if i < num_tokens - 1 or last_token_logits:
position_ids = torch.full( position_ids = torch.full(
(batch_size, 1), (batch_size, 1),
inference_params_draft.sequence_len_offset, inference_params_draft.seqlen_offset,
dtype=torch.long, dtype=torch.long,
device=input_ids.device, device=input_ids.device,
) )
@ -401,7 +402,7 @@ def decode_speculative(
cg=cg, cg=cg,
) )
) )
inference_params.sequence_len_offset += 1 inference_params.seqlen_offset += 1
scores.append(logits) scores.append(logits)
if i < num_tokens - 1: if i < num_tokens - 1:
next_token = sample_fn(logits) next_token = sample_fn(logits)
@ -476,8 +477,8 @@ def decode_speculative(
scores.append(logits[:1, : num_generated_tokens[0]]) scores.append(logits[:1, : num_generated_tokens[0]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass # Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model. # that in the next time we call @model.
inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1 inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset inference_params_draft.seqlen_offset = inference_params.seqlen_offset
if debug: if debug:
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
scores_ref = model( scores_ref = model(
@ -486,10 +487,10 @@ def decode_speculative(
print((scores[-1] - scores_ref[:, :-1]).abs().max()) print((scores[-1] - scores_ref[:, :-1]).abs().max())
while True: while True:
# sequence_len_offset is total length generated - 1 # seqlen_offset is total length generated - 1
if inference_params.sequence_len_offset >= max_length - 1: if inference_params.seqlen_offset >= max_length - 1:
break break
if inference_params.sequence_len_offset >= max_length - 2: if inference_params.seqlen_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model # Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens) sequences.append(tokens)
@ -497,7 +498,7 @@ def decode_speculative(
break break
# Sample from draft model # Sample from draft model
n_spec_tokens = min( n_spec_tokens = min(
speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2 speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
) )
tokens_draft, scores_draft = sample_tokens_draft( tokens_draft, scores_draft = sample_tokens_draft(
sequences[-1][:, -1:], num_tokens=n_spec_tokens sequences[-1][:, -1:], num_tokens=n_spec_tokens
@ -510,9 +511,9 @@ def decode_speculative(
# Evaluate the draft tokens with the model # Evaluate the draft tokens with the model
position_ids = repeat( position_ids = repeat(
torch.arange( torch.arange(
inference_params.sequence_len_offset, inference_params.seqlen_offset,
# 1 extra token from last time that hasn't been passed through model # 1 extra token from last time that hasn't been passed through model
inference_params.sequence_len_offset + n_spec_tokens + 1, inference_params.seqlen_offset + n_spec_tokens + 1,
dtype=torch.long, dtype=torch.long,
device=input_ids.device, device=input_ids.device,
), ),
@ -525,7 +526,7 @@ def decode_speculative(
inference_params=inference_params, inference_params=inference_params,
).logits # (batch, n_spec_tokens, vocab_size) ).logits # (batch, n_spec_tokens, vocab_size)
logits = logits_postprocess_fn(logits) logits = logits_postprocess_fn(logits)
inference_params.sequence_len_offset += 1 inference_params.seqlen_offset += 1
if debug: if debug:
logits_ref = model( logits_ref = model(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
@ -539,8 +540,8 @@ def decode_speculative(
print(num_generated_tokens) print(num_generated_tokens)
sequences.append(tokens[:1, : num_generated_tokens[0]]) sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]]) scores.append(logits[:1, : num_generated_tokens[0]])
inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1 inference_params.seqlen_offset += num_generated_tokens[0].item() - 1
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset inference_params_draft.seqlen_offset = inference_params.seqlen_offset
# breakpoint() # breakpoint()
if debug: if debug:
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
@ -679,9 +680,9 @@ def update_graph_cache(
) )
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams( cache.inference_params = InferenceParams(
max_sequence_len=max_seqlen, max_seqlen=max_seqlen,
max_batch_size=batch_size, max_batch_size=batch_size,
sequence_len_offset=seqlen_og, seqlen_offset=seqlen_og,
key_value_memory_dict=inf_cache, key_value_memory_dict=inf_cache,
lengths_per_sample=lengths_per_sample, lengths_per_sample=lengths_per_sample,
) )
@ -705,7 +706,7 @@ def update_graph_cache(
) )
cache.run = dispatch cache.run = dispatch
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
return cache return cache
@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
device = next(iter(model.parameters())).device device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
sequence_len_offset_og = inference_params.sequence_len_offset seqlen_offset_og = inference_params.seqlen_offset
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample. # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
inference_params.sequence_len_offset = max_seqlen - 1 inference_params.seqlen_offset = max_seqlen - 1
inference_params.lengths_per_sample[:] = max_seqlen - 1 inference_params.lengths_per_sample[:] = max_seqlen - 1
# Warmup before capture # Warmup before capture
@ -755,5 +756,5 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
graph.replay() graph.replay()
return logits.clone() return logits.clone()
inference_params.sequence_len_offset = sequence_len_offset_og inference_params.seqlen_offset = seqlen_offset_og
return run return run

View File

@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
logits_ref = model(input_ids).logits logits_ref = model(input_ids).logits
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1) inference_params = InferenceParams(max_seqlen=20, max_batch_size=1)
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
inference_params.sequence_len_offset += 10 inference_params.seqlen_offset += 10
position_ids = torch.arange(10, 14, dtype=torch.long, device=device) position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
logits_1014 = model( logits_1014 = model(
input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
).logits ).logits
inference_params.sequence_len_offset += 4 inference_params.seqlen_offset += 4
position_ids = torch.arange(14, 20, dtype=torch.long, device=device) position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
logits_1420 = model( logits_1420 = model(
input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params