19 lines
525 B
Python
19 lines
525 B
Python
import torch.nn as nn
|
|
|
|
class BaseParallel(nn.Module):
|
|
def __init__(self, model, config):
|
|
super().__init__()
|
|
self.model = model
|
|
self.config = config
|
|
|
|
def __getattr__(self, name):
|
|
try:
|
|
return super().__getattr__(name)
|
|
except AttributeError:
|
|
return getattr(self.model, name)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
def backward(self, *args, **kwargs):
|
|
return self.model.backward(*args, **kwargs) |