25 lines
463 B
Python
25 lines
463 B
Python
|
|
# coding=utf-8
|
||
|
|
import os
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.multiprocessing as mp
|
||
|
|
|
||
|
|
|
||
|
|
class DataParallel(object):
|
||
|
|
def __init__(self, dp_num, world_size: int, rank: int, module=None):
|
||
|
|
self.world_size = world_size
|
||
|
|
self.rank = rank
|
||
|
|
self.dp_num = dp_num
|
||
|
|
|
||
|
|
def forward_pipeline(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def backward_pipeline(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def train(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def sync_grad(self, param):
|
||
|
|
pass
|