154 lines
6.2 KiB
Python
154 lines
6.2 KiB
Python
# coding=utf-8
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from dataclasses import dataclass
|
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
|
|
from einops import rearrange
|
|
import torch.nn.functional as F
|
|
|
|
from flash_attn_2_cuda import varlen_fwd
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: LlamaConfig,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
self.up = nn.Linear(self.hidden_size, self.intermediate_size)
|
|
self.gate = nn.Linear(self.hidden_size, self.intermediate_size)
|
|
self.down = nn.Linear(self.intermediate_size, self.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
hidden_states = self.down(self.gate(hidden_states) * self.up(hidden_states))
|
|
return hidden_states
|
|
|
|
|
|
class LLAMAAttention(nn.Module):
|
|
def __init__(self, config: LlamaConfig, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.q = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.k = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.v = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.o = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.num_head = config.num_attention_heads
|
|
self.kv_head = config.num_key_value_heads
|
|
self.head_dim = config.head_dim
|
|
assert (
|
|
self.num_head * self.head_dim == config.hidden_size
|
|
), "make sure the num_head*head_dim == hidden_size"
|
|
|
|
def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor):
|
|
q = self.q(hidden_states)
|
|
k = self.k(hidden_states)
|
|
v = self.v(hidden_states)
|
|
q = q.view(-1, self.num_head, self.head_dim)
|
|
k = k.view(-1, self.kv_head, self.head_dim)
|
|
v = v.view(-1, self.kv_head, self.head_dim)
|
|
# process positionids
|
|
# varlen_fwd # this function is the core of the process.
|
|
hidden_states = varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
None, # output
|
|
cu_seqlens_q, # seqlen_q
|
|
cu_seqlens_k, # seqlen_k, v
|
|
seqused_k, # seqused_k
|
|
None, # block table
|
|
None, # alibi slopse
|
|
max_seqlen_q, # int
|
|
max_seqlen_k, # int
|
|
0.0, # dropout
|
|
0.0, # softmax_scale
|
|
False, # zero_tensors
|
|
True, # is causal
|
|
-1, # window size left
|
|
-1, # window size right
|
|
False, # return softmax
|
|
None, # gen
|
|
)
|
|
# mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
|
# const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
|
# const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
|
# c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
# const at::Tensor &cu_seqlens_q, // b+1
|
|
# const at::Tensor &cu_seqlens_k, // b+1
|
|
# c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
|
# c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
|
# c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
|
# int max_seqlen_q,
|
|
# const int max_seqlen_k,
|
|
# const float p_dropout,
|
|
# const float softmax_scale,
|
|
# const bool zero_tensors,
|
|
# bool is_causal,
|
|
# int window_size_left,
|
|
# int window_size_right,
|
|
# const bool return_softmax,
|
|
# c10::optional<at::Generator> gen_) {
|
|
return hidden_states
|
|
|
|
|
|
class LLAMADecodeLayer(nn.Module):
|
|
def __init__(self, config: LlamaConfig, idx: int, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
class LLAMAModel(nn.Module):
|
|
|
|
def __init__(self, config: LlamaConfig, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.num_layer = config.num_hidden_layers
|
|
|
|
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
self.layers = nn.ModuleList(
|
|
[LLAMADecodeLayer(config=config, idx=i) for i in range(self.num_layer)]
|
|
)
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
|
|
|
def forward(self, input_ids, hidden_states, position_ids):
|
|
if input_ids is not None:
|
|
hidden_states = self.token_embed(input_ids)
|
|
|
|
for layer in self.layers:
|
|
hidden_states = layer(hidden_states, position_ids)
|
|
output = self.lm_head(hidden_states)
|
|
return output
|
|
|
|
def unpad_input(self, hidden_states, attention_mask):
|
|
hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...")
|
|
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1) # some time is eq(1)
|
|
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
|
|
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
cu_seqlens = F.pad(
|
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
|
)
|
|
hidden_states = hidden_states[indices].unsqueeze(0)
|
|
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
|
|
|
|
def pad_input(self, hidden_states, indices, batch, seqlen):
|
|
"""
|
|
:param hidden_states: Shape is [L,H] not [B,L,H]
|
|
:param indices: from unpad_input return indices
|
|
:param batch:
|
|
:param seqlen: from unpad_input return max_seqlen_in_batch
|
|
:return:
|
|
"""
|
|
output = torch.zeros(
|
|
batch * seqlen,
|
|
*hidden_states.shape[1:],
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
output[indices] = hidden_states
|
|
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|