picotron/parallel/data_parallel.py
2024-10-15 13:06:17 +00:00

11 lines
327 B
Python

import torch.distributed as dist
import torch.nn as nn
import distributed.process_group_manager as pgm
from parallel.base_parallel import BaseParallel
class DataParallel(BaseParallel):
def __init__(self, model, config):
#TODO: Add Zero1
#TODO: Interleave all_reduce
super().__init__(model, config)