hotfix attn alibi wo head mapping (#496)

Co-authored-by: oliveryuan <oliveryuan@basemind.com>
This commit is contained in:
Song 2023-07-19 02:31:48 +08:00 committed by GitHub
parent 453bafb96f
commit bda41c70dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 0 deletions

View File

@ -199,6 +199,7 @@ def run_single_query_cached_kv_attention(
] ]
block_tables.append(block_table) block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
output = torch.empty(num_tokens, output = torch.empty(num_tokens,
@ -211,6 +212,7 @@ def run_single_query_cached_kv_attention(
query, query,
key_cache, key_cache,
value_cache, value_cache,
head_mapping,
scale, scale,
block_tables, block_tables,
context_lens, context_lens,

View File

@ -408,6 +408,7 @@ class PagedAttentionWithALiBi(PagedAttention):
query, query,
key_cache, key_cache,
value_cache, value_cache,
self.head_mapping,
self.scale, self.scale,
input_metadata.block_tables, input_metadata.block_tables,
input_metadata.context_lens, input_metadata.context_lens,