revert some change
This commit is contained in:
parent
9d53e9afa6
commit
9a7904d5d6
@ -2,12 +2,10 @@ from src.parallel.tensor_parallel.layers import ColumnParallelLinear, RowParalle
|
||||
import torch.nn.init as init
|
||||
import torch.nn as nn
|
||||
|
||||
class TensorParallel(nn.Module):
|
||||
class TensorParallel():
|
||||
def __init__(self, model, init_method = init.xavier_normal_):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
|
||||
module_linear_name_stype_mapping_list = [
|
||||
("attention", "q_proj", "column"),
|
||||
("attention", "k_proj", "column"),
|
||||
@ -50,10 +48,4 @@ class TensorParallel(nn.Module):
|
||||
embedding_dim=linear_layer.embedding_dim,
|
||||
init_method=self.init_method
|
||||
)
|
||||
setattr(module, linear_proj_name, new_linear_layer)
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.model, name)
|
||||
setattr(module, linear_proj_name, new_linear_layer)
|
||||
4
train.py
4
train.py
@ -234,7 +234,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
model = TensorParallel(model)
|
||||
TensorParallel(model)
|
||||
|
||||
# if pgm.process_group_manager.cp_size > 1:
|
||||
#TODO: do at the very end when we have fix convergence issue
|
||||
@ -249,7 +249,7 @@ if __name__ == "__main__":
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
|
||||
data_loader = MicroBatchDataLoader(global_batch_size=GLOBAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, grad_acc = grad_acc,num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
|
||||
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
|
||||
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user