torch_ext/fi/models/llama.py
2025-03-28 22:19:03 +08:00

154 lines
6.2 KiB
Python

# coding=utf-8
import torch
import torch.nn as nn
from dataclasses import dataclass
from transformers.models.llama.configuration_llama import LlamaConfig
from einops import rearrange
import torch.nn.functional as F
from flash_attn_2_cuda import varlen_fwd
class MLP(nn.Module):
def __init__(
self,
config: LlamaConfig,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up = nn.Linear(self.hidden_size, self.intermediate_size)
self.gate = nn.Linear(self.hidden_size, self.intermediate_size)
self.down = nn.Linear(self.intermediate_size, self.hidden_size)
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.down(self.gate(hidden_states) * self.up(hidden_states))
return hidden_states
class LLAMAAttention(nn.Module):
def __init__(self, config: LlamaConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.q = nn.Linear(config.hidden_size, config.hidden_size)
self.k = nn.Linear(config.hidden_size, config.hidden_size)
self.v = nn.Linear(config.hidden_size, config.hidden_size)
self.o = nn.Linear(config.hidden_size, config.hidden_size)
self.num_head = config.num_attention_heads
self.kv_head = config.num_key_value_heads
self.head_dim = config.head_dim
assert (
self.num_head * self.head_dim == config.hidden_size
), "make sure the num_head*head_dim == hidden_size"
def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor):
q = self.q(hidden_states)
k = self.k(hidden_states)
v = self.v(hidden_states)
q = q.view(-1, self.num_head, self.head_dim)
k = k.view(-1, self.kv_head, self.head_dim)
v = v.view(-1, self.kv_head, self.head_dim)
# process positionids
# varlen_fwd # this function is the core of the process.
hidden_states = varlen_fwd(
q,
k,
v,
None, # output
cu_seqlens_q, # seqlen_q
cu_seqlens_k, # seqlen_k, v
seqused_k, # seqused_k
None, # block table
None, # alibi slopse
max_seqlen_q, # int
max_seqlen_k, # int
0.0, # dropout
0.0, # softmax_scale
False, # zero_tensors
True, # is causal
-1, # window size left
-1, # window size right
False, # return softmax
None, # gen
)
# mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
# const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
# const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
# c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
# const at::Tensor &cu_seqlens_q, // b+1
# const at::Tensor &cu_seqlens_k, // b+1
# c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
# c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
# c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
# int max_seqlen_q,
# const int max_seqlen_k,
# const float p_dropout,
# const float softmax_scale,
# const bool zero_tensors,
# bool is_causal,
# int window_size_left,
# int window_size_right,
# const bool return_softmax,
# c10::optional<at::Generator> gen_) {
return hidden_states
class LLAMADecodeLayer(nn.Module):
def __init__(self, config: LlamaConfig, idx: int, *args, **kwargs):
super().__init__(*args, **kwargs)
class LLAMAModel(nn.Module):
def __init__(self, config: LlamaConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_layer = config.num_hidden_layers
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[LLAMADecodeLayer(config=config, idx=i) for i in range(self.num_layer)]
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids, hidden_states, position_ids):
if input_ids is not None:
hidden_states = self.token_embed(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, position_ids)
output = self.lm_head(hidden_states)
return output
def unpad_input(self, hidden_states, attention_mask):
hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...")
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1) # some time is eq(1)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
hidden_states = hidden_states[indices].unsqueeze(0)
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def pad_input(self, hidden_states, indices, batch, seqlen):
"""
:param hidden_states: Shape is [L,H] not [B,L,H]
:param indices: from unpad_input return indices
:param batch:
:param seqlen: from unpad_input return max_seqlen_in_batch
:return:
"""
output = torch.zeros(
batch * seqlen,
*hidden_states.shape[1:],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
output[indices] = hidden_states
return rearrange(output, "(b s) ... -> b s ...", b=batch)