25 lines
463 B
Python
25 lines
463 B
Python
# coding=utf-8
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.multiprocessing as mp
|
|
|
|
|
|
class DataParallel(object):
|
|
def __init__(self, dp_num, world_size: int, rank: int, module=None):
|
|
self.world_size = world_size
|
|
self.rank = rank
|
|
self.dp_num = dp_num
|
|
|
|
def forward_pipeline(self):
|
|
pass
|
|
|
|
def backward_pipeline(self):
|
|
pass
|
|
|
|
def train(self):
|
|
pass
|
|
|
|
def sync_grad(self, param):
|
|
pass
|