From 46af5b042509a83e609cf998ed41dd3e00521dfa Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 29 Oct 2024 14:08:08 +0000 Subject: [PATCH] some fixes --- model.py | 2 +- src/parallel/context_parallel.py | 76 ------------------------------- src/parallel/pipeline_parallel.py | 2 +- train.py | 3 +- 4 files changed, 3 insertions(+), 80 deletions(-) diff --git a/model.py b/model.py index f362cf7..83417ac 100644 --- a/model.py +++ b/model.py @@ -21,7 +21,7 @@ def get_cos_sin(seq_length, head_dim, base=500000.0): assert head_dim%2==0 # Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim)) - dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32 + dtype = torch.bfloat16 if os.getenv('DTYPE', 'bfloat16') == 'bfloat16' else torch.float32 device = torch.device('cuda') if os.getenv('DEVICE', 'cuda') == 'cuda' else torch.device('cpu') position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1] # To match transformers implementation. m * theta should be computed on GPU diff --git a/src/parallel/context_parallel.py b/src/parallel/context_parallel.py index 80d2447..fa3fc63 100644 --- a/src/parallel/context_parallel.py +++ b/src/parallel/context_parallel.py @@ -5,84 +5,8 @@ import torch.nn.functional as F from torch import distributed as dist from typing import Any, Optional, Tuple from src.distributed.distributed_primtives import ContextComms -# from model import Attention import src.distributed.process_group_manager as pgm -# class ContextParallel(nn.Module): -# def __init__(self, model, config): -# super().__init__() - -# self.model = model - -# for name, module in model.named_modules(): -# if isinstance(module, Attention) and not isinstance(module, RingAttention): -# parent_name, child_name = name.rsplit('.', 1) -# parent_module = model.get_submodule(parent_name) -# setattr(parent_module, child_name, RingAttention(module)) -# del module - -# def __getattr__(self, name): -# try: -# return super().__getattr__(name) -# except AttributeError: -# return getattr(self.model, name) - -# class RingAttention(nn.Module): -# def __init__(self, original_mha): -# super().__init__() - -# self.hidden_size = original_mha.hidden_size -# self.num_heads = original_mha.num_heads -# self.head_dim = self.hidden_size // self.num_heads -# self.num_key_value_heads = original_mha.num_key_values -# self.num_key_value_groups = self.num_heads // self.num_key_value_heads -# self.is_causal = original_mha.is_causal - -# # Copy the weights from the original Attention -# self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) -# self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) -# self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) -# self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - -# self.q_proj.weight.data.copy_(original_mha.q_proj.weight.data) -# self.k_proj.weight.data.copy_(original_mha.k_proj.weight.data) -# self.v_proj.weight.data.copy_(original_mha.v_proj.weight.data) -# self.out_proj.weight.data.copy_(original_mha.out_proj.weight.data) - -# self.rotary = original_mha.rotary - -# def forward(self, input_ids, position_ids): -# batch_size, seq_len, _ = input_ids.shape - -# q = self.q_proj(input_ids) -# k = self.k_proj(input_ids) -# v = self.v_proj(input_ids) - -# q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) -# k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) -# v = v.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - -# if self.rotary is not None: -# cos, sin = self.rotary(v, position_ids) -# q, k = self.rotary.apply_rotary_pos_emb(q, k, cos, sin) - -# k = self._repeat_kv(k, self.num_key_value_groups) -# v = self._repeat_kv(v, self.num_key_value_groups) - -# sm_scale = 1.0 / (q.size(-1) ** 0.5) -# output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal) - -# output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) -# output = self.out_proj(output) -# return output - -# def _repeat_kv(self, x, n_rep): -# batch, num_key_value_heads, seq_len, head_dim = x.shape -# if n_rep == 1: -# return x -# x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim) -# return x.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim) - def ring_attention(q, k, v, sm_scale, is_causal): return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal) diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index 260a8e7..b4a1064 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -97,7 +97,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype): input_tensor = None pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype) else: - input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.dtype) + input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=dtype) for _ in range(num_warmup_microbatches): # Cooldown backward passes input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) diff --git a/train.py b/train.py index a940738..db5f701 100644 --- a/train.py +++ b/train.py @@ -79,7 +79,6 @@ if __name__ == "__main__": os.environ["DEVICE"] = "cuda" if not args.use_cpu else "cpu" dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not args.use_cpu else torch.float32 # if GPU is not available or not supported, use torch.float32 - os.environ["DTYPE"] = "bfloat16" if dtype == torch.bfloat16 else "float32" assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16" local_rank = int(os.environ["LOCAL_RANK"]) @@ -186,7 +185,7 @@ if __name__ == "__main__": #TODO: Add activation checkpointing #TODO: add gradient accumulation - while MAX_TOKENS is None or trained_tokens < MAX_TOKENS: + while trained_tokens < MAX_TOKENS: #TODO: Add epoch support # data_loader.set_epoch(step) step_start_time = time.time()