picotron/parallel/context_parallel.py

15 lines
481 B
Python
Raw Normal View History

2024-09-25 20:36:22 +08:00
import torch.distributed as dist
import torch.nn as nn
2024-10-10 23:12:14 +08:00
import distributed.process_group_manager as pgm
2024-09-25 20:36:22 +08:00
2024-10-10 23:12:14 +08:00
class ContextParallel(nn.Module):
def __init__(self, model, config):
2024-09-25 20:36:22 +08:00
super().__init__()
self.model = model
2024-10-10 23:12:14 +08:00
2024-09-25 20:36:22 +08:00
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def backward(self, input_tensor, output_tensor, output_tensor_grad):
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)