add context parallel for training
This commit is contained in:
parent
1e229cae88
commit
ffea3d2ad1
@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
import distributed.process_group_manager as pgm
|
||||
import torch, torch.distributed as dist
|
||||
import distributed.process_group_manager as pgm
|
||||
|
||||
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
||||
|
||||
@ -44,6 +44,53 @@ def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, devi
|
||||
torch.cuda.synchronize()
|
||||
if VERBOSE: STEP += 1
|
||||
return recv_tensor
|
||||
|
||||
class ContextComms:
|
||||
def __init__(self, msg: str = ""):
|
||||
global STEP
|
||||
global VERBOSE
|
||||
self._pending_operations: List[dist.P2POp] = []
|
||||
self._active_requests = None
|
||||
self.rank = pgm.process_group_manager.cp_rank
|
||||
self.world_size = pgm.process_group_manager.cp_world_size
|
||||
self.send_rank = pgm.process_group_manager.cp_send_rank
|
||||
self.recv_rank = pgm.process_group_manager.cp_recv_rank
|
||||
if VERBOSE: print(f"RingComm ({msg}) | initialized | RANK:{self.rank} | "f"WORLD_SIZE:{self.world_size} | SEND_RANK:{self.send_rank} | "f"RECV_RANK:{self.recv_rank}", flush=True)
|
||||
|
||||
def send_recv(self, tensor_to_send, recv_tensor=None):
|
||||
if recv_tensor is None:
|
||||
result_tensor = torch.zeros_like(tensor_to_send)
|
||||
else:
|
||||
result_tensor = recv_tensor
|
||||
|
||||
send_operation = dist.P2POp(dist.isend, tensor_to_send, self.send_rank, group=pgm.process_group_manager.cp_group)
|
||||
recv_operation = dist.P2POp(dist.irecv, result_tensor, self.recv_rank, group=pgm.process_group_manager.cp_group)
|
||||
|
||||
self._pending_operations.extend([send_operation, recv_operation])
|
||||
|
||||
if VERBOSE:
|
||||
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:sending | TO:{self.send_rank} | TENSOR:{tensor_to_send}", flush=True)
|
||||
print(f"RingComm | send_recv | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:receiving | FROM:{self.recv_rank} | TENSOR:{result_tensor}", flush=True)
|
||||
return result_tensor
|
||||
|
||||
def commit(self):
|
||||
if self._active_requests is not None: raise RuntimeError("Commit called twice")
|
||||
self._active_requests = dist.batch_isend_irecv(self._pending_operations)
|
||||
if VERBOSE: print(f"RingComm | commit | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:committed | NUM_OPS:{len(self._pending_operations) // 2}", flush=True)
|
||||
|
||||
def wait(self):
|
||||
if self._active_requests is None: raise RuntimeError("Wait called before commit")
|
||||
for i, request in enumerate(self._active_requests):
|
||||
request.wait()
|
||||
if VERBOSE:
|
||||
operation_type = "send" if i % 2 == 0 else "receive"
|
||||
peer_rank = self.send_rank if operation_type == "send" else self.recv_rank
|
||||
print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True)
|
||||
torch.cuda.synchronize()
|
||||
self._active_requests = None
|
||||
self._pending_operations = []
|
||||
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
|
||||
|
||||
def all_reduce_loss_across_dp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# Reduce the loss across all workers so that every rank has the updated loss value.
|
||||
|
||||
@ -42,6 +42,8 @@ class ProcessGroupManager:
|
||||
self.cp_first_rank = self.cp_group_ids[0]
|
||||
self.cp_last_rank = self.cp_group_ids[-1]
|
||||
self.cp_world_size = dist.get_world_size(group=self.cp_group)
|
||||
self.cp_send_rank = self.cp_group_ids[(self.cp_rank + 1) % self.cp_size]
|
||||
self.cp_recv_rank = self.cp_group_ids[(self.cp_rank - 1) % self.cp_size]
|
||||
|
||||
# Pipeline parallelism
|
||||
self.pp_first_rank = self.pp_group_ids[0]
|
||||
|
||||
14
model.py
14
model.py
@ -70,9 +70,9 @@ class RotaryEmbedding(nn.Module):
|
||||
k_embed = (k * cos) + (self._rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config, is_causal):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
super(Attention, self).__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -90,10 +90,10 @@ class MultiHeadAttention(nn.Module):
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
self.rotary = RotaryEmbedding(dim=self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta)
|
||||
|
||||
def forward(self, x, position_ids):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
def forward(self, input_ids, position_ids):
|
||||
q = self.q_proj(input_ids)
|
||||
k = self.k_proj(input_ids)
|
||||
v = self.v_proj(input_ids)
|
||||
|
||||
batch, seq_len, _ = q.shape
|
||||
|
||||
@ -130,7 +130,7 @@ class DecoderLayer(nn.Module):
|
||||
def __init__(self, config, is_causal):
|
||||
super(DecoderLayer, self).__init__()
|
||||
|
||||
self.attention = MultiHeadAttention(config, is_causal)
|
||||
self.attention = Attention(config, is_causal)
|
||||
self.norm_attn = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.mlp = MLP(config)
|
||||
self.norm_mlp = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -1,15 +1,258 @@
|
||||
import torch.distributed as dist
|
||||
# Inspired by https://github.com/zhuzilin/ring-flash-attention
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import distributed as dist
|
||||
from typing import Any, Optional, Tuple
|
||||
from distributed.distributed_primtives import ContextComms
|
||||
from model import Attention
|
||||
import distributed.process_group_manager as pgm
|
||||
import lovely_tensors as lt; lt.monkey_patch()
|
||||
|
||||
from utils import print
|
||||
|
||||
class ContextParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.config = config
|
||||
|
||||
for name, module in self.model.named_modules():
|
||||
if isinstance(module, Attention) and not isinstance(module, RingAttention):
|
||||
parent_name, child_name = name.rsplit('.', 1)
|
||||
parent_module = self.model.get_submodule(parent_name)
|
||||
setattr(parent_module, child_name, RingAttention(module))
|
||||
del module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
|
||||
class RingAttention(nn.Module):
|
||||
def __init__(self, original_mha):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = original_mha.hidden_size
|
||||
self.num_heads = original_mha.num_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = original_mha.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.is_causal = original_mha.is_causal
|
||||
|
||||
# Copy the weights from the original Attention
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
|
||||
self.q_proj.weight.data.copy_(original_mha.q_proj.weight.data)
|
||||
self.k_proj.weight.data.copy_(original_mha.k_proj.weight.data)
|
||||
self.v_proj.weight.data.copy_(original_mha.v_proj.weight.data)
|
||||
self.o_proj.weight.data.copy_(original_mha.o_proj.weight.data)
|
||||
|
||||
self.rotary = original_mha.rotary
|
||||
|
||||
def forward(self, input_ids, position_ids):
|
||||
batch_size, seq_len, _ = input_ids.shape
|
||||
|
||||
q = self.q_proj(input_ids)
|
||||
k = self.k_proj(input_ids)
|
||||
v = self.v_proj(input_ids)
|
||||
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.rotary is not None:
|
||||
cos, sin = self.rotary(v, position_ids)
|
||||
q, k = self.rotary.apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
k = self._repeat_kv(k, self.num_key_value_groups)
|
||||
v = self._repeat_kv(v, self.num_key_value_groups)
|
||||
|
||||
# Apply ring attention
|
||||
sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
||||
output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal)
|
||||
|
||||
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
||||
output = self.o_proj(output)
|
||||
return output
|
||||
|
||||
def _repeat_kv(self, x, n_rep):
|
||||
batch, num_key_value_heads, seq_len, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
|
||||
return x.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
|
||||
|
||||
class RingAttentionFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale, is_causal):
|
||||
comm = ContextComms("comm")
|
||||
#NOTE: Find a better to save these tensors without cloning
|
||||
k_og = k.clone()
|
||||
v_og = v.clone()
|
||||
out, lse = None, None
|
||||
next_k, next_v = None, None
|
||||
|
||||
for step in range(comm.world_size):
|
||||
if step + 1 != comm.world_size:
|
||||
next_k = comm.send_recv(k)
|
||||
next_v = comm.send_recv(v)
|
||||
comm.commit()
|
||||
|
||||
if not is_causal or step <= comm.rank:
|
||||
block_out, block_lse = ring_attention_forward(
|
||||
q, k, v, sm_scale, is_causal and step == 0
|
||||
)
|
||||
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
|
||||
|
||||
if step + 1 != comm.world_size:
|
||||
comm.wait()
|
||||
k = next_k
|
||||
v = next_v
|
||||
|
||||
out = out.to(q.dtype)
|
||||
ctx.save_for_backward(q, k_og, v_og, out, lse.squeeze(-1))
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.is_causal = is_causal
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
|
||||
q, k, v, out, softmax_lse = ctx.saved_tensors
|
||||
sm_scale = ctx.sm_scale
|
||||
is_causal = ctx.is_causal
|
||||
|
||||
kv_comm = ContextComms("kv_comm")
|
||||
d_kv_comm = ContextComms("d_kv_comm")
|
||||
dq, dk, dv = None, None, None
|
||||
next_dk, next_dv = None, None
|
||||
|
||||
block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
||||
block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
|
||||
|
||||
next_dk, next_dv = None, None
|
||||
next_k, next_v = None, None
|
||||
|
||||
for step in range(kv_comm.world_size):
|
||||
if step + 1 != kv_comm.world_size:
|
||||
next_k = kv_comm.send_recv(k)
|
||||
next_v = kv_comm.send_recv(v)
|
||||
kv_comm.commit()
|
||||
|
||||
if step <= kv_comm.rank or not is_causal:
|
||||
bwd_causal = is_causal and step == 0
|
||||
|
||||
block_dq_buffer, block_dk_buffer, block_dv_buffer = ring_attention_backward(
|
||||
dout, q, k, v, out, softmax_lse, sm_scale, bwd_causal
|
||||
)
|
||||
|
||||
if dq is None:
|
||||
dq = block_dq_buffer.to(torch.float32)
|
||||
dk = block_dk_buffer.to(torch.float32)
|
||||
dv = block_dv_buffer.to(torch.float32)
|
||||
else:
|
||||
dq += block_dq_buffer
|
||||
d_kv_comm.wait()
|
||||
dk = block_dk_buffer + next_dk
|
||||
dv = block_dv_buffer + next_dv
|
||||
elif step != 0:
|
||||
d_kv_comm.wait()
|
||||
dk = next_dk
|
||||
dv = next_dv
|
||||
|
||||
if step + 1 != kv_comm.world_size:
|
||||
kv_comm.wait()
|
||||
k = next_k
|
||||
v = next_v
|
||||
|
||||
next_dk = d_kv_comm.send_recv(dk)
|
||||
next_dv = d_kv_comm.send_recv(dv)
|
||||
d_kv_comm.commit()
|
||||
|
||||
d_kv_comm.wait()
|
||||
|
||||
return dq, next_dk, next_dv, None, None
|
||||
|
||||
def ring_attention_forward(q, k, v, sm_scale, is_causal):
|
||||
batch_size, nheads, seqlen, d = q.shape
|
||||
S = torch.matmul(q, k.transpose(-2, -1)) * sm_scale
|
||||
|
||||
if is_causal:
|
||||
causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=q.device, dtype=torch.bool), diagonal=1)
|
||||
causal_mask = causal_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, nheads, seqlen, seqlen)
|
||||
S.masked_fill_(causal_mask, float('-inf'))
|
||||
|
||||
# Online softmax
|
||||
S_max = torch.max(S, dim=-1, keepdim=True)[0]
|
||||
exp_S = torch.exp(S - S_max)
|
||||
exp_sum = torch.sum(exp_S, dim=-1, keepdim=True)
|
||||
log_sum_exp = torch.log(exp_sum) + S_max
|
||||
P = exp_S / exp_sum
|
||||
O = torch.matmul(P, v)
|
||||
return O, log_sum_exp.squeeze(-1)
|
||||
|
||||
def ring_attention_backward(dO, Q, K, V, O, softmax_lse, sm_scale, is_causal):
|
||||
batch_size, nheads, seqlen, d = Q.shape
|
||||
|
||||
# Recreate S and P from log_sum_exp
|
||||
S = torch.matmul(Q, K.transpose(-2, -1)) * sm_scale
|
||||
if is_causal:
|
||||
causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=Q.device, dtype=torch.bool), diagonal=1)
|
||||
S = S.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), float('-inf'))
|
||||
|
||||
P = torch.exp(S - softmax_lse.unsqueeze(-1))
|
||||
# Step 1: Compute dV
|
||||
dV = torch.matmul(P.transpose(-2, -1), dO)
|
||||
# Step 2: Compute dP
|
||||
dP = torch.matmul(dO, V.transpose(-2, -1))
|
||||
# Step 3: Compute D
|
||||
D = torch.sum(dO * O, dim=-1, keepdim=True)
|
||||
# Step 4: Compute dS
|
||||
dS = P * (dP - D)
|
||||
# Apply causal mask to dS if is_causal is True
|
||||
if is_causal:
|
||||
dS = dS.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), 0)
|
||||
# Step 5: Compute dQ
|
||||
dQ = torch.matmul(dS, K) * sm_scale
|
||||
# Step 6: Compute dK
|
||||
dK = torch.matmul(dS.transpose(-2, -1), Q) * sm_scale
|
||||
return dQ, dK, dV
|
||||
|
||||
def update_out_and_lse(
|
||||
out: Optional[torch.Tensor],
|
||||
lse: Optional[torch.Tensor],
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
slice_: Optional[Any] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
def _update(current_out, current_lse):
|
||||
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
|
||||
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
|
||||
# For additional context and discussion, please refer to:
|
||||
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
|
||||
current_out = current_out - F.sigmoid(block_lse - current_lse) * (current_out - block_out)
|
||||
current_lse = current_lse - F.logsigmoid(current_lse - block_lse)
|
||||
return current_out, current_lse
|
||||
|
||||
block_out = block_out.to(torch.float32)
|
||||
block_lse = block_lse.unsqueeze(dim=-1)
|
||||
|
||||
if out is None:
|
||||
if slice_ is not None:
|
||||
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
|
||||
return block_out, block_lse
|
||||
|
||||
if slice_ is not None:
|
||||
out[slice_], lse[slice_] = _update(out[slice_], lse[slice_])
|
||||
else:
|
||||
out, lse = _update(out, lse)
|
||||
|
||||
return out, lse
|
||||
37
train.py
37
train.py
@ -11,7 +11,7 @@ import argparse
|
||||
|
||||
import distributed.process_group_manager as pgm
|
||||
from distributed.distributed_primtives import all_reduce_gradients_across_dp_cp_ranks
|
||||
from utils import set_all_seed, print, display_4D_parallelism_grid
|
||||
from utils import set_all_seed, print
|
||||
from distributed.process_group_manager import setup_process_group_manager
|
||||
from parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from parallel.data_parallel import DataParallel
|
||||
@ -26,6 +26,8 @@ class MicroBatchDataLoader(DataLoader):
|
||||
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
|
||||
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
@ -42,8 +44,21 @@ class MicroBatchDataLoader(DataLoader):
|
||||
def collate_batch(self, batch_data):
|
||||
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
|
||||
batch_size, seq_len = batch_input_ids.shape
|
||||
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
|
||||
start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu
|
||||
end_idx = start_idx + self.seq_length_per_gpu
|
||||
input_ids = batch_input_ids[:, start_idx:end_idx].contiguous()
|
||||
target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous()
|
||||
position_index = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
|
||||
local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool))
|
||||
attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous()
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"target_ids": target_ids,
|
||||
"position_index": position_index,
|
||||
"attn_mask": attn_mask,
|
||||
"hidden_states": None
|
||||
}
|
||||
|
||||
def train_step(model, data_loader, device):
|
||||
total_loss = 0.0
|
||||
@ -55,11 +70,11 @@ def train_step(model, data_loader, device):
|
||||
position_ids = batch["position_index"].to(device)
|
||||
target_ids = batch["target_ids"].to(device)
|
||||
|
||||
outputs = model(input_ids=input_ids, position_ids=position_ids)
|
||||
logits = outputs.logits
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Use your suggested cross_entropy calculation
|
||||
loss = F.cross_entropy(logits.transpose(1, 2), target_ids, reduction='mean')
|
||||
outputs = model(input_ids=input_ids, position_ids=position_ids)
|
||||
|
||||
loss = F.cross_entropy(outputs.view(batch_size * seq_len, -1), target_ids.view(-1), reduction="mean")
|
||||
|
||||
loss.backward()
|
||||
|
||||
@ -91,6 +106,8 @@ if __name__ == "__main__":
|
||||
|
||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
|
||||
|
||||
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
|
||||
|
||||
backend = "gloo" if args.use_cpu else "nccl"
|
||||
|
||||
if backend == "nccl":
|
||||
@ -140,6 +157,9 @@ if __name__ == "__main__":
|
||||
|
||||
model.load_state_dict(torch.load("smollm.pth"))
|
||||
|
||||
# if pgm.process_group_manager.tp_world_size > 1:
|
||||
# model = TensorParallel(model, config).to(device)
|
||||
|
||||
if pgm.process_group_manager.cp_size > 1:
|
||||
model = ContextParallel(model, config).to(device)
|
||||
|
||||
@ -149,13 +169,10 @@ if __name__ == "__main__":
|
||||
if pgm.process_group_manager.dp_world_size > 1:
|
||||
model = DataParallel(model, config).to(device)
|
||||
|
||||
# if pgm.process_group_manager.tp_world_size > 1:
|
||||
# model = TensorParallel(model, config).to(device)
|
||||
|
||||
model.train()
|
||||
|
||||
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES)
|
||||
tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, config.hidden_size)
|
||||
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
trained_tokens, step = 0, 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user