add cuda kernels

This commit is contained in:
zzhhjjj 2024-10-22 22:38:29 +00:00
parent 9a7904d5d6
commit a6d79b07b5
3 changed files with 55 additions and 36 deletions

View File

@ -6,28 +6,12 @@ import torch.nn.init as init
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.layers.rotary import apply_rotary_emb
import src.distributed.process_group_manager as pgm
from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
from src.nn.layer_norm import LlamaRMSNorm, TritonRMSNorm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32
init_method = init.xavier_normal_
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def apply_rotary_pos_emb(x, cos, sin):
#TODO: Maybe do class RotaryEmbedding(nn.Module) later
batch_size, num_head, seq_length, head_dim = x.size()
@ -66,14 +50,7 @@ class Attention(nn.Module):
self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False)
# self.q_proj = ColumnParallelLinear(config.hidden_size, self.num_heads*self.head_dim, bias=False, gather_output=False, init_method=init_method) # why the init method is x? Xavier is better?
# self.k_proj = ColumnParallelLinear(config.hidden_size, self.num_key_values*self.head_dim, bias=False, gather_output=False, init_method=init_method)
# self.v_proj = ColumnParallelLinear(config.hidden_size, self.num_key_values*self.head_dim, bias=False, gather_output=False, init_method=init_method)
# if os.getenv('FLASH_ROPE', '1') == '1':
# self.flash_rope = FlashRotaryEmbedding(dim=self.head_dim, interleaved=False, base=500000.0)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# self.out_proj = RowParallelLinear(self.num_heads * self.head_dim, config.hidden_size, bias=False, input_is_parallel=True, init_method=init_method)
self.layer_idx = layer_idx
## TODO support mask
@ -83,7 +60,7 @@ class Attention(nn.Module):
q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim]
k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim]
v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim]
if os.getenv('FLASH_ROPE', '0') != '1':
if os.getenv('FLASH_ATTEN', '1') != '1':
q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim]
k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim]
v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim]
@ -101,8 +78,9 @@ class Attention(nn.Module):
k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim]
v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim]
if os.getenv('ATTENTION', 'SDPA') == 'SDPA':
if os.getenv('FLASH_ATTEN', '1') != '1':
causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1.
# Pytorch scaled dot product attention
out = F.scaled_dot_product_attention(q, k, v, is_causal=causal) # [batch_size, num_heads, seq_length, head_dim]
out = out.transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
else:
@ -119,10 +97,7 @@ class MLP(nn.Module):
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
# self.up_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
# self.gate_proj = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=False, gather_output=False, init_method=init_method)
# self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False, input_is_parallel=True, init_method=init_method)
def forward(self, x):
#TODO: dont do single line operations as it is harder to debug
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
@ -131,7 +106,7 @@ class DecoderLayer(nn.Module):
# RMSNorm -> Attention -> Residual -> RMSNorm -> MLP -> Residual
def __init__(self, config, layer_idx):
super().__init__()
RMSNorm = LlamaRMSNorm
RMSNorm = LlamaRMSNorm if os.getenv('FLASH_ATTEN', '1') != '1' else TritonRMSNorm
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention = Attention(config, layer_idx = layer_idx)
@ -167,11 +142,10 @@ class Llama(nn.Module):
# modules
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
# self.embedding = VocabParallelEmbedding(self.vocab_size, self.hidden_size, init_method=init_method)
self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)])
self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
# self.final_proj = ColumnParallelLinear(self.hidden_size, self.vocab_size, bias=False, gather_output=True, init_method=init_method) # we can also not gather the output. TODO: add vocab_parallel_cross_entropy
self.final_norm = LlamaRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
RMSNorm = LlamaRMSNorm if os.getenv('FLASH_ATTEN', '1') != '1' else TritonRMSNorm
self.final_norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None):
batch_size, seq_length = input_ids.size()

43
src/nn/layer_norm.py Normal file
View File

@ -0,0 +1,43 @@
import torch
from torch import nn
from flash_attn.ops.triton.layer_norm import layer_norm_fn
class TritonRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
self.register_parameter("bias", None)
def forward(
self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
return layer_norm_fn(
hidden_states,
self.weight,
None,
residual=residual,
eps=self.eps,
dropout_p=dropout_p,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=True,
return_dropout_mask=return_dropout_mask,
)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

View File

@ -177,6 +177,7 @@ if __name__ == "__main__":
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["FLASH_ATTEN"] = "1" # Use operations from flash attention repo to accelerate the training. Model dtpe should be torch.float16!
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
@ -185,8 +186,9 @@ if __name__ == "__main__":
# SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
## hyperparameters
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 32, 4, 3e-4, 100000, int(10e8), 42
grad_acc = 16
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
@ -246,7 +248,7 @@ if __name__ == "__main__":
if pgm.process_group_manager.dp_world_size > 1:
model = DataParallel(model)
model.to(device)
model.to(dtype).to(device)
model.train()
data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, grad_acc = grad_acc,num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)