torch_ext/training/dp.py
2025-03-28 22:19:03 +08:00

25 lines
463 B
Python

# coding=utf-8
import os
import torch
import torch.nn as nn
import torch.multiprocessing as mp
class DataParallel(object):
def __init__(self, dp_num, world_size: int, rank: int, module=None):
self.world_size = world_size
self.rank = rank
self.dp_num = dp_num
def forward_pipeline(self):
pass
def backward_pipeline(self):
pass
def train(self):
pass
def sync_grad(self, param):
pass