Add unoptimized OPT Attention

This commit is contained in:
Woosuk Kwon 2023-02-23 09:31:55 +00:00
parent b56b6ca0d6
commit d4bc1a4d24
2 changed files with 177 additions and 14 deletions

View File

@ -0,0 +1,118 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import xformers.ops as xops
from cacheflow import ops
from cacheflow.models import InputMetadata
class OPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
# Shape-agnostic attention mask.
self.attention_mask = xops.LowerTriangularMask()
def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> None:
out = xops.memory_efficient_attention(
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
# FIXME(woosuk): Directly write the attention output.
output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
block_size = value_cache.shape[2]
block_tables = input_metadata.block_tables
# FIXME(woosuk): Replace the following with a custom op.
for i in range(input_metadata.num_generation_tokens):
q = query[i]
block_table = block_tables[i]
context_len = int(input_metadata.context_lens[i])
keys = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.view(num_heads, head_size)
keys.append(k)
keys = torch.stack(keys, dim=-1)
logits = q @ keys
attention_weights = torch.softmax(logits, dim=-1)
values = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
v = value_cache[block_number, :, block_offset, :]
values.append(v)
values = torch.stack(values, dim=-1)
out = attention_weights @ values
output[i].copy_(out, non_blocking=True)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Reshape the input tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
# Compute the attention op for prompts.
output = torch.empty_like(query)
start_idx = 0
for i in range(input_metadata.num_prompts):
prompt_len = input_metadata.prompt_lens[i]
out = output[start_idx:start_idx + prompt_len]
q = query[start_idx:start_idx + prompt_len]
k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]
self.multi_query_kv_attention(out, q, k, v)
start_idx += prompt_len
# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)
if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[start_idx:],
query[start_idx:],
key_cache,
value_cache,
input_metadata)
# Reshape the output tensor.
return output.view(-1, num_heads * head_size)

View File

@ -1,9 +1,17 @@
"""1D OPT model compatible with HuggingFace weights."""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import OPTConfig
from transformers import PreTrainedModel
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding):
@ -31,17 +39,27 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim**-0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
q = self.q_proj(hidden_states) * self.scaling
self.attn = OPTCacheFlowAttention(scale=self.scaling)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# TODO
attn_output = None
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output = self.out_proj(attn_output)
return output
@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module):
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel):
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
for layer in self.layers:
hidden_states = layer(hidden_states)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
return self.decoder(input_ids, positions)
return self.decoder(
input_ids, positions, kv_caches, input_metadata, cache_events)
class OPTForCausalLM(OPTPreTrainedModel):
@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.sampler = Sampler(embedding=self.lm_head.weight)
# Initialize weights and apply final processing
self.post_init()
@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
) -> torch.Tensor:
hidden_states = self.model.decoder(input_ids, positions)
logits = self.lm_head(hidden_states).contiguous()
return logits
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, Tuple[int, int]]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(hidden_states, input_metadata)
return next_tokens