From a6ec1782dc69b1d1a9ed94e2323c3ed5ba56cc13 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 27 Dec 2022 21:18:45 -0800 Subject: [PATCH] Bump to v0.2.6 --- flash_attn/modules/mha.py | 2 +- flash_attn/utils/generation.py | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 97755bc..a6f3ec7 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -436,7 +436,7 @@ class MHA(nn.Module): kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) # If we're processing the prompt, causal=None (use self.causal). # If we're decoding, then causal=False. - causal = False if inference_params.sequence_len_offset == 0 else None + causal = None if inference_params.sequence_len_offset == 0 else False context = self.inner_cross_attn(q, kv, causal=causal) else: if not self.return_residual: diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 13f7fca..c578f37 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -40,7 +40,7 @@ def greedy_decode(input_ids, model, max_length): inference_params.sequence_len_offset = seqlen_og while True: position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, - dtype=torch.long, device=input_ids.device) + dtype=torch.long, device=input_ids.device) logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, inference_params=inference_params).logits[:, -1] scores.append(logits) diff --git a/setup.py b/setup.py index 6c97046..ee5ded3 100644 --- a/setup.py +++ b/setup.py @@ -156,7 +156,7 @@ ext_modules.append( setup( name="flash_attn", - version="0.2.5", + version="0.2.6-1", packages=find_packages( exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) ),