refactor
This commit is contained in:
parent
439e23fdba
commit
a727b986cb
@ -18,7 +18,7 @@ class MicroBatchDataLoader(DataLoader):
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
|
||||
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
self.dataset = load_dataset(dataset_name, split=split, name=subset_name)
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
print(f"rank {pgm.process_group_manager.global_rank}: Creating tokenizer")
|
||||
|
||||
@ -16,8 +16,27 @@ def split_tensor_along_last_dim(tensor, num_partitions):
|
||||
last_dim_size = tensor.size()[last_dim] // num_partitions
|
||||
return torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
|
||||
class CopyToModelParallelRegion(torch.autograd.Function):
|
||||
"""
|
||||
Copy in forward pass, all-reduce in backward pass.
|
||||
This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
return grad_output
|
||||
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return grad_output
|
||||
|
||||
class ReduceFromModelParallelRegion(torch.autograd.Function):
|
||||
"""All-reduce in forward pass, identity in backward pass."""
|
||||
"""
|
||||
All-reduce in forward pass, identity in backward pass.
|
||||
This is the `g` function in the paper: https://arxiv.org/abs/1909.08053
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
@ -52,27 +71,6 @@ class GatherFromModelParallelRegion(torch.autograd.Function):
|
||||
chunks = split_tensor_along_last_dim(grad_output, pgm.process_group_manager.tp_world_size)
|
||||
return chunks[pgm.process_group_manager.tp_rank].contiguous()
|
||||
|
||||
class CopyToModelParallelRegion(torch.autograd.Function):
|
||||
"""Copy in forward pass, all-reduce in backward pass."""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
return grad_output
|
||||
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return grad_output
|
||||
|
||||
def linear_with_all_reduce(x, weight, bias):
|
||||
input_parallel = CopyToModelParallelRegion.apply(x)
|
||||
output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i
|
||||
return output
|
||||
|
||||
def linear_with_async_all_reduce(x, weight, bias):
|
||||
return LinearWithAsyncAllReduce.apply(x, weight, bias)
|
||||
|
||||
class LinearWithAsyncAllReduce(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias):
|
||||
@ -100,4 +98,12 @@ class LinearWithAsyncAllReduce(torch.autograd.Function):
|
||||
grad_weight = grad_output.t() @ input_ # (out_size, b*s) @ (b*s, input_size) -> (out_size, input_size)
|
||||
grad_bias = grad_output.sum(0) if ctx.use_bias else None
|
||||
input_gradient_all_reduce_handle.wait()
|
||||
return grad_input, grad_weight, grad_bias
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
def linear_with_all_reduce(x, weight, bias):
|
||||
input_parallel = CopyToModelParallelRegion.apply(x)
|
||||
output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i
|
||||
return output
|
||||
|
||||
def linear_with_async_all_reduce(x, weight, bias):
|
||||
return LinearWithAsyncAllReduce.apply(x, weight, bias)
|
||||
@ -52,7 +52,6 @@ def get_num_params(model):
|
||||
For DP: Parameters are replicated, so only count once
|
||||
|
||||
Note:
|
||||
LayerNorm: Split across TP ranks for sequence parallelism
|
||||
FSDP: Parameters are sharded across data parallel ranks
|
||||
"""
|
||||
tp_world_size = pgm.process_group_manager.tp_world_size
|
||||
@ -86,4 +85,11 @@ def assert_no_meta_tensors(model):
|
||||
if buffer.device == torch.device("meta"):
|
||||
meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}")
|
||||
|
||||
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
|
||||
assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors)
|
||||
|
||||
def average_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
|
||||
reduced_loss /= pgm.process_group_manager.cp_dp_world_size
|
||||
return reduced_loss.item()
|
||||
22
train.py
22
train.py
@ -1,10 +1,6 @@
|
||||
"""Training script for LLaMA model.
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
|
||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/llama2_7b_benchmark.json
|
||||
"""
|
||||
import os
|
||||
import inspect
|
||||
@ -20,7 +16,7 @@ from transformers import AutoConfig
|
||||
from picotron.context_parallel.context_parallel import apply_context_parallel
|
||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.utils import set_all_seed, print, to_readable_format, get_mfu, get_num_params
|
||||
from picotron.utils import average_loss_across_dp_cp_ranks, set_all_seed, print, to_readable_format, get_mfu, get_num_params
|
||||
from picotron.checkpoint import CheckpointManager
|
||||
from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
@ -111,8 +107,9 @@ if __name__ == "__main__":
|
||||
device=device,
|
||||
num_workers=config["dataset"]["num_workers"],
|
||||
num_proc=config["dataset"]["num_proc"],
|
||||
num_samples=config["training"]["num_samples"],
|
||||
subset_name=config["dataset"]["subset_name"],
|
||||
num_samples=config["training"].get("num_samples", None),
|
||||
subset_name=config["dataset"].get("subset_name", None),
|
||||
split=config["dataset"].get("split", "train")
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
@ -208,13 +205,6 @@ if __name__ == "__main__":
|
||||
step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, config["checkpoint"]["load_path"])
|
||||
|
||||
dist.barrier()
|
||||
|
||||
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
|
||||
reduced_loss /= pgm.process_group_manager.cp_dp_world_size
|
||||
return reduced_loss.item()
|
||||
|
||||
while config["training"]["max_tokens"] is None or trained_tokens < config["training"]["max_tokens"]:
|
||||
step_start_time = time.time()
|
||||
@ -230,7 +220,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
loss = train_step(model, data_loader, device)
|
||||
|
||||
loss = _all_reduce_loss_across_dp_cp_ranks(loss, device)
|
||||
loss = average_loss_across_dp_cp_ranks(loss, device)
|
||||
|
||||
optimizer.step()
|
||||
trained_tokens += tokens_per_step
|
||||
|
||||
Loading…
Reference in New Issue
Block a user