16 lines
864 B
Python
16 lines
864 B
Python
|
|
import torch.distributed as dist
|
||
|
|
|
||
|
|
class ParallelContext:
|
||
|
|
def __init__(self, pp_rank, pp_world_size):
|
||
|
|
self.pp_rank, self.pp_world_size = pp_rank, pp_world_size
|
||
|
|
self.pp_group = dist.new_group(list(range(self.pp_world_size)))
|
||
|
|
self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else (self.pp_rank + 1) % self.pp_world_size
|
||
|
|
self.pp_prev_rank = None if self.pp_rank == 0 else (self.pp_rank - 1) % self.pp_world_size
|
||
|
|
self.is_pipeline_last_stage = self.pp_rank == self.pp_world_size - 1
|
||
|
|
#TODO: refactor to handle TP and DP
|
||
|
|
self.pp_last_rank = self.pp_world_size - 1
|
||
|
|
self.is_pipeline_first_stage = self.pp_rank == 0
|
||
|
|
|
||
|
|
def setup_parallel_context(local_rank, world_size):
|
||
|
|
global parallel_context
|
||
|
|
parallel_context = ParallelContext(pp_rank=local_rank, pp_world_size=world_size)
|