2024-09-25 20:36:22 +08:00
|
|
|
import torch.distributed as dist
|
|
|
|
|
import torch.nn as nn
|
2024-10-10 23:12:14 +08:00
|
|
|
import distributed.process_group_manager as pgm
|
2024-09-25 20:36:22 +08:00
|
|
|
|
2024-10-15 21:06:17 +08:00
|
|
|
from parallel.base_parallel import BaseParallel
|
|
|
|
|
|
|
|
|
|
class DataParallel(BaseParallel):
|
2024-09-25 21:17:05 +08:00
|
|
|
def __init__(self, model, config):
|
2024-09-25 20:36:22 +08:00
|
|
|
#TODO: Add Zero1
|
|
|
|
|
#TODO: Interleave all_reduce
|
2024-10-15 21:06:17 +08:00
|
|
|
super().__init__(model, config)
|