Fix wrong dtype in PagedAttentionWithALiBi bias (#996)

---------

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Antoni Baum 2023-09-09 14:58:35 -07:00 committed by GitHub
parent 4042d192f5
commit a62de9ecfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -73,7 +73,12 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
def set_attn_bias(
self,
input_metadata: InputMetadata,
dtype: torch.dtype,
) -> None:
del dtype # Unused.
if input_metadata.attn_bias:
# Already set by a previous layer.
return
@ -196,7 +201,7 @@ class PagedAttention(nn.Module):
if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata)
self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
@ -340,13 +345,14 @@ class PagedAttentionWithALiBi(PagedAttention):
slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len)
bias = torch.arange(prompt_len, dtype=dtype)
# Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
@ -364,6 +370,7 @@ class PagedAttentionWithALiBi(PagedAttention):
prompt_len,
padded_len,
device=self.alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)