some fixes

This commit is contained in:
ferdinand.mom 2024-10-29 14:08:08 +00:00
parent b7f3e253be
commit 46af5b0425
4 changed files with 3 additions and 80 deletions

View File

@ -21,7 +21,7 @@ def get_cos_sin(seq_length, head_dim, base=500000.0):
assert head_dim%2==0
# Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU
theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim))
dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32
dtype = torch.bfloat16 if os.getenv('DTYPE', 'bfloat16') == 'bfloat16' else torch.float32
device = torch.device('cuda') if os.getenv('DEVICE', 'cuda') == 'cuda' else torch.device('cpu')
position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1]
# To match transformers implementation. m * theta should be computed on GPU

View File

@ -5,84 +5,8 @@ import torch.nn.functional as F
from torch import distributed as dist
from typing import Any, Optional, Tuple
from src.distributed.distributed_primtives import ContextComms
# from model import Attention
import src.distributed.process_group_manager as pgm
# class ContextParallel(nn.Module):
# def __init__(self, model, config):
# super().__init__()
# self.model = model
# for name, module in model.named_modules():
# if isinstance(module, Attention) and not isinstance(module, RingAttention):
# parent_name, child_name = name.rsplit('.', 1)
# parent_module = model.get_submodule(parent_name)
# setattr(parent_module, child_name, RingAttention(module))
# del module
# def __getattr__(self, name):
# try:
# return super().__getattr__(name)
# except AttributeError:
# return getattr(self.model, name)
# class RingAttention(nn.Module):
# def __init__(self, original_mha):
# super().__init__()
# self.hidden_size = original_mha.hidden_size
# self.num_heads = original_mha.num_heads
# self.head_dim = self.hidden_size // self.num_heads
# self.num_key_value_heads = original_mha.num_key_values
# self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# self.is_causal = original_mha.is_causal
# # Copy the weights from the original Attention
# self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
# self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
# self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
# self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
# self.q_proj.weight.data.copy_(original_mha.q_proj.weight.data)
# self.k_proj.weight.data.copy_(original_mha.k_proj.weight.data)
# self.v_proj.weight.data.copy_(original_mha.v_proj.weight.data)
# self.out_proj.weight.data.copy_(original_mha.out_proj.weight.data)
# self.rotary = original_mha.rotary
# def forward(self, input_ids, position_ids):
# batch_size, seq_len, _ = input_ids.shape
# q = self.q_proj(input_ids)
# k = self.k_proj(input_ids)
# v = self.v_proj(input_ids)
# q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# v = v.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# if self.rotary is not None:
# cos, sin = self.rotary(v, position_ids)
# q, k = self.rotary.apply_rotary_pos_emb(q, k, cos, sin)
# k = self._repeat_kv(k, self.num_key_value_groups)
# v = self._repeat_kv(v, self.num_key_value_groups)
# sm_scale = 1.0 / (q.size(-1) ** 0.5)
# output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal)
# output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# output = self.out_proj(output)
# return output
# def _repeat_kv(self, x, n_rep):
# batch, num_key_value_heads, seq_len, head_dim = x.shape
# if n_rep == 1:
# return x
# x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
# return x.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
def ring_attention(q, k, v, sm_scale, is_causal):
return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal)

View File

@ -97,7 +97,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
input_tensor = None
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
else:
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.dtype)
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=dtype)
for _ in range(num_warmup_microbatches): # Cooldown backward passes
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)

View File

@ -79,7 +79,6 @@ if __name__ == "__main__":
os.environ["DEVICE"] = "cuda" if not args.use_cpu else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not args.use_cpu else torch.float32 # if GPU is not available or not supported, use torch.float32
os.environ["DTYPE"] = "bfloat16" if dtype == torch.bfloat16 else "float32"
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
local_rank = int(os.environ["LOCAL_RANK"])
@ -186,7 +185,7 @@ if __name__ == "__main__":
#TODO: Add activation checkpointing
#TODO: add gradient accumulation
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
while trained_tokens < MAX_TOKENS:
#TODO: Add epoch support
# data_loader.set_epoch(step)
step_start_time = time.time()