enhance parallel context to handle 3D
This commit is contained in:
parent
c36d415b47
commit
bce75fd508
@ -9,18 +9,18 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None):
|
|||||||
global STEP
|
global STEP
|
||||||
global VERBOSE
|
global VERBOSE
|
||||||
if operation == 'recv_forward':
|
if operation == 'recv_forward':
|
||||||
if pc.parallel_context.is_pipeline_first_stage: return None
|
if pc.parallel_context.pp_is_first_stage: return None
|
||||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
||||||
src = pc.parallel_context.pp_prev_rank
|
src = pc.parallel_context.pp_prev_rank
|
||||||
elif operation == 'send_forward':
|
elif operation == 'send_forward':
|
||||||
if pc.parallel_context.is_pipeline_last_stage: return
|
if pc.parallel_context.pp_is_last_stage: return
|
||||||
dest = pc.parallel_context.pp_next_rank
|
dest = pc.parallel_context.pp_next_rank
|
||||||
elif operation == 'recv_backward':
|
elif operation == 'recv_backward':
|
||||||
if pc.parallel_context.is_pipeline_last_stage: return None
|
if pc.parallel_context.pp_is_last_stage: return None
|
||||||
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype)
|
||||||
src = pc.parallel_context.pp_next_rank
|
src = pc.parallel_context.pp_next_rank
|
||||||
elif operation == 'send_backward':
|
elif operation == 'send_backward':
|
||||||
if pc.parallel_context.is_pipeline_first_stage: return
|
if pc.parallel_context.pp_is_first_stage: return
|
||||||
dest = pc.parallel_context.pp_prev_rank
|
dest = pc.parallel_context.pp_prev_rank
|
||||||
is_send = operation.startswith('send')
|
is_send = operation.startswith('send')
|
||||||
peer_rank = dest if is_send else src
|
peer_rank = dest if is_send else src
|
||||||
@ -35,7 +35,7 @@ def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device
|
|||||||
global STEP
|
global STEP
|
||||||
global VERBOSE
|
global VERBOSE
|
||||||
is_fwd = (operation == 'send_fwd_recv_bwd')
|
is_fwd = (operation == 'send_fwd_recv_bwd')
|
||||||
if (is_fwd and pc.parallel_context.is_pipeline_last_stage) or (not is_fwd and pc.parallel_context.is_pipeline_first_stage): return None
|
if (is_fwd and pc.parallel_context.pp_is_last_stage) or (not is_fwd and pc.parallel_context.pp_is_first_stage): return None
|
||||||
peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank
|
peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank
|
||||||
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
|
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
|
||||||
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
|
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
|
||||||
|
|||||||
20
generate.py
20
generate.py
@ -1,4 +1,4 @@
|
|||||||
#VERBOSE=0 torchrun --nproc_per_node 3 generate.py
|
#VERBOSE=0 torchrun --nproc_per_node 3 generate.py --pp_size 3
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import torch, torch.distributed as dist
|
import torch, torch.distributed as dist
|
||||||
@ -20,12 +20,12 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor:
|
|||||||
|
|
||||||
# Preallocate memory for output logits.
|
# Preallocate memory for output logits.
|
||||||
logits = None
|
logits = None
|
||||||
if pc.parallel_context.is_pipeline_last_stage:
|
if pc.parallel_context.pp_is_last_stage:
|
||||||
logits = torch.empty((batch_size, seq_len, int(model.config.vocab_size)), dtype=torch.float32, device=device)
|
logits = torch.empty((batch_size, seq_len, int(model.config.vocab_size)), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32)
|
recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32)
|
||||||
|
|
||||||
batch["hidden_states"] = None if pc.parallel_context.is_pipeline_first_stage else recv_buffer
|
batch["hidden_states"] = None if pc.parallel_context.pp_is_first_stage else recv_buffer
|
||||||
|
|
||||||
output_tensor = model.forward(batch, device)
|
output_tensor = model.forward(batch, device)
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor:
|
|||||||
communicate(operation="send_forward", tensor=output_tensor)
|
communicate(operation="send_forward", tensor=output_tensor)
|
||||||
|
|
||||||
# Copy logits.
|
# Copy logits.
|
||||||
if pc.parallel_context.is_pipeline_last_stage:
|
if pc.parallel_context.pp_is_last_stage:
|
||||||
logits = output_tensor
|
logits = output_tensor
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
@ -42,16 +42,16 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--pp_size", type=int, default=1)
|
||||||
parser.add_argument("--max_tokens", type=int, default=32)
|
parser.add_argument("--max_tokens", type=int, default=32)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
#TODO: support only PP
|
|
||||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
||||||
|
|
||||||
dist.init_process_group(backend="nccl")
|
dist.init_process_group(backend="nccl")
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
device = torch.device("cuda", local_rank)
|
device = torch.device("cuda", local_rank)
|
||||||
setup_parallel_context(local_rank, world_size)
|
setup_parallel_context(tp_size=1, pp_size=args.pp_size, dp_size=1)
|
||||||
set_all_seed(seed=42)
|
set_all_seed(seed=42)
|
||||||
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
|
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
|
||||||
|
|
||||||
@ -60,8 +60,8 @@ if __name__ == "__main__":
|
|||||||
# Tokenize the input
|
# Tokenize the input
|
||||||
prompts = [
|
prompts = [
|
||||||
"My name is",
|
"My name is",
|
||||||
"How old are you ?",
|
# "How old are you ?",
|
||||||
"What is your favorite color?",
|
# "What is your favorite color?",
|
||||||
]
|
]
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M-Instruct")
|
||||||
@ -88,7 +88,7 @@ if __name__ == "__main__":
|
|||||||
logits = run_one_inference_step(model, batch_prompts, device)
|
logits = run_one_inference_step(model, batch_prompts, device)
|
||||||
|
|
||||||
# Sample new token
|
# Sample new token
|
||||||
if pc.parallel_context.is_pipeline_last_stage:
|
if pc.parallel_context.pp_is_last_stage:
|
||||||
assert logits is not None
|
assert logits is not None
|
||||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||||
tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1)
|
tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1)
|
||||||
@ -101,7 +101,7 @@ if __name__ == "__main__":
|
|||||||
dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank)
|
dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank)
|
||||||
|
|
||||||
# Get only the new generated tokens
|
# Get only the new generated tokens
|
||||||
if pc.parallel_context.is_pipeline_last_stage:
|
if pc.parallel_context.pp_is_last_stage:
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:]
|
tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:]
|
||||||
outputs = tokenizer.decode(tokenized_outputs)
|
outputs = tokenizer.decode(tokenized_outputs)
|
||||||
|
|||||||
@ -1,16 +1,113 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
class ParallelContext:
|
class ParallelContext:
|
||||||
def __init__(self, pp_rank, pp_world_size):
|
def __init__(self, tp_size, pp_size, dp_size):
|
||||||
self.pp_rank, self.pp_world_size = pp_rank, pp_world_size
|
self.global_rank = dist.get_rank()
|
||||||
self.pp_group = dist.new_group(list(range(self.pp_world_size)))
|
self.world_size = dist.get_world_size()
|
||||||
self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else (self.pp_rank + 1) % self.pp_world_size
|
self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size))
|
||||||
self.pp_prev_rank = None if self.pp_rank == 0 else (self.pp_rank - 1) % self.pp_world_size
|
|
||||||
self.is_pipeline_last_stage = self.pp_rank == self.pp_world_size - 1
|
|
||||||
#TODO: refactor to handle TP and DP
|
|
||||||
self.pp_last_rank = self.pp_world_size - 1
|
|
||||||
self.is_pipeline_first_stage = self.pp_rank == 0
|
|
||||||
|
|
||||||
def setup_parallel_context(local_rank, world_size):
|
self.tp_size = tp_size
|
||||||
|
self.pp_size = pp_size
|
||||||
|
self.dp_size = dp_size
|
||||||
|
assert self.world_size == self.tp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})"
|
||||||
|
|
||||||
|
self.grid = torch.arange(self.world_size).view(self.pp_size, self.dp_size, self.tp_size).permute(2, 0, 1)
|
||||||
|
# Find the position of the current process in the grid
|
||||||
|
self.tp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist()
|
||||||
|
|
||||||
|
# Process group creation
|
||||||
|
self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist()
|
||||||
|
self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist()
|
||||||
|
self.dp_group_ids = self.grid[self.tp_rank, self.pp_rank, :].tolist()
|
||||||
|
self.tp_pp_group_ids = self.grid[..., self.dp_rank].tolist()
|
||||||
|
|
||||||
|
self.tp_group = dist.new_group(self.tp_group_ids)
|
||||||
|
self.pp_group = dist.new_group(self.pp_group_ids)
|
||||||
|
self.dp_group = dist.new_group(self.dp_group_ids)
|
||||||
|
self.tp_pp_group = dist.new_subgroups_by_enumeration(self.tp_pp_group_ids)[0]
|
||||||
|
|
||||||
|
# Tensor parallelism
|
||||||
|
self.tp_first_rank = self.tp_group_ids[0]
|
||||||
|
self.tp_last_rank = self.tp_group_ids[-1]
|
||||||
|
self.tp_is_first_stage = self.tp_rank == 0
|
||||||
|
self.tp_is_last_stage = self.tp_rank == self.tp_size - 1
|
||||||
|
self.tp_world_size = dist.get_world_size(group=self.tp_group)
|
||||||
|
|
||||||
|
# Pipeline parallelism
|
||||||
|
self.pp_first_rank = self.pp_group_ids[0]
|
||||||
|
self.pp_last_rank = self.pp_group_ids[-1]
|
||||||
|
self.pp_is_first_stage = self.pp_rank == 0
|
||||||
|
self.pp_is_last_stage = self.pp_rank == self.pp_size - 1
|
||||||
|
self.pp_next_rank = None if self.pp_rank == self.pp_size - 1 else int(self.grid[self.tp_rank, self.pp_rank + 1, self.dp_rank].item())
|
||||||
|
self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.tp_rank, self.pp_rank - 1, self.dp_rank].item())
|
||||||
|
self.pp_world_size = dist.get_world_size(group=self.pp_group)
|
||||||
|
|
||||||
|
# Data parallelism
|
||||||
|
self.dp_first_rank = self.dp_group_ids[0]
|
||||||
|
self.dp_last_rank = self.dp_group_ids[-1]
|
||||||
|
self.dp_is_first_stage = self.dp_rank == 0
|
||||||
|
self.dp_is_last_stage = self.dp_rank == self.dp_size - 1
|
||||||
|
self.dp_world_size = dist.get_world_size(group=self.dp_group)
|
||||||
|
|
||||||
|
# Tensor parallelism and pipeline parallelism
|
||||||
|
self.tp_pp_world_size = dist.get_world_size(group=self.tp_pp_group)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"DP({self.dp_size})-PP({self.pp_size})-TP({self.tp_size})-Rank({self.global_rank})"
|
||||||
|
|
||||||
|
def display_parallelism_grid(self):
|
||||||
|
def _create_box(content):
|
||||||
|
return f" {content:^3} "
|
||||||
|
|
||||||
|
def _create_row(row):
|
||||||
|
return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|"
|
||||||
|
|
||||||
|
def _create_border(width):
|
||||||
|
return "+" + "-" * (width - 2) + "+"
|
||||||
|
|
||||||
|
def _create_pp_line(width, pp_size):
|
||||||
|
box_width = (width - pp_size + 1) // pp_size
|
||||||
|
return " ".join("PP".center(box_width) for _ in range(pp_size))
|
||||||
|
|
||||||
|
output = []
|
||||||
|
sample_row = _create_row(self.grid[0, :, 0])
|
||||||
|
row_width = len(sample_row)
|
||||||
|
border = _create_border(row_width)
|
||||||
|
|
||||||
|
output.append(f"=== Global Parallelism Configuration ===")
|
||||||
|
output.append(f"DP Size: {self.dp_size}, PP Size: {self.pp_size}, TP Size: {self.grid.shape[0]}")
|
||||||
|
output.append("") # Top spacing
|
||||||
|
|
||||||
|
for dp in range(self.dp_size):
|
||||||
|
output.append(f"DP {dp}:")
|
||||||
|
output.append(f"{'':>8}{border}")
|
||||||
|
|
||||||
|
for tp in range(self.grid.shape[0]):
|
||||||
|
if tp == 0:
|
||||||
|
output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}")
|
||||||
|
else:
|
||||||
|
output.append(f"{'':8}{border}")
|
||||||
|
output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}")
|
||||||
|
|
||||||
|
output.append(f"{'':8}{border}")
|
||||||
|
if self.pp_size > 1:
|
||||||
|
output.append(f"{'':>7}{_create_pp_line(row_width, self.pp_size)}")
|
||||||
|
|
||||||
|
output.append("") # Spacing between DP blocks
|
||||||
|
|
||||||
|
output.append("") # Bottom spacing
|
||||||
|
|
||||||
|
output.append(f"=== Local Parallelism Configuration ===")
|
||||||
|
output.append(self.__str__())
|
||||||
|
output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in self.dp_group_ids]}")
|
||||||
|
output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in self.pp_group_ids]}")
|
||||||
|
output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in self.tp_group_ids]}")
|
||||||
|
output.append(f"TP-PP Group IDs: {[['g{:02d}'.format(id) for id in subgroup] for subgroup in self.tp_pp_group_ids]}")
|
||||||
|
|
||||||
|
print("\n".join(output))
|
||||||
|
|
||||||
|
def setup_parallel_context(tp_size, pp_size, dp_size):
|
||||||
global parallel_context
|
global parallel_context
|
||||||
parallel_context = ParallelContext(pp_rank=local_rank, pp_world_size=world_size)
|
parallel_context = ParallelContext(tp_size, pp_size, dp_size)
|
||||||
@ -9,10 +9,10 @@ class PipelineParallel(nn.Module):
|
|||||||
self.config = AutoConfig.from_pretrained(model_name)
|
self.config = AutoConfig.from_pretrained(model_name)
|
||||||
base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config)
|
base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config)
|
||||||
layer_distribution = self.distribute_layers(self.config.num_hidden_layers)
|
layer_distribution = self.distribute_layers(self.config.num_hidden_layers)
|
||||||
self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.is_pipeline_first_stage else nn.Identity()
|
self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.pp_is_first_stage else nn.Identity()
|
||||||
self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution})
|
self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution})
|
||||||
self.norm = base_model.model.norm if pc.parallel_context.is_pipeline_last_stage else nn.Identity()
|
self.norm = base_model.model.norm if pc.parallel_context.pp_is_last_stage else nn.Identity()
|
||||||
self.lm_head = base_model.lm_head if pc.parallel_context.is_pipeline_last_stage else nn.Identity()
|
self.lm_head = base_model.lm_head if pc.parallel_context.pp_is_last_stage else nn.Identity()
|
||||||
del base_model
|
del base_model
|
||||||
|
|
||||||
def distribute_layers(self, num_layers):
|
def distribute_layers(self, num_layers):
|
||||||
@ -44,7 +44,7 @@ def pipeline_parallel_afab(model, data_loader, tensor_shapes, device):
|
|||||||
batch["hidden_states"] = input_tensor
|
batch["hidden_states"] = input_tensor
|
||||||
output_tensor = model.forward(batch, device)
|
output_tensor = model.forward(batch, device)
|
||||||
communicate(operation='send_forward', tensor=output_tensor)
|
communicate(operation='send_forward', tensor=output_tensor)
|
||||||
if pc.parallel_context.is_pipeline_last_stage:
|
if pc.parallel_context.pp_is_last_stage:
|
||||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||||
logging_loss += output_tensor.item()
|
logging_loss += output_tensor.item()
|
||||||
input_tensors.append(input_tensor)
|
input_tensors.append(input_tensor)
|
||||||
@ -67,7 +67,7 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device):
|
|||||||
batch = next(iter(data_loader))
|
batch = next(iter(data_loader))
|
||||||
batch["hidden_states"] = input_tensor
|
batch["hidden_states"] = input_tensor
|
||||||
output_tensor = model.forward(batch, device)
|
output_tensor = model.forward(batch, device)
|
||||||
if pc.parallel_context.is_pipeline_last_stage:
|
if pc.parallel_context.pp_is_last_stage:
|
||||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||||
nonlocal logging_loss
|
nonlocal logging_loss
|
||||||
logging_loss += output_tensor.item()
|
logging_loss += output_tensor.item()
|
||||||
|
|||||||
19
train.py
19
train.py
@ -1,10 +1,11 @@
|
|||||||
#VERBOSE=0 torchrun --nproc_per_node 3 train.py
|
#VERBOSE=0 torchrun --nproc_per_node 3 train.py --pp_size 3
|
||||||
import os
|
import os
|
||||||
import torch, torch.distributed as dist
|
import torch, torch.distributed as dist
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
import argparse
|
||||||
|
|
||||||
import parallel_context as pc
|
import parallel_context as pc
|
||||||
from utils import set_all_seed
|
from utils import set_all_seed
|
||||||
@ -30,14 +31,26 @@ class MicroBatchDataLoader(DataLoader):
|
|||||||
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
|
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--tp_size", type=int, default=1)
|
||||||
|
parser.add_argument("--pp_size", type=int, default=1)
|
||||||
|
parser.add_argument("--dp_size", type=int, default=1)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
|
||||||
|
|
||||||
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800
|
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800
|
||||||
|
|
||||||
dist.init_process_group(backend="nccl")
|
dist.init_process_group(backend="nccl")
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
device = torch.device("cuda", local_rank)
|
device = torch.device("cuda", local_rank)
|
||||||
setup_parallel_context(local_rank, world_size)
|
setup_parallel_context(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
||||||
|
|
||||||
|
if pc.parallel_context.global_rank == 0:
|
||||||
|
pc.parallel_context.display_parallelism_grid()
|
||||||
|
|
||||||
set_all_seed(seed=42)
|
set_all_seed(seed=42)
|
||||||
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
|
model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device)
|
||||||
@ -52,5 +65,5 @@ if __name__ == "__main__":
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
trained_tokens += tokens_per_step
|
trained_tokens += tokens_per_step
|
||||||
step += 1
|
step += 1
|
||||||
if pc.parallel_context.is_pipeline_last_stage:
|
if pc.parallel_context.pp_is_last_stage:
|
||||||
print(f"Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
print(f"Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user