picotron/parallel_context.py

16 lines
864 B
Python
Raw Normal View History

2024-09-19 22:06:46 +08:00
import torch.distributed as dist
class ParallelContext:
def __init__(self, pp_rank, pp_world_size):
self.pp_rank, self.pp_world_size = pp_rank, pp_world_size
self.pp_group = dist.new_group(list(range(self.pp_world_size)))
self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else (self.pp_rank + 1) % self.pp_world_size
self.pp_prev_rank = None if self.pp_rank == 0 else (self.pp_rank - 1) % self.pp_world_size
self.is_pipeline_last_stage = self.pp_rank == self.pp_world_size - 1
#TODO: refactor to handle TP and DP
self.pp_last_rank = self.pp_world_size - 1
self.is_pipeline_first_stage = self.pp_rank == 0
def setup_parallel_context(local_rank, world_size):
global parallel_context
parallel_context = ParallelContext(pp_rank=local_rank, pp_world_size=world_size)