Use FlashAttention for multi_query_kv_attention (#4)
This commit is contained in:
parent
0deacbce6e
commit
3e9f991d6a
@ -4,6 +4,7 @@
|
||||
|
||||
```bash
|
||||
pip install cmake torch transformers
|
||||
pip install flash-attn # This may take up to 10 mins.
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -14,20 +15,7 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
|
||||
def _masked_attention(
|
||||
self,
|
||||
query: torch.Tensor, # [num_queries, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_keys, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_keys, num_heads, head_size]
|
||||
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
|
||||
) -> torch.Tensor: # [num_queries, num_heads, head_size]
|
||||
query = query * self.scale
|
||||
attn = torch.einsum('qhd,khd->hqk', query, key)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
out = torch.einsum('hqk,khd->qhd', attn, value)
|
||||
return out
|
||||
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
@ -37,21 +25,31 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
prompt_lens: List[int],
|
||||
) -> None:
|
||||
# FIXME(woosuk): Replace the following with a custom op.
|
||||
start_idx = 0
|
||||
if query.dtype == torch.float:
|
||||
raise ValueError('The float data type is not supported by '
|
||||
'FlashAttention. Use the half data type instead.')
|
||||
head_size = query.shape[2]
|
||||
if head_size > 128:
|
||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
||||
|
||||
device = query.device
|
||||
prefix_sum = [0]
|
||||
for prompt_len in prompt_lens:
|
||||
out = output[start_idx:start_idx + prompt_len]
|
||||
q = query[start_idx:start_idx + prompt_len]
|
||||
k = key[start_idx:start_idx + prompt_len]
|
||||
v = value[start_idx:start_idx + prompt_len]
|
||||
prefix_sum.append(prefix_sum[-1] + prompt_len)
|
||||
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
|
||||
max_prompt_len = max(prompt_lens)
|
||||
|
||||
attention_mask = torch.triu(
|
||||
torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
|
||||
attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
|
||||
attention_out = self._masked_attention(q, k, v, attention_mask)
|
||||
out.copy_(attention_out, non_blocking=True)
|
||||
|
||||
start_idx += prompt_len
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
qkv = torch.stack([query, key, value], dim=1)
|
||||
out = self.flash_attn(
|
||||
qkv,
|
||||
cu_seqlens=prefix_sum,
|
||||
max_s=max_prompt_len,
|
||||
causal=True,
|
||||
)[0]
|
||||
num_tokens = prefix_sum[-1]
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
output[:num_tokens].copy_(out, non_blocking=True)
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
@ -61,6 +59,14 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
head_size = value_cache.shape[2]
|
||||
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(f'head_size ({head_size}) is not supported by '
|
||||
'the single_query_cached_kv_attention kernel. '
|
||||
'Use one of the following head sizes: '
|
||||
f'{supported_head_sizes}.')
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
@ -101,8 +107,9 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
output = output.view(-1, num_heads, head_size)
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
self.multi_query_kv_attention(
|
||||
output, query, key, value, input_metadata.prompt_lens)
|
||||
if input_metadata.num_prompts > 0:
|
||||
self.multi_query_kv_attention(
|
||||
output, query, key, value, input_metadata.prompt_lens)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
if cache_event is not None:
|
||||
|
||||
@ -9,10 +9,12 @@ parser = argparse.ArgumentParser(description='CacheFlow server')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
||||
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
|
||||
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
|
||||
parser.add_argument('--block-size', type=int, default=8, help='token block size')
|
||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
||||
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
|
||||
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
|
||||
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)')
|
||||
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -27,6 +29,7 @@ def main():
|
||||
block_size=args.block_size,
|
||||
num_gpu_blocks=args.num_gpu_blocks,
|
||||
num_cpu_blocks=args.num_cpu_blocks,
|
||||
dtype=args.dtype,
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
import torch
|
||||
|
||||
from cacheflow import attention_ops
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
query: torch.Tensor,
|
||||
@ -79,7 +82,7 @@ def test_single_query_cached_kv_attention(
|
||||
value_cache = torch.randn(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
|
||||
context_lens = [random.randint(1, 4096) for _ in range(num_tokens)]
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
@ -123,11 +126,60 @@ def test_single_query_cached_kv_attention(
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
max_seq_len = max(seq_lens)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
cu_seq_lens = [0]
|
||||
for seq_len in seq_lens:
|
||||
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
query = torch.randn(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
key = torch.rand_like(query)
|
||||
value = torch.rand_like(query)
|
||||
|
||||
qkv = torch.stack([query, key, value], dim=1)
|
||||
flash_attn = FlashAttention(softmax_scale=scale)
|
||||
output = flash_attn(
|
||||
qkv,
|
||||
cu_seqlens=cu_seq_lens,
|
||||
max_s=max_seq_len,
|
||||
causal=True,
|
||||
)[0]
|
||||
|
||||
ref_outputs = []
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_attention() -> None:
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16]:
|
||||
for head_size in [64, 80, 96, 128, 256]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
test_single_query_cached_kv_attention(
|
||||
num_tokens=37,
|
||||
num_heads=3,
|
||||
@ -137,6 +189,17 @@ def test_attention() -> None:
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): FlashAttention does not support FP32.
|
||||
for dtype in [torch.half]:
|
||||
# NOTE(woosuk): FlashAttention does not support head_size > 128.
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
test_multi_query_kv_attention(
|
||||
num_seqs=11,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_attention()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user