Add unoptimized OPT Attention

This commit is contained in:
Woosuk Kwon 2023-02-23 09:31:55 +00:00
parent b56b6ca0d6
commit d4bc1a4d24
2 changed files with 177 additions and 14 deletions

View File

@ -0,0 +1,118 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import xformers.ops as xops
from cacheflow import ops
from cacheflow.models import InputMetadata
class OPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
# Shape-agnostic attention mask.
self.attention_mask = xops.LowerTriangularMask()
def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> None:
out = xops.memory_efficient_attention(
query, key, value, attn_bias=self.attention_mask, scale=self.scale)
# FIXME(woosuk): Directly write the attention output.
output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
block_size = value_cache.shape[2]
block_tables = input_metadata.block_tables
# FIXME(woosuk): Replace the following with a custom op.
for i in range(input_metadata.num_generation_tokens):
q = query[i]
block_table = block_tables[i]
context_len = int(input_metadata.context_lens[i])
keys = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.view(num_heads, head_size)
keys.append(k)
keys = torch.stack(keys, dim=-1)
logits = q @ keys
attention_weights = torch.softmax(logits, dim=-1)
values = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
v = value_cache[block_number, :, block_offset, :]
values.append(v)
values = torch.stack(values, dim=-1)
out = attention_weights @ values
output[i].copy_(out, non_blocking=True)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Reshape the input tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
# Compute the attention op for prompts.
output = torch.empty_like(query)
start_idx = 0
for i in range(input_metadata.num_prompts):
prompt_len = input_metadata.prompt_lens[i]
out = output[start_idx:start_idx + prompt_len]
q = query[start_idx:start_idx + prompt_len]
k = key[start_idx:start_idx + prompt_len]
v = value[start_idx:start_idx + prompt_len]
self.multi_query_kv_attention(out, q, k, v)
start_idx += prompt_len
# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)
if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[start_idx:],
query[start_idx:],
key_cache,
value_cache,
input_metadata)
# Reshape the output tensor.
return output.view(-1, num_heads * head_size)

View File

@ -1,9 +1,17 @@
"""1D OPT model compatible with HuggingFace weights.""" """1D OPT model compatible with HuggingFace weights."""
from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import OPTConfig from transformers import OPTConfig
from transformers import PreTrainedModel from transformers import PreTrainedModel
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
@ -31,17 +39,27 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.attn = OPTCacheFlowAttention(scale=self.scaling)
q = self.q_proj(hidden_states) * self.scaling
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states) k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states) v = self.v_proj(hidden_states)
# TODO key_cache, value_cache = kv_cache
attn_output = None attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output = self.out_proj(attn_output) output = self.out_proj(attn_output)
return output return output
@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module):
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before: if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states) hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention # 350m applies layer norm AFTER attention
if not self.do_layer_norm_before: if not self.do_layer_norm_before:
@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.LongTensor, positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions) pos_embeds = self.embed_positions(positions)
@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel):
inputs_embeds = self.project_in(inputs_embeds) inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds hidden_states = inputs_embeds + pos_embeds
for layer in self.layers: for i in range(len(self.layers)):
hidden_states = layer(hidden_states) if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.LongTensor, positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
return self.decoder(input_ids, positions) return self.decoder(
input_ids, positions, kv_caches, input_metadata, cache_events)
class OPTForCausalLM(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel):
@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model = OPTModel(config) self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight # the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
self.sampler = Sampler(embedding=self.lm_head.weight)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.LongTensor, positions: torch.LongTensor,
) -> torch.Tensor: kv_caches: List[KVCache],
hidden_states = self.model.decoder(input_ids, positions) input_metadata: InputMetadata,
logits = self.lm_head(hidden_states).contiguous() cache_events: Optional[List[torch.cuda.Event]],
return logits ) -> Dict[int, Tuple[int, int]]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(hidden_states, input_metadata)
return next_tokens