[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
This commit is contained in:
parent
42832575d4
commit
e6a8026489
@ -300,7 +300,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
|
||||
if layer_idx not in inference_params.key_value_memory_dict:
|
||||
kv_cache = torch.empty(
|
||||
inference_params.max_batch_size,
|
||||
inference_params.max_sequence_len,
|
||||
inference_params.max_seqlen,
|
||||
2,
|
||||
num_heads,
|
||||
head_dim,
|
||||
@ -313,7 +313,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_start = inference_params.seqlen_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
||||
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
||||
@ -445,12 +445,12 @@ class MHA(nn.Module):
|
||||
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
||||
"""
|
||||
assert inference_params is not None and inference_params.sequence_len_offset > 0
|
||||
assert inference_params is not None and inference_params.seqlen_offset > 0
|
||||
assert self.use_flash_attn
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
||||
self.rotary_emb._update_cos_sin_cache(
|
||||
inference_params.max_sequence_len, device=q.device, dtype=q.dtype
|
||||
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
||||
)
|
||||
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
||||
else:
|
||||
@ -460,7 +460,7 @@ class MHA(nn.Module):
|
||||
cache_seqlens = (
|
||||
inference_params.lengths_per_sample[:batch]
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.sequence_len_offset
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
context = flash_attn_with_kvcache(
|
||||
q,
|
||||
@ -480,11 +480,11 @@ class MHA(nn.Module):
|
||||
def _update_kvcache_attention(self, q, kv, inference_params):
|
||||
"""Write kv to inference_params, then do attention"""
|
||||
if (
|
||||
inference_params.sequence_len_offset == 0
|
||||
inference_params.seqlen_offset == 0
|
||||
or flash_attn_with_kvcache is None
|
||||
or not self.use_flash_attn
|
||||
):
|
||||
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
|
||||
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
||||
kv = self._update_kv_cache(kv, inference_params)
|
||||
return self.inner_cross_attn(q, kv)
|
||||
else:
|
||||
@ -493,7 +493,7 @@ class MHA(nn.Module):
|
||||
cache_seqlens = (
|
||||
inference_params.lengths_per_sample[:batch]
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.sequence_len_offset
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
return flash_attn_with_kvcache(
|
||||
q,
|
||||
@ -561,12 +561,10 @@ class MHA(nn.Module):
|
||||
else (
|
||||
inference_params.lengths_per_sample
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.sequence_len_offset
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
)
|
||||
rotary_max_seqlen = (
|
||||
inference_params.max_sequence_len if inference_params is not None else None
|
||||
)
|
||||
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
||||
batch, seqlen = x.shape[:2]
|
||||
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
||||
assert x_kv is None and mixer_subset is None
|
||||
@ -581,7 +579,7 @@ class MHA(nn.Module):
|
||||
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
||||
if (
|
||||
inference_params is None
|
||||
or inference_params.sequence_len_offset == 0
|
||||
or inference_params.seqlen_offset == 0
|
||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||
or not self.use_flash_attn
|
||||
):
|
||||
@ -632,7 +630,7 @@ class MHA(nn.Module):
|
||||
).contiguous()
|
||||
if (
|
||||
inference_params is None
|
||||
or inference_params.sequence_len_offset == 0
|
||||
or inference_params.seqlen_offset == 0
|
||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||
or not self.use_flash_attn
|
||||
):
|
||||
@ -772,12 +770,12 @@ class ParallelMHA(nn.Module):
|
||||
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
||||
"""
|
||||
assert inference_params is not None and inference_params.sequence_len_offset > 0
|
||||
assert inference_params is not None and inference_params.seqlen_offset > 0
|
||||
assert self.use_flash_attn
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
||||
self.rotary_emb._update_cos_sin_cache(
|
||||
inference_params.max_sequence_len, device=q.device, dtype=q.dtype
|
||||
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
||||
)
|
||||
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
||||
else:
|
||||
@ -787,7 +785,7 @@ class ParallelMHA(nn.Module):
|
||||
cache_seqlens = (
|
||||
inference_params.lengths_per_sample[:batch]
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.sequence_len_offset
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
context = flash_attn_with_kvcache(
|
||||
q,
|
||||
@ -806,8 +804,8 @@ class ParallelMHA(nn.Module):
|
||||
|
||||
def _update_kvcache_attention(self, q, kv, inference_params):
|
||||
"""Write kv to inference_params, then do attention"""
|
||||
if inference_params.sequence_len_offset == 0 or not self.use_flash_attn:
|
||||
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
|
||||
if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
|
||||
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
||||
kv = self._update_kv_cache(kv, inference_params)
|
||||
return self.inner_cross_attn(q, kv)
|
||||
else:
|
||||
@ -816,7 +814,7 @@ class ParallelMHA(nn.Module):
|
||||
cache_seqlens = (
|
||||
inference_params.lengths_per_sample[:batch]
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.sequence_len_offset
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
context = flash_attn_with_kvcache(
|
||||
q,
|
||||
@ -847,17 +845,15 @@ class ParallelMHA(nn.Module):
|
||||
else (
|
||||
inference_params.lengths_per_sample
|
||||
if inference_params.lengths_per_sample is not None
|
||||
else inference_params.sequence_len_offset
|
||||
else inference_params.seqlen_offset
|
||||
)
|
||||
)
|
||||
rotary_max_seqlen = (
|
||||
inference_params.max_sequence_len if inference_params is not None else None
|
||||
)
|
||||
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
||||
if self.num_heads_kv == self.num_heads:
|
||||
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
||||
if (
|
||||
inference_params is None
|
||||
or inference_params.sequence_len_offset == 0
|
||||
or inference_params.seqlen_offset == 0
|
||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||
or not self.use_flash_attn
|
||||
):
|
||||
@ -892,7 +888,7 @@ class ParallelMHA(nn.Module):
|
||||
)
|
||||
if (
|
||||
inference_params is None
|
||||
or inference_params.sequence_len_offset == 0
|
||||
or inference_params.seqlen_offset == 0
|
||||
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
||||
or not self.use_flash_attn
|
||||
):
|
||||
|
||||
@ -20,13 +20,20 @@ class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
|
||||
max_sequence_len: int
|
||||
max_seqlen: int
|
||||
max_batch_size: int
|
||||
sequence_len_offset: int = 0
|
||||
seqlen_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
def reset(self, max_seqlen, max_batch_size):
|
||||
self.max_seqlen = max_seqlen
|
||||
self.max_batch_size = max_batch_size
|
||||
self.seqlen_offset = 0
|
||||
if self.lengths_per_sample is not None:
|
||||
self.lengths_per_sample.zero_()
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
||||
@ -127,19 +134,16 @@ def decode(
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
inference_params = model._decoding_cache.inference_params
|
||||
inference_params.max_sequence_len = max_length
|
||||
inference_params.max_batch_size = batch_size
|
||||
inference_params.sequence_len_offset = 0
|
||||
inference_params.lengths_per_sample.zero_()
|
||||
inference_params.reset(max_length, batch_size)
|
||||
else:
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
|
||||
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||
|
||||
def get_logits(input_ids, inference_params):
|
||||
decoding = inference_params.sequence_len_offset > 0
|
||||
decoding = inference_params.seqlen_offset > 0
|
||||
if decoding:
|
||||
position_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
inference_params.sequence_len_offset,
|
||||
inference_params.seqlen_offset,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
@ -154,24 +158,24 @@ def decode(
|
||||
).logits.squeeze(dim=1)
|
||||
else:
|
||||
logits = model._decoding_cache.run(
|
||||
input_ids, position_ids, inference_params.sequence_len_offset
|
||||
input_ids, position_ids, inference_params.seqlen_offset
|
||||
).clone()
|
||||
return logits[..., :vocab_size] if vocab_size is not None else logits
|
||||
|
||||
def sample_tokens(logits, inference_params):
|
||||
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset:
|
||||
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
||||
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
token = teacher_outputs[:, inference_params.sequence_len_offset]
|
||||
token = teacher_outputs[:, inference_params.seqlen_offset]
|
||||
# return rearrange(token, "b -> b 1")
|
||||
return token.unsqueeze(1)
|
||||
|
||||
def should_stop(current_token, inference_params):
|
||||
if inference_params.sequence_len_offset == 0:
|
||||
if inference_params.seqlen_offset == 0:
|
||||
return False
|
||||
if eos_token_id is not None and (current_token == eos_token_id).all():
|
||||
return True
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
if inference_params.seqlen_offset >= max_length - 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -185,7 +189,7 @@ def decode(
|
||||
scores, sequences = [], [input_ids]
|
||||
while not should_stop(sequences[-1], inference_params):
|
||||
scores.append(get_logits(sequences[-1], inference_params))
|
||||
inference_params.sequence_len_offset += sequences[-1].shape[1]
|
||||
inference_params.seqlen_offset += sequences[-1].shape[1]
|
||||
sequences.append(sample_tokens(scores[-1], inference_params))
|
||||
if enable_timing:
|
||||
end.record()
|
||||
@ -256,6 +260,7 @@ def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, t
|
||||
return tokens, first_rejected_idx + 1
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_speculative(
|
||||
input_ids,
|
||||
model,
|
||||
@ -303,15 +308,11 @@ def decode_speculative(
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
inference_params_draft = model_draft._decoding_cache.inference_params
|
||||
inference_params_draft.max_sequence_len = max_length
|
||||
inference_params_draft.max_batch_size = batch_size
|
||||
inference_params_draft.sequence_len_offset = 0
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
|
||||
inference_params_draft.reset(max_length, batch_size)
|
||||
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||
else:
|
||||
inference_params_draft = InferenceParams(
|
||||
max_sequence_len=max_length, max_batch_size=batch_size
|
||||
)
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
|
||||
inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||
|
||||
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
|
||||
if not cg:
|
||||
@ -323,7 +324,7 @@ def decode_speculative(
|
||||
).logits.squeeze(dim=1)
|
||||
else:
|
||||
return model._decoding_cache.run(
|
||||
input_ids, position_ids, inference_params.sequence_len_offset
|
||||
input_ids, position_ids, inference_params.seqlen_offset
|
||||
).clone()
|
||||
|
||||
logits_postprocess_fn = (
|
||||
@ -365,13 +366,13 @@ def decode_speculative(
|
||||
assert seqlen == 1
|
||||
position_ids = repeat(
|
||||
torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
+ inference_params.sequence_len_offset,
|
||||
+ inference_params.seqlen_offset,
|
||||
"s -> b s",
|
||||
b=batch_size,
|
||||
)
|
||||
# position_ids = torch.full(
|
||||
# (batch_size, 1),
|
||||
# inference_params.sequence_len_offset,
|
||||
# inference_params.seqlen_offset,
|
||||
# dtype=torch.long,
|
||||
# device=input_ids.device,
|
||||
# )
|
||||
@ -380,7 +381,7 @@ def decode_speculative(
|
||||
logits = logits_postprocess_fn(
|
||||
logits_forward_fn(model, input_ids, position_ids, inference_params, cg=decoding and cg)
|
||||
)
|
||||
inference_params.sequence_len_offset += input_ids.shape[1]
|
||||
inference_params.seqlen_offset += input_ids.shape[1]
|
||||
scores = [logits]
|
||||
next_token = sample_fn(logits)
|
||||
sequences.append(next_token)
|
||||
@ -388,7 +389,7 @@ def decode_speculative(
|
||||
if i < num_tokens - 1 or last_token_logits:
|
||||
position_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
inference_params_draft.sequence_len_offset,
|
||||
inference_params_draft.seqlen_offset,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
@ -401,7 +402,7 @@ def decode_speculative(
|
||||
cg=cg,
|
||||
)
|
||||
)
|
||||
inference_params.sequence_len_offset += 1
|
||||
inference_params.seqlen_offset += 1
|
||||
scores.append(logits)
|
||||
if i < num_tokens - 1:
|
||||
next_token = sample_fn(logits)
|
||||
@ -476,8 +477,8 @@ def decode_speculative(
|
||||
scores.append(logits[:1, : num_generated_tokens[0]])
|
||||
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
|
||||
# that in the next time we call @model.
|
||||
inference_params.sequence_len_offset = seqlen_og + num_generated_tokens[0].item() - 1
|
||||
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
|
||||
inference_params.seqlen_offset = seqlen_og + num_generated_tokens[0].item() - 1
|
||||
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
|
||||
if debug:
|
||||
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
|
||||
scores_ref = model(
|
||||
@ -486,10 +487,10 @@ def decode_speculative(
|
||||
print((scores[-1] - scores_ref[:, :-1]).abs().max())
|
||||
|
||||
while True:
|
||||
# sequence_len_offset is total length generated - 1
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
# seqlen_offset is total length generated - 1
|
||||
if inference_params.seqlen_offset >= max_length - 1:
|
||||
break
|
||||
if inference_params.sequence_len_offset >= max_length - 2:
|
||||
if inference_params.seqlen_offset >= max_length - 2:
|
||||
# Don't do speculative sampling, just sample 1 token from the model
|
||||
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
|
||||
sequences.append(tokens)
|
||||
@ -497,7 +498,7 @@ def decode_speculative(
|
||||
break
|
||||
# Sample from draft model
|
||||
n_spec_tokens = min(
|
||||
speculative_lookahead, max_length - inference_params_draft.sequence_len_offset - 2
|
||||
speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
|
||||
)
|
||||
tokens_draft, scores_draft = sample_tokens_draft(
|
||||
sequences[-1][:, -1:], num_tokens=n_spec_tokens
|
||||
@ -510,9 +511,9 @@ def decode_speculative(
|
||||
# Evaluate the draft tokens with the model
|
||||
position_ids = repeat(
|
||||
torch.arange(
|
||||
inference_params.sequence_len_offset,
|
||||
inference_params.seqlen_offset,
|
||||
# 1 extra token from last time that hasn't been passed through model
|
||||
inference_params.sequence_len_offset + n_spec_tokens + 1,
|
||||
inference_params.seqlen_offset + n_spec_tokens + 1,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
),
|
||||
@ -525,7 +526,7 @@ def decode_speculative(
|
||||
inference_params=inference_params,
|
||||
).logits # (batch, n_spec_tokens, vocab_size)
|
||||
logits = logits_postprocess_fn(logits)
|
||||
inference_params.sequence_len_offset += 1
|
||||
inference_params.seqlen_offset += 1
|
||||
if debug:
|
||||
logits_ref = model(
|
||||
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
|
||||
@ -539,8 +540,8 @@ def decode_speculative(
|
||||
print(num_generated_tokens)
|
||||
sequences.append(tokens[:1, : num_generated_tokens[0]])
|
||||
scores.append(logits[:1, : num_generated_tokens[0]])
|
||||
inference_params.sequence_len_offset += num_generated_tokens[0].item() - 1
|
||||
inference_params_draft.sequence_len_offset = inference_params.sequence_len_offset
|
||||
inference_params.seqlen_offset += num_generated_tokens[0].item() - 1
|
||||
inference_params_draft.seqlen_offset = inference_params.seqlen_offset
|
||||
# breakpoint()
|
||||
if debug:
|
||||
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
|
||||
@ -679,9 +680,9 @@ def update_graph_cache(
|
||||
)
|
||||
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
||||
cache.inference_params = InferenceParams(
|
||||
max_sequence_len=max_seqlen,
|
||||
max_seqlen=max_seqlen,
|
||||
max_batch_size=batch_size,
|
||||
sequence_len_offset=seqlen_og,
|
||||
seqlen_offset=seqlen_og,
|
||||
key_value_memory_dict=inf_cache,
|
||||
lengths_per_sample=lengths_per_sample,
|
||||
)
|
||||
@ -705,7 +706,7 @@ def update_graph_cache(
|
||||
)
|
||||
|
||||
cache.run = dispatch
|
||||
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
|
||||
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
||||
return cache
|
||||
|
||||
|
||||
@ -713,10 +714,10 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
||||
device = next(iter(model.parameters())).device
|
||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
sequence_len_offset_og = inference_params.sequence_len_offset
|
||||
seqlen_offset_og = inference_params.seqlen_offset
|
||||
# TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is
|
||||
# used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample.
|
||||
inference_params.sequence_len_offset = max_seqlen - 1
|
||||
inference_params.seqlen_offset = max_seqlen - 1
|
||||
inference_params.lengths_per_sample[:] = max_seqlen - 1
|
||||
|
||||
# Warmup before capture
|
||||
@ -755,5 +756,5 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
||||
graph.replay()
|
||||
return logits.clone()
|
||||
|
||||
inference_params.sequence_len_offset = sequence_len_offset_og
|
||||
inference_params.seqlen_offset = seqlen_offset_og
|
||||
return run
|
||||
|
||||
@ -364,14 +364,14 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
|
||||
logits_ref = model(input_ids).logits
|
||||
|
||||
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
|
||||
inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1)
|
||||
inference_params = InferenceParams(max_seqlen=20, max_batch_size=1)
|
||||
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
|
||||
inference_params.sequence_len_offset += 10
|
||||
inference_params.seqlen_offset += 10
|
||||
position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
|
||||
logits_1014 = model(
|
||||
input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
|
||||
).logits
|
||||
inference_params.sequence_len_offset += 4
|
||||
inference_params.seqlen_offset += 4
|
||||
position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
|
||||
logits_1420 = model(
|
||||
input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params
|
||||
|
||||
Loading…
Reference in New Issue
Block a user