add cuda kernels
This commit is contained in:
parent
9a7904d5d6
commit
a6d79b07b5
42
model.py
42
model.py
@ -6,28 +6,12 @@ import torch.nn.init as init
|
|||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
from flash_attn.layers.rotary import apply_rotary_emb
|
from flash_attn.layers.rotary import apply_rotary_emb
|
||||||
import src.distributed.process_group_manager as pgm
|
import src.distributed.process_group_manager as pgm
|
||||||
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
|
from src.nn.layer_norm import LlamaRMSNorm, TritonRMSNorm
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32
|
dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32
|
||||||
init_method = init.xavier_normal_
|
init_method = init.xavier_normal_
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
|
||||||
def __init__(self, hidden_size, eps=1e-5):
|
|
||||||
"""
|
|
||||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(x, cos, sin):
|
def apply_rotary_pos_emb(x, cos, sin):
|
||||||
#TODO: Maybe do class RotaryEmbedding(nn.Module) later
|
#TODO: Maybe do class RotaryEmbedding(nn.Module) later
|
||||||
batch_size, num_head, seq_length, head_dim = x.size()
|
batch_size, num_head, seq_length, head_dim = x.size()
|
||||||
@ -66,14 +50,7 @@ class Attention(nn.Module):
|
|||||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False)
|
self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False)
|
||||||
self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
|
self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
|
||||||
self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
|
self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
|
||||||
# self.q_proj = ColumnParallelLinear(config.hidden_size, self.num_heads*self.head_dim, bias=False, gather_output=False, init_method=init_method) # why the init method is x? Xavier is better?
|
|
||||||
# self.k_proj = ColumnParallelLinear(config.hidden_size, self.num_key_values*self.head_dim, bias=False, gather_output=False, init_method=init_method)
|
|
||||||
# self.v_proj = ColumnParallelLinear(config.hidden_size, self.num_key_values*self.head_dim, bias=False, gather_output=False, init_method=init_method)
|
|
||||||
# if os.getenv('FLASH_ROPE', '1') == '1':
|
|
||||||
# self.flash_rope = FlashRotaryEmbedding(dim=self.head_dim, interleaved=False, base=500000.0)
|
|
||||||
|
|
||||||
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||||
# self.out_proj = RowParallelLinear(self.num_heads * self.head_dim, config.hidden_size, bias=False, input_is_parallel=True, init_method=init_method)
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
## TODO support mask
|
## TODO support mask
|
||||||
@ -83,7 +60,7 @@ class Attention(nn.Module):
|
|||||||
q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim]
|
q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim]
|
||||||
k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim]
|
k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim]
|
||||||
v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim]
|
v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim]
|
||||||
if os.getenv('FLASH_ROPE', '0') != '1':
|
if os.getenv('FLASH_ATTEN', '1') != '1':
|
||||||
q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim]
|
q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim]
|
||||||
k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim]
|
k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim]
|
||||||
v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim]
|
v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim]
|
||||||
@ -101,8 +78,9 @@ class Attention(nn.Module):
|
|||||||
k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim]
|
k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim]
|
||||||
v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim]
|
v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim]
|
||||||
|
|
||||||
if os.getenv('ATTENTION', 'SDPA') == 'SDPA':
|
if os.getenv('FLASH_ATTEN', '1') != '1':
|
||||||
causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1.
|
causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1.
|
||||||
|
# Pytorch scaled dot product attention
|
||||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=causal) # [batch_size, num_heads, seq_length, head_dim]
|
out = F.scaled_dot_product_attention(q, k, v, is_causal=causal) # [batch_size, num_heads, seq_length, head_dim]
|
||||||
out = out.transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
|
out = out.transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
|
||||||
else:
|
else:
|
||||||
@ -119,10 +97,7 @@ class MLP(nn.Module):
|
|||||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||||
# self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
|
|
||||||
# self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
|
|
||||||
# self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False, input_is_parallel=True, init_method=init_method)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
#TODO: dont do single line operations as it is harder to debug
|
#TODO: dont do single line operations as it is harder to debug
|
||||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
@ -131,7 +106,7 @@ class DecoderLayer(nn.Module):
|
|||||||
# RMSNorm -> Attention -> Residual -> RMSNorm -> MLP -> Residual
|
# RMSNorm -> Attention -> Residual -> RMSNorm -> MLP -> Residual
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
RMSNorm = LlamaRMSNorm
|
RMSNorm = LlamaRMSNorm if os.getenv('FLASH_ATTEN', '1') != '1' else TritonRMSNorm
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.attention = Attention(config, layer_idx = layer_idx)
|
self.attention = Attention(config, layer_idx = layer_idx)
|
||||||
@ -167,11 +142,10 @@ class Llama(nn.Module):
|
|||||||
|
|
||||||
# modules
|
# modules
|
||||||
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
|
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
|
||||||
# self.embedding = VocabParallelEmbedding(self.vocab_size, self.hidden_size, init_method=init_method)
|
|
||||||
self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)])
|
self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)])
|
||||||
self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
|
self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
|
||||||
# self.final_proj = ColumnParallelLinear(self.hidden_size, self.vocab_size, bias=False, gather_output=True, init_method=init_method) # we can also not gather the output. TODO: add vocab_parallel_cross_entropy
|
RMSNorm = LlamaRMSNorm if os.getenv('FLASH_ATTEN', '1') != '1' else TritonRMSNorm
|
||||||
self.final_norm = LlamaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
self.final_norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None):
|
def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None):
|
||||||
batch_size, seq_length = input_ids.size()
|
batch_size, seq_length = input_ids.size()
|
||||||
|
|||||||
43
src/nn/layer_norm.py
Normal file
43
src/nn/layer_norm.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
||||||
|
|
||||||
|
class TritonRMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
|
||||||
|
):
|
||||||
|
return layer_norm_fn(
|
||||||
|
hidden_states,
|
||||||
|
self.weight,
|
||||||
|
None,
|
||||||
|
residual=residual,
|
||||||
|
eps=self.eps,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
prenorm=prenorm,
|
||||||
|
residual_in_fp32=residual_in_fp32,
|
||||||
|
is_rms_norm=True,
|
||||||
|
return_dropout_mask=return_dropout_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
class LlamaRMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-5):
|
||||||
|
"""
|
||||||
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
6
train.py
6
train.py
@ -177,6 +177,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
os.environ["OMP_NUM_THREADS"] = "1"
|
os.environ["OMP_NUM_THREADS"] = "1"
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
os.environ["FLASH_ATTEN"] = "1" # Use operations from flash attention repo to accelerate the training. Model dtpe should be torch.float16!
|
||||||
|
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
@ -185,8 +186,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
|
# SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
|
||||||
## hyperparameters
|
## hyperparameters
|
||||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42
|
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 32, 4, 3e-4, 100000, int(10e8), 42
|
||||||
grad_acc = 16
|
grad_acc = 16
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
|
||||||
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
|
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
|
||||||
|
|
||||||
@ -246,7 +248,7 @@ if __name__ == "__main__":
|
|||||||
if pgm.process_group_manager.dp_world_size > 1:
|
if pgm.process_group_manager.dp_world_size > 1:
|
||||||
model = DataParallel(model)
|
model = DataParallel(model)
|
||||||
|
|
||||||
model.to(device)
|
model.to(dtype).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, grad_acc = grad_acc,num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
|
data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, grad_acc = grad_acc,num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user