From 2b2781a374a3aa4c4a9bc1cb7e6655864240f825 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 18 Oct 2024 15:30:01 +0000 Subject: [PATCH] made Tensor Parallel API compliant --- src/parallel/context_parallel.py | 4 +++- src/parallel/tensor_parallel/tensor_parallel.py | 11 ++++++++++- train.py | 2 +- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/parallel/context_parallel.py b/src/parallel/context_parallel.py index 8ef42bb..a6f79a8 100644 --- a/src/parallel/context_parallel.py +++ b/src/parallel/context_parallel.py @@ -10,7 +10,9 @@ import src.distributed.process_group_manager as pgm class ContextParallel(nn.Module): def __init__(self, model, config): - super().__init__(model, config) + super().__init__() + + self.model = model for name, module in model.named_modules(): if isinstance(module, Attention) and not isinstance(module, RingAttention): diff --git a/src/parallel/tensor_parallel/tensor_parallel.py b/src/parallel/tensor_parallel/tensor_parallel.py index e95d83d..b59abaa 100644 --- a/src/parallel/tensor_parallel/tensor_parallel.py +++ b/src/parallel/tensor_parallel/tensor_parallel.py @@ -4,6 +4,10 @@ import torch.nn as nn class TensorParallel(nn.Module): def __init__(self, model, init_method = init.xavier_normal_): + super().__init__() + + self.model = model + module_linear_name_stype_mapping_list = [ ("attention", "q_proj", "column"), ("attention", "k_proj", "column"), @@ -47,4 +51,9 @@ class TensorParallel(nn.Module): init_method=self.init_method ) setattr(module, linear_proj_name, new_linear_layer) - \ No newline at end of file + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.model, name) \ No newline at end of file diff --git a/train.py b/train.py index 24d15cd..2570f59 100644 --- a/train.py +++ b/train.py @@ -234,7 +234,7 @@ if __name__ == "__main__": ) if pgm.process_group_manager.tp_world_size > 1: - TensorParallel(model) + model = TensorParallel(model) # if pgm.process_group_manager.cp_size > 1: #TODO: do at the very end when we have fix convergence issue