diff --git a/src/parallel/tensor_parallel/tensor_parallel.py b/src/parallel/tensor_parallel/tensor_parallel.py index b59abaa..3e2c59b 100644 --- a/src/parallel/tensor_parallel/tensor_parallel.py +++ b/src/parallel/tensor_parallel/tensor_parallel.py @@ -2,12 +2,10 @@ from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParalle import torch.nn.init as init import torch.nn as nn -class TensorParallel(nn.Module): +class TensorParallel(): def __init__(self, model, init_method = init.xavier_normal_): super().__init__() - self.model = model - module_linear_name_stype_mapping_list = [ ("attention", "q_proj", "column"), ("attention", "k_proj", "column"), @@ -50,10 +48,4 @@ class TensorParallel(nn.Module): embedding_dim=linear_layer.embedding_dim, init_method=self.init_method ) - setattr(module, linear_proj_name, new_linear_layer) - - def __getattr__(self, name): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.model, name) \ No newline at end of file + setattr(module, linear_proj_name, new_linear_layer) \ No newline at end of file diff --git a/train.py b/train.py index 2570f59..c40e115 100644 --- a/train.py +++ b/train.py @@ -234,7 +234,7 @@ if __name__ == "__main__": ) if pgm.process_group_manager.tp_world_size > 1: - model = TensorParallel(model) + TensorParallel(model) # if pgm.process_group_manager.cp_size > 1: #TODO: do at the very end when we have fix convergence issue @@ -249,7 +249,7 @@ if __name__ == "__main__": model.to(device) model.train() - data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, num_workers=4, num_proc=4, num_samples=NUM_SAMPLES) + data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, grad_acc = grad_acc,num_workers=4, num_proc=4, num_samples=NUM_SAMPLES) tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)