made Tensor Parallel API compliant

This commit is contained in:
ferdinand.mom 2024-10-18 15:30:01 +00:00
parent abd1edf9f9
commit 2b2781a374
3 changed files with 14 additions and 3 deletions

View File

@ -10,7 +10,9 @@ import src.distributed.process_group_manager as pgm
class ContextParallel(nn.Module):
def __init__(self, model, config):
super().__init__(model, config)
super().__init__()
self.model = model
for name, module in model.named_modules():
if isinstance(module, Attention) and not isinstance(module, RingAttention):

View File

@ -4,6 +4,10 @@ import torch.nn as nn
class TensorParallel(nn.Module):
def __init__(self, model, init_method = init.xavier_normal_):
super().__init__()
self.model = model
module_linear_name_stype_mapping_list = [
("attention", "q_proj", "column"),
("attention", "k_proj", "column"),
@ -47,4 +51,9 @@ class TensorParallel(nn.Module):
init_method=self.init_method
)
setattr(module, linear_proj_name, new_linear_layer)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)

View File

@ -234,7 +234,7 @@ if __name__ == "__main__":
)
if pgm.process_group_manager.tp_world_size > 1:
TensorParallel(model)
model = TensorParallel(model)
# if pgm.process_group_manager.cp_size > 1:
#TODO: do at the very end when we have fix convergence issue