# coding=utf-8 import os import torch import torch.nn as nn import torch.multiprocessing as mp class DataParallel(object): def __init__(self, dp_num, world_size: int, rank: int, module=None): self.world_size = world_size self.rank = rank self.dp_num = dp_num def forward_pipeline(self): pass def backward_pipeline(self): pass def train(self): pass def sync_grad(self, param): pass