picotron/data_parallel.py
2024-09-25 13:33:20 +00:00

24 lines
892 B
Python

import torch.distributed as dist
import torch.nn as nn
import process_group_manager as pgm
class DataParallel(nn.Module):
def __init__(self, model, config):
#TODO: Add Zero1
#TODO: Interleave all_reduce
super().__init__()
self.model = model
self.dp_world_size = pgm.process_group_manager.dp_world_size
self.dp_rank = pgm.process_group_manager.dp_rank
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def backward(self, input_tensor, output_tensor, output_tensor_grad):
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)
def all_reduce_gradients(self):
for param in self.model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)