diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py new file mode 100644 index 00000000..c36f06c7 --- /dev/null +++ b/cacheflow/models/attention.py @@ -0,0 +1,118 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import xformers.ops as xops + +from cacheflow import ops +from cacheflow.models import InputMetadata + + +class OPTCacheFlowAttention(nn.Module): + + def __init__(self, scale: float) -> None: + super().__init__() + self.scale = scale + + # Shape-agnostic attention mask. + self.attention_mask = xops.LowerTriangularMask() + + def multi_query_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> None: + out = xops.memory_efficient_attention( + query, key, value, attn_bias=self.attention_mask, scale=self.scale) + # FIXME(woosuk): Directly write the attention output. + output.copy_(out, non_blocking=True) + + def single_query_cached_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + ) -> None: + num_heads = value_cache.shape[1] + head_size = value_cache.shape[3] + block_size = value_cache.shape[2] + block_tables = input_metadata.block_tables + + # FIXME(woosuk): Replace the following with a custom op. + for i in range(input_metadata.num_generation_tokens): + q = query[i] + block_table = block_tables[i] + context_len = int(input_metadata.context_lens[i]) + keys = [] + for j in range(context_len): + block_number = block_table[j // block_size] + block_offset = j % block_size + k = key_cache[block_number, :, :, block_offset, :] + k = k.view(num_heads, head_size) + keys.append(k) + keys = torch.stack(keys, dim=-1) + logits = q @ keys + attention_weights = torch.softmax(logits, dim=-1) + + values = [] + for j in range(context_len): + block_number = block_table[j // block_size] + block_offset = j % block_size + v = value_cache[block_number, :, block_offset, :] + values.append(v) + values = torch.stack(values, dim=-1) + out = attention_weights @ values + output[i].copy_(out, non_blocking=True) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Reshape the input tensors. + num_heads = value_cache.shape[1] + head_size = value_cache.shape[3] + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_heads, head_size) + value = value.view(-1, num_heads, head_size) + + # Compute the attention op for prompts. + output = torch.empty_like(query) + start_idx = 0 + for i in range(input_metadata.num_prompts): + prompt_len = input_metadata.prompt_lens[i] + out = output[start_idx:start_idx + prompt_len] + q = query[start_idx:start_idx + prompt_len] + k = key[start_idx:start_idx + prompt_len] + v = value[start_idx:start_idx + prompt_len] + self.multi_query_kv_attention(out, q, k, v) + start_idx += prompt_len + + # Wait until the cache op is done. + if cache_event is not None: + cache_event.wait() + + # Reshape the keys and values and store them in the cache. + ops.reshape_and_cache( + key, value, key_cache, value_cache, input_metadata.slot_mapping) + + if input_metadata.num_generation_tokens > 0: + # Compute the attention op for generation tokens. + self.single_query_cached_kv_attention( + output[start_idx:], + query[start_idx:], + key_cache, + value_cache, + input_metadata) + + # Reshape the output tensor. + return output.view(-1, num_heads * head_size) diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 3a340317..234ab263 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -1,9 +1,17 @@ """1D OPT model compatible with HuggingFace weights.""" +from typing import Dict, List, Optional, Tuple + import torch from torch import nn from transformers import OPTConfig from transformers import PreTrainedModel +from cacheflow.models import InputMetadata +from cacheflow.models.attention import OPTCacheFlowAttention +from cacheflow.models.sample import Sampler + +KVCache = Tuple[torch.Tensor, torch.Tensor] + class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -31,17 +39,27 @@ class OPTAttention(nn.Module): self.head_dim = embed_dim // num_heads self.scaling = self.head_dim**-0.5 + # TODO(woosuk): Fuse the three linear layers into one QKV linear layer. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - q = self.q_proj(hidden_states) * self.scaling + self.attn = OPTCacheFlowAttention(scale=self.scaling) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - # TODO - attn_output = None + key_cache, value_cache = kv_cache + attn_output = self.attn( + q, k, v, key_cache, value_cache, input_metadata, cache_event) output = self.out_proj(attn_output) return output @@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: # Self Attention residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: @@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel): self, input_ids: torch.LongTensor, positions: torch.LongTensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) pos_embeds = self.embed_positions(positions) @@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel): inputs_embeds = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds - for layer in self.layers: - hidden_states = layer(hidden_states) + for i in range(len(self.layers)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + hidden_states, kv_caches[i], input_metadata, cache_event) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) @@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel): self, input_ids: torch.LongTensor, positions: torch.LongTensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: - return self.decoder(input_ids, positions) + return self.decoder( + input_ids, positions, kv_caches, input_metadata, cache_events) class OPTForCausalLM(OPTPreTrainedModel): @@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel): def __init__(self, config): super().__init__(config) self.model = OPTModel(config) - # the lm_head weight is automatically tied to the embed tokens weight self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + self.sampler = Sampler(embedding=self.lm_head.weight) # Initialize weights and apply final processing self.post_init() @@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel): self, input_ids: torch.LongTensor, positions: torch.LongTensor, - ) -> torch.Tensor: - hidden_states = self.model.decoder(input_ids, positions) - logits = self.lm_head(hidden_states).contiguous() - return logits + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, Tuple[int, int]]: + hidden_states = self.model( + input_ids, positions, kv_caches, input_metadata, cache_events) + next_tokens = self.sampler(hidden_states, input_metadata) + return next_tokens