# coding=utf-8 import torch import torch.nn as nn from dataclasses import dataclass from transformers.models.llama.configuration_llama import LlamaConfig from einops import rearrange import torch.nn.functional as F from flash_attn_2_cuda import varlen_fwd class MLP(nn.Module): def __init__( self, config: LlamaConfig, *args, **kwargs, ): super().__init__(*args, **kwargs) self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.up = nn.Linear(self.hidden_size, self.intermediate_size) self.gate = nn.Linear(self.hidden_size, self.intermediate_size) self.down = nn.Linear(self.intermediate_size, self.hidden_size) def forward(self, hidden_states: torch.Tensor): hidden_states = self.down(self.gate(hidden_states) * self.up(hidden_states)) return hidden_states class LLAMAAttention(nn.Module): def __init__(self, config: LlamaConfig, *args, **kwargs): super().__init__(*args, **kwargs) self.q = nn.Linear(config.hidden_size, config.hidden_size) self.k = nn.Linear(config.hidden_size, config.hidden_size) self.v = nn.Linear(config.hidden_size, config.hidden_size) self.o = nn.Linear(config.hidden_size, config.hidden_size) self.num_head = config.num_attention_heads self.kv_head = config.num_key_value_heads self.head_dim = config.head_dim assert ( self.num_head * self.head_dim == config.hidden_size ), "make sure the num_head*head_dim == hidden_size" def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor): q = self.q(hidden_states) k = self.k(hidden_states) v = self.v(hidden_states) q = q.view(-1, self.num_head, self.head_dim) k = k.view(-1, self.kv_head, self.head_dim) v = v.view(-1, self.kv_head, self.head_dim) # process positionids # varlen_fwd # this function is the core of the process. hidden_states = varlen_fwd( q, k, v, None, # output cu_seqlens_q, # seqlen_q cu_seqlens_k, # seqlen_k, v seqused_k, # seqused_k None, # block table None, # alibi slopse max_seqlen_q, # int max_seqlen_k, # int 0.0, # dropout 0.0, # softmax_scale False, # zero_tensors True, # is causal -1, # window size left -1, # window size right False, # return softmax None, # gen ) # mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i # const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. # const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. # c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i # const at::Tensor &cu_seqlens_q, // b+1 # const at::Tensor &cu_seqlens_k, // b+1 # c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. # c10::optional &block_table_, // batch_size x max_num_blocks_per_seq # c10::optional &alibi_slopes_, // num_heads or b x num_heads # int max_seqlen_q, # const int max_seqlen_k, # const float p_dropout, # const float softmax_scale, # const bool zero_tensors, # bool is_causal, # int window_size_left, # int window_size_right, # const bool return_softmax, # c10::optional gen_) { return hidden_states class LLAMADecodeLayer(nn.Module): def __init__(self, config: LlamaConfig, idx: int, *args, **kwargs): super().__init__(*args, **kwargs) class LLAMAModel(nn.Module): def __init__(self, config: LlamaConfig, *args, **kwargs): super().__init__(*args, **kwargs) self.num_layer = config.num_hidden_layers self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [LLAMADecodeLayer(config=config, idx=i) for i in range(self.num_layer)] ) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) def forward(self, input_ids, hidden_states, position_ids): if input_ids is not None: hidden_states = self.token_embed(input_ids) for layer in self.layers: hidden_states = layer(hidden_states, position_ids) output = self.lm_head(hidden_states) return output def unpad_input(self, hidden_states, attention_mask): hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...") valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1) # some time is eq(1) seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) ) hidden_states = hidden_states[indices].unsqueeze(0) return hidden_states, indices, cu_seqlens, max_seqlen_in_batch def pad_input(self, hidden_states, indices, batch, seqlen): """ :param hidden_states: Shape is [L,H] not [B,L,H] :param indices: from unpad_input return indices :param batch: :param seqlen: from unpad_input return max_seqlen_in_batch :return: """ output = torch.zeros( batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device, dtype=hidden_states.dtype, ) output[indices] = hidden_states return rearrange(output, "(b s) ... -> b s ...", b=batch)