Bump to v0.2.6
This commit is contained in:
parent
63670fd84a
commit
a6ec1782dc
@ -436,7 +436,7 @@ class MHA(nn.Module):
|
||||
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = False if inference_params.sequence_len_offset == 0 else None
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
if not self.return_residual:
|
||||
|
||||
@ -40,7 +40,7 @@ def greedy_decode(input_ids, model, max_length):
|
||||
inference_params.sequence_len_offset = seqlen_og
|
||||
while True:
|
||||
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
|
||||
dtype=torch.long, device=input_ids.device)
|
||||
dtype=torch.long, device=input_ids.device)
|
||||
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
scores.append(logits)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user