Bump to v0.2.6

This commit is contained in:
Tri Dao 2022-12-27 21:18:45 -08:00
parent 63670fd84a
commit a6ec1782dc
3 changed files with 3 additions and 3 deletions

View File

@ -436,7 +436,7 @@ class MHA(nn.Module):
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = False if inference_params.sequence_len_offset == 0 else None
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else:
if not self.return_residual:

View File

@ -40,7 +40,7 @@ def greedy_decode(input_ids, model, max_length):
inference_params.sequence_len_offset = seqlen_og
while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device)
dtype=torch.long, device=input_ids.device)
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
scores.append(logits)

View File

@ -156,7 +156,7 @@ ext_modules.append(
setup(
name="flash_attn",
version="0.2.5",
version="0.2.6-1",
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
),