revert some change

This commit is contained in:
zzhhjjj 2024-10-22 19:50:23 +00:00
parent 9d53e9afa6
commit 9a7904d5d6
2 changed files with 4 additions and 12 deletions

View File

@ -2,12 +2,10 @@ from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParalle
import torch.nn.init as init
import torch.nn as nn
class TensorParallel(nn.Module):
class TensorParallel():
def __init__(self, model, init_method = init.xavier_normal_):
super().__init__()
self.model = model
module_linear_name_stype_mapping_list = [
("attention", "q_proj", "column"),
("attention", "k_proj", "column"),
@ -50,10 +48,4 @@ class TensorParallel(nn.Module):
embedding_dim=linear_layer.embedding_dim,
init_method=self.init_method
)
setattr(module, linear_proj_name, new_linear_layer)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)
setattr(module, linear_proj_name, new_linear_layer)

View File

@ -234,7 +234,7 @@ if __name__ == "__main__":
)
if pgm.process_group_manager.tp_world_size > 1:
model = TensorParallel(model)
TensorParallel(model)
# if pgm.process_group_manager.cp_size > 1:
#TODO: do at the very end when we have fix convergence issue
@ -249,7 +249,7 @@ if __name__ == "__main__":
model.to(device)
model.train()
data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, grad_acc = grad_acc,num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)