some fixes
This commit is contained in:
parent
b7f3e253be
commit
46af5b0425
2
model.py
2
model.py
@ -21,7 +21,7 @@ def get_cos_sin(seq_length, head_dim, base=500000.0):
|
||||
assert head_dim%2==0
|
||||
# Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU
|
||||
theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim))
|
||||
dtype = torch.bfloat16 if os.getenv('DATA_TYPE', 'bfloat16') == 'bfloat16' else torch.float32
|
||||
dtype = torch.bfloat16 if os.getenv('DTYPE', 'bfloat16') == 'bfloat16' else torch.float32
|
||||
device = torch.device('cuda') if os.getenv('DEVICE', 'cuda') == 'cuda' else torch.device('cpu')
|
||||
position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1]
|
||||
# To match transformers implementation. m * theta should be computed on GPU
|
||||
|
||||
@ -5,84 +5,8 @@ import torch.nn.functional as F
|
||||
from torch import distributed as dist
|
||||
from typing import Any, Optional, Tuple
|
||||
from src.distributed.distributed_primtives import ContextComms
|
||||
# from model import Attention
|
||||
import src.distributed.process_group_manager as pgm
|
||||
|
||||
# class ContextParallel(nn.Module):
|
||||
# def __init__(self, model, config):
|
||||
# super().__init__()
|
||||
|
||||
# self.model = model
|
||||
|
||||
# for name, module in model.named_modules():
|
||||
# if isinstance(module, Attention) and not isinstance(module, RingAttention):
|
||||
# parent_name, child_name = name.rsplit('.', 1)
|
||||
# parent_module = model.get_submodule(parent_name)
|
||||
# setattr(parent_module, child_name, RingAttention(module))
|
||||
# del module
|
||||
|
||||
# def __getattr__(self, name):
|
||||
# try:
|
||||
# return super().__getattr__(name)
|
||||
# except AttributeError:
|
||||
# return getattr(self.model, name)
|
||||
|
||||
# class RingAttention(nn.Module):
|
||||
# def __init__(self, original_mha):
|
||||
# super().__init__()
|
||||
|
||||
# self.hidden_size = original_mha.hidden_size
|
||||
# self.num_heads = original_mha.num_heads
|
||||
# self.head_dim = self.hidden_size // self.num_heads
|
||||
# self.num_key_value_heads = original_mha.num_key_values
|
||||
# self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
# self.is_causal = original_mha.is_causal
|
||||
|
||||
# # Copy the weights from the original Attention
|
||||
# self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
# self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
# self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
# self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
|
||||
# self.q_proj.weight.data.copy_(original_mha.q_proj.weight.data)
|
||||
# self.k_proj.weight.data.copy_(original_mha.k_proj.weight.data)
|
||||
# self.v_proj.weight.data.copy_(original_mha.v_proj.weight.data)
|
||||
# self.out_proj.weight.data.copy_(original_mha.out_proj.weight.data)
|
||||
|
||||
# self.rotary = original_mha.rotary
|
||||
|
||||
# def forward(self, input_ids, position_ids):
|
||||
# batch_size, seq_len, _ = input_ids.shape
|
||||
|
||||
# q = self.q_proj(input_ids)
|
||||
# k = self.k_proj(input_ids)
|
||||
# v = self.v_proj(input_ids)
|
||||
|
||||
# q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
# k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
# v = v.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# if self.rotary is not None:
|
||||
# cos, sin = self.rotary(v, position_ids)
|
||||
# q, k = self.rotary.apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
# k = self._repeat_kv(k, self.num_key_value_groups)
|
||||
# v = self._repeat_kv(v, self.num_key_value_groups)
|
||||
|
||||
# sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
||||
# output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal)
|
||||
|
||||
# output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
||||
# output = self.out_proj(output)
|
||||
# return output
|
||||
|
||||
# def _repeat_kv(self, x, n_rep):
|
||||
# batch, num_key_value_heads, seq_len, head_dim = x.shape
|
||||
# if n_rep == 1:
|
||||
# return x
|
||||
# x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
|
||||
# return x.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
|
||||
|
||||
def ring_attention(q, k, v, sm_scale, is_causal):
|
||||
return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal)
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
|
||||
input_tensor = None
|
||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
|
||||
else:
|
||||
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.dtype)
|
||||
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
|
||||
for _ in range(num_warmup_microbatches): # Cooldown backward passes
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
|
||||
3
train.py
3
train.py
@ -79,7 +79,6 @@ if __name__ == "__main__":
|
||||
os.environ["DEVICE"] = "cuda" if not args.use_cpu else "cpu"
|
||||
|
||||
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not args.use_cpu else torch.float32 # if GPU is not available or not supported, use torch.float32
|
||||
os.environ["DTYPE"] = "bfloat16" if dtype == torch.bfloat16 else "float32"
|
||||
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
@ -186,7 +185,7 @@ if __name__ == "__main__":
|
||||
#TODO: Add activation checkpointing
|
||||
#TODO: add gradient accumulation
|
||||
|
||||
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
|
||||
while trained_tokens < MAX_TOKENS:
|
||||
#TODO: Add epoch support
|
||||
# data_loader.set_epoch(step)
|
||||
step_start_time = time.time()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user