match tp+pp loss
This commit is contained in:
parent
63307c79a1
commit
51b5683dd3
@ -44,8 +44,8 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
|||||||
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
|
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
|
||||||
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype)
|
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype)
|
||||||
|
|
||||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
# calculate loss on the last stage
|
||||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
if pgm.process_group_manager.pp_is_last_stage:
|
||||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||||
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
|
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
|
||||||
|
|
||||||
@ -70,8 +70,9 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
|||||||
batch = next(data_loader)
|
batch = next(data_loader)
|
||||||
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
|
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
|
||||||
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
|
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
|
||||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
|
||||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
# calculate loss on the last stage
|
||||||
|
if pgm.process_group_manager.pp_is_last_stage:
|
||||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||||
nonlocal logging_loss
|
nonlocal logging_loss
|
||||||
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
|
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user