picotron/parallel/base_parallel.py
2024-10-15 13:06:17 +00:00

19 lines
525 B
Python

import torch.nn as nn
class BaseParallel(nn.Module):
def __init__(self, model, config):
super().__init__()
self.model = model
self.config = config
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def backward(self, *args, **kwargs):
return self.model.backward(*args, **kwargs)