Add unoptimized OPT Attention
This commit is contained in:
parent
b56b6ca0d6
commit
d4bc1a4d24
118
cacheflow/models/attention.py
Normal file
118
cacheflow/models/attention.py
Normal file
@ -0,0 +1,118 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import xformers.ops as xops
|
||||
|
||||
from cacheflow import ops
|
||||
from cacheflow.models import InputMetadata
|
||||
|
||||
|
||||
class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
# Shape-agnostic attention mask.
|
||||
self.attention_mask = xops.LowerTriangularMask()
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> None:
|
||||
out = xops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
|
||||
# FIXME(woosuk): Directly write the attention output.
|
||||
output.copy_(out, non_blocking=True)
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[3]
|
||||
block_size = value_cache.shape[2]
|
||||
block_tables = input_metadata.block_tables
|
||||
|
||||
# FIXME(woosuk): Replace the following with a custom op.
|
||||
for i in range(input_metadata.num_generation_tokens):
|
||||
q = query[i]
|
||||
block_table = block_tables[i]
|
||||
context_len = int(input_metadata.context_lens[i])
|
||||
keys = []
|
||||
for j in range(context_len):
|
||||
block_number = block_table[j // block_size]
|
||||
block_offset = j % block_size
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.view(num_heads, head_size)
|
||||
keys.append(k)
|
||||
keys = torch.stack(keys, dim=-1)
|
||||
logits = q @ keys
|
||||
attention_weights = torch.softmax(logits, dim=-1)
|
||||
|
||||
values = []
|
||||
for j in range(context_len):
|
||||
block_number = block_table[j // block_size]
|
||||
block_offset = j % block_size
|
||||
v = value_cache[block_number, :, block_offset, :]
|
||||
values.append(v)
|
||||
values = torch.stack(values, dim=-1)
|
||||
out = attention_weights @ values
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Reshape the input tensors.
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[3]
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_heads, head_size)
|
||||
value = value.view(-1, num_heads, head_size)
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
output = torch.empty_like(query)
|
||||
start_idx = 0
|
||||
for i in range(input_metadata.num_prompts):
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
out = output[start_idx:start_idx + prompt_len]
|
||||
q = query[start_idx:start_idx + prompt_len]
|
||||
k = key[start_idx:start_idx + prompt_len]
|
||||
v = value[start_idx:start_idx + prompt_len]
|
||||
self.multi_query_kv_attention(out, q, k, v)
|
||||
start_idx += prompt_len
|
||||
|
||||
# Wait until the cache op is done.
|
||||
if cache_event is not None:
|
||||
cache_event.wait()
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, input_metadata.slot_mapping)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(
|
||||
output[start_idx:],
|
||||
query[start_idx:],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, num_heads * head_size)
|
||||
@ -1,9 +1,17 @@
|
||||
"""1D OPT model compatible with HuggingFace weights."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import OPTConfig
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
|
||||
@ -31,17 +39,27 @@ class OPTAttention(nn.Module):
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
q = self.q_proj(hidden_states) * self.scaling
|
||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
# TODO
|
||||
attn_output = None
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
output = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
pos_embeds = self.embed_positions(positions)
|
||||
@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states)
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
return self.decoder(input_ids, positions)
|
||||
return self.decoder(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
|
||||
|
||||
class OPTForCausalLM(OPTPreTrainedModel):
|
||||
@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = OPTModel(config)
|
||||
|
||||
# the lm_head weight is automatically tied to the embed tokens weight
|
||||
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
||||
self.sampler = Sampler(embedding=self.lm_head.weight)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model.decoder(input_ids, positions)
|
||||
logits = self.lm_head(hidden_states).contiguous()
|
||||
return logits
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, Tuple[int, int]]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
Loading…
Reference in New Issue
Block a user