Use FlashAttention for multi_query_kv_attention (#4)

This commit is contained in:
Woosuk Kwon 2023-03-01 21:13:08 -08:00 committed by GitHub
parent 0deacbce6e
commit 3e9f991d6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 107 additions and 33 deletions

View File

@ -4,6 +4,7 @@
```bash
pip install cmake torch transformers
pip install flash-attn # This may take up to 10 mins.
pip install -e .
```

View File

@ -1,5 +1,6 @@
from typing import List, Optional
from flash_attn.flash_attention import FlashAttention
import torch
import torch.nn as nn
@ -14,20 +15,7 @@ class OPTCacheFlowAttention(nn.Module):
super().__init__()
self.scale = float(scale)
def _masked_attention(
self,
query: torch.Tensor, # [num_queries, num_heads, head_size]
key: torch.Tensor, # [num_keys, num_heads, head_size]
value: torch.Tensor, # [num_keys, num_heads, head_size]
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
) -> torch.Tensor: # [num_queries, num_heads, head_size]
query = query * self.scale
attn = torch.einsum('qhd,khd->hqk', query, key)
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
out = torch.einsum('hqk,khd->qhd', attn, value)
return out
self.flash_attn = FlashAttention(softmax_scale=self.scale)
def multi_query_kv_attention(
self,
@ -37,21 +25,31 @@ class OPTCacheFlowAttention(nn.Module):
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
prompt_lens: List[int],
) -> None:
# FIXME(woosuk): Replace the following with a custom op.
start_idx = 0
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[2]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')
device = query.device
prefix_sum = [0]
for prompt_len in prompt_lens:
out = output[start_idx:start_idx + prompt_len]
q = query[start_idx:start_idx + prompt_len]
k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]
prefix_sum.append(prefix_sum[-1] + prompt_len)
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
max_prompt_len = max(prompt_lens)
attention_mask = torch.triu(
torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
attention_out = self._masked_attention(q, k, v, attention_mask)
out.copy_(attention_out, non_blocking=True)
start_idx += prompt_len
# FIXME(woosuk): Unnecessary copy. Optimize this.
qkv = torch.stack([query, key, value], dim=1)
out = self.flash_attn(
qkv,
cu_seqlens=prefix_sum,
max_s=max_prompt_len,
causal=True,
)[0]
num_tokens = prefix_sum[-1]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output[:num_tokens].copy_(out, non_blocking=True)
def single_query_cached_kv_attention(
self,
@ -61,6 +59,14 @@ class OPTCacheFlowAttention(nn.Module):
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
) -> None:
head_size = value_cache.shape[2]
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
if head_size not in supported_head_sizes:
raise ValueError(f'head_size ({head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{supported_head_sizes}.')
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
@ -101,8 +107,9 @@ class OPTCacheFlowAttention(nn.Module):
output = output.view(-1, num_heads, head_size)
# Compute the attention op for prompts.
self.multi_query_kv_attention(
output, query, key, value, input_metadata.prompt_lens)
if input_metadata.num_prompts > 0:
self.multi_query_kv_attention(
output, query, key, value, input_metadata.prompt_lens)
# Wait until the cache op is done.
if cache_event is not None:

View File

@ -9,10 +9,12 @@ parser = argparse.ArgumentParser(description='CacheFlow server')
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
parser.add_argument('--block-size', type=int, default=8, help='token block size')
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)')
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
args = parser.parse_args()
@ -27,6 +29,7 @@ def main():
block_size=args.block_size,
num_gpu_blocks=args.num_gpu_blocks,
num_cpu_blocks=args.num_cpu_blocks,
dtype=args.dtype,
)
controllers.append(controller)

View File

@ -1,10 +1,13 @@
import random
from typing import Optional
from flash_attn.flash_attention import FlashAttention
import torch
from cacheflow import attention_ops
MAX_SEQ_LEN = 4096
def ref_masked_attention(
query: torch.Tensor,
@ -79,7 +82,7 @@ def test_single_query_cached_kv_attention(
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
context_lens = [random.randint(1, 4096) for _ in range(num_tokens)]
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
@ -123,11 +126,60 @@ def test_single_query_cached_kv_attention(
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> None:
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
max_seq_len = max(seq_lens)
num_tokens = sum(seq_lens)
cu_seq_lens = [0]
for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
query = torch.randn(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
key = torch.rand_like(query)
value = torch.rand_like(query)
qkv = torch.stack([query, key, value], dim=1)
flash_attn = FlashAttention(softmax_scale=scale)
output = flash_attn(
qkv,
cu_seqlens=cu_seq_lens,
max_s=max_seq_len,
causal=True,
)[0]
ref_outputs = []
for i, seq_len in enumerate(seq_lens):
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
@torch.inference_mode()
def test_attention() -> None:
for dtype in [torch.half, torch.float]:
for block_size in [8, 16]:
for head_size in [64, 80, 96, 128, 256]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
test_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
@ -137,6 +189,17 @@ def test_attention() -> None:
dtype=dtype,
)
# NOTE(woosuk): FlashAttention does not support FP32.
for dtype in [torch.half]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for head_size in [64, 80, 96, 128]:
test_multi_query_kv_attention(
num_seqs=11,
num_heads=3,
head_size=head_size,
dtype=dtype,
)
if __name__ == '__main__':
test_attention()