diff --git a/README.md b/README.md index c7da1ced..1ae1ff6a 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ ```bash pip install cmake torch transformers +pip install flash-attn # This may take up to 10 mins. pip install -e . ``` diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 0babcc8b..34edeec0 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,5 +1,6 @@ from typing import List, Optional +from flash_attn.flash_attention import FlashAttention import torch import torch.nn as nn @@ -14,20 +15,7 @@ class OPTCacheFlowAttention(nn.Module): super().__init__() self.scale = float(scale) - def _masked_attention( - self, - query: torch.Tensor, # [num_queries, num_heads, head_size] - key: torch.Tensor, # [num_keys, num_heads, head_size] - value: torch.Tensor, # [num_keys, num_heads, head_size] - attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys] - ) -> torch.Tensor: # [num_queries, num_heads, head_size] - query = query * self.scale - attn = torch.einsum('qhd,khd->hqk', query, key) - if attn_mask is not None: - attn = attn + attn_mask - attn = torch.softmax(attn, dim=-1) - out = torch.einsum('hqk,khd->qhd', attn, value) - return out + self.flash_attn = FlashAttention(softmax_scale=self.scale) def multi_query_kv_attention( self, @@ -37,21 +25,31 @@ class OPTCacheFlowAttention(nn.Module): value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] prompt_lens: List[int], ) -> None: - # FIXME(woosuk): Replace the following with a custom op. - start_idx = 0 + if query.dtype == torch.float: + raise ValueError('The float data type is not supported by ' + 'FlashAttention. Use the half data type instead.') + head_size = query.shape[2] + if head_size > 128: + raise ValueError('FlashAttention does not support head_size > 128.') + + device = query.device + prefix_sum = [0] for prompt_len in prompt_lens: - out = output[start_idx:start_idx + prompt_len] - q = query[start_idx:start_idx + prompt_len] - k = key[start_idx:start_idx + prompt_len] - v = value[start_idx:start_idx + prompt_len] + prefix_sum.append(prefix_sum[-1] + prompt_len) + prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device) + max_prompt_len = max(prompt_lens) - attention_mask = torch.triu( - torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5 - attention_mask = attention_mask.to(dtype=q.dtype, device=q.device) - attention_out = self._masked_attention(q, k, v, attention_mask) - out.copy_(attention_out, non_blocking=True) - - start_idx += prompt_len + # FIXME(woosuk): Unnecessary copy. Optimize this. + qkv = torch.stack([query, key, value], dim=1) + out = self.flash_attn( + qkv, + cu_seqlens=prefix_sum, + max_s=max_prompt_len, + causal=True, + )[0] + num_tokens = prefix_sum[-1] + # FIXME(woosuk): Unnecessary copy. Optimize this. + output[:num_tokens].copy_(out, non_blocking=True) def single_query_cached_kv_attention( self, @@ -61,6 +59,14 @@ class OPTCacheFlowAttention(nn.Module): value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] input_metadata: InputMetadata, ) -> None: + head_size = value_cache.shape[2] + supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256] + if head_size not in supported_head_sizes: + raise ValueError(f'head_size ({head_size}) is not supported by ' + 'the single_query_cached_kv_attention kernel. ' + 'Use one of the following head sizes: ' + f'{supported_head_sizes}.') + block_size = value_cache.shape[3] attention_ops.single_query_cached_kv_attention( output, @@ -101,8 +107,9 @@ class OPTCacheFlowAttention(nn.Module): output = output.view(-1, num_heads, head_size) # Compute the attention op for prompts. - self.multi_query_kv_attention( - output, query, key, value, input_metadata.prompt_lens) + if input_metadata.num_prompts > 0: + self.multi_query_kv_attention( + output, query, key, value, input_metadata.prompt_lens) # Wait until the cache op is done. if cache_event is not None: diff --git a/server.py b/server.py index a0e12ab1..d70dab01 100644 --- a/server.py +++ b/server.py @@ -9,10 +9,12 @@ parser = argparse.ArgumentParser(description='CacheFlow server') parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes') parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') -parser.add_argument('--block-size', type=int, default=8, help='token block size') +parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') # TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks. parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)') -parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)') +parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)') +# NOTE(woosuk): If FlashAttention is used, the float data type is not supported. +parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') args = parser.parse_args() @@ -27,6 +29,7 @@ def main(): block_size=args.block_size, num_gpu_blocks=args.num_gpu_blocks, num_cpu_blocks=args.num_cpu_blocks, + dtype=args.dtype, ) controllers.append(controller) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index 550e2b28..ec493940 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -1,10 +1,13 @@ import random from typing import Optional +from flash_attn.flash_attention import FlashAttention import torch from cacheflow import attention_ops +MAX_SEQ_LEN = 4096 + def ref_masked_attention( query: torch.Tensor, @@ -79,7 +82,7 @@ def test_single_query_cached_kv_attention( value_cache = torch.randn( size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') - context_lens = [random.randint(1, 4096) for _ in range(num_tokens)] + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] max_context_len = max(context_lens) context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') @@ -123,11 +126,60 @@ def test_single_query_cached_kv_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, +) -> None: + seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) + max_seq_len = max(seq_lens) + num_tokens = sum(seq_lens) + + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda') + + scale = float(1.0 / (head_size ** 0.5)) + query = torch.randn( + num_tokens, num_heads, head_size, dtype=dtype, device='cuda') + key = torch.rand_like(query) + value = torch.rand_like(query) + + qkv = torch.stack([query, key, value], dim=1) + flash_attn = FlashAttention(softmax_scale=scale) + output = flash_attn( + qkv, + cu_seqlens=cu_seq_lens, + max_s=max_seq_len, + causal=True, + )[0] + + ref_outputs = [] + for i, seq_len in enumerate(seq_lens): + attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5 + attn_mask = attn_mask.to(dtype=dtype, device='cuda') + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + @torch.inference_mode() def test_attention() -> None: for dtype in [torch.half, torch.float]: for block_size in [8, 16]: - for head_size in [64, 80, 96, 128, 256]: + for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: test_single_query_cached_kv_attention( num_tokens=37, num_heads=3, @@ -137,6 +189,17 @@ def test_attention() -> None: dtype=dtype, ) + # NOTE(woosuk): FlashAttention does not support FP32. + for dtype in [torch.half]: + # NOTE(woosuk): FlashAttention does not support head_size > 128. + for head_size in [64, 80, 96, 128]: + test_multi_query_kv_attention( + num_seqs=11, + num_heads=3, + head_size=head_size, + dtype=dtype, + ) + if __name__ == '__main__': test_attention()