Merge remote-tracking branch 'origin/main' into loading_big_model

This commit is contained in:
ferdinand.mom 2024-12-17 05:30:26 +00:00
commit 43f39ff9ec
6 changed files with 396 additions and 25 deletions

View File

@ -10,8 +10,7 @@ from picotron.utils import print
import picotron.process_group_manager as pgm
class MicroBatchDataLoader(DataLoader):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, split="train", num_samples=None):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True):
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.grad_acc_steps = grad_acc_steps
@ -50,7 +49,7 @@ class MicroBatchDataLoader(DataLoader):
self.tokenized_dataset,
batch_size=micro_batch_size,
collate_fn=self.collate_batch,
pin_memory=True,
pin_memory=pin_memory,
num_workers=num_workers,
sampler=self.sampler,
shuffle=False
@ -109,14 +108,11 @@ class MicroBatchDataLoader(DataLoader):
input_ids = batch_input_ids[:, start_idx:end_idx].contiguous()
target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous()
position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool))
attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous()
return {
"input_ids": input_ids,
"target_ids": target_ids,
"position_ids": position_ids,
"attn_mask": attn_mask,
"hidden_states": None
}
@ -131,6 +127,12 @@ class MicroBatchDataLoader(DataLoader):
try:
batch = next(self._iterator)
except StopIteration:
self._iterator = None
raise StopIteration
# Reinitialize the sampler and iterator
self.sampler.set_epoch(self.sampler.epoch + 1 if hasattr(self.sampler, 'epoch') else 0)
self._iterator = super().__iter__()
try:
batch = next(self._iterator)
except StopIteration:
self._iterator = None
raise StopIteration
return batch

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import picotron.process_group_manager as pgm
from picotron.tensor_parallel.tp_communications import Reduce, Gather, Copy, split_tensor_along_last_dim
from picotron.tensor_parallel.tp_communications import Reduce, Gather, linear_with_all_reduce, linear_with_async_all_reduce
def apply_tensor_parallel(model):
@ -51,10 +51,26 @@ def apply_tensor_parallel(model):
return model
class ColumnParallelLinear(nn.Module):
class ColumnParallelLinear(torch.nn.Module):
"""Column Parallel Linear layer
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
Arguments:
in_features: first dimension of weight matrix W.
out_features: second dimension of weight matrix W.
bias: If true, add bias
init_method: method to initialize weights
gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
"""
def __init__(self, in_features: int, out_features: int, bias: bool, gather_output: bool = False):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
gather_output: bool = False,
async_all_reduce: bool = True,
) -> None:
super(ColumnParallelLinear, self).__init__()
self.tp_world_size = pgm.process_group_manager.tp_world_size
@ -65,7 +81,8 @@ class ColumnParallelLinear(nn.Module):
assert out_features % self.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.output_size_per_partition = out_features // self.tp_world_size
self.gather_output = gather_output
self.async_all_reduce = async_all_reduce
# Allocate space for the weight and bias
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
self.weight = nn.Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
if bias:
@ -95,17 +112,34 @@ class ColumnParallelLinear(nn.Module):
# Split the model into size of self.output_size_per_partition
weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0)
self.weight.data = weight_list[self.tp_rank].contiguous()
def forward(self, input):
input_parallel = Copy.apply(input)
# XW_i^T + b, output is Y_i
output = F.linear(input_parallel, self.weight, self.bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.async_all_reduce:
output = linear_with_async_all_reduce(input, self.weight, self.bias)
else:
output = linear_with_all_reduce(input, self.weight, self.bias)
if self.gather_output:
output = Gather.apply(output)
return output
class RowParallelLinear(nn.Module):
"""Linear layer with row parallelism.
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
- -
| W_1 |
| . |
W = | . | X = [X_1, ..., X_p]
| . |
| W_p |
- -
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
Arguments:
in_features: first dimension of matrix W.
out_features: second dimension of matrix W.
bias: If true, add bias
init_method: method to initialize weights.
"""
def __init__(self, in_features: int, out_features: int, bias: bool):
super(RowParallelLinear, self).__init__()

View File

@ -1,6 +1,13 @@
import torch.distributed as dist
import torch
import picotron.process_group_manager as pgm
import torch.nn.functional as F
from typing import Tuple
def merge_first_two_dims(grad_output: torch.Tensor, input_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Merge the first two dimensions of tensors."""
return grad_output.contiguous().view(-1, *grad_output.shape[2:]), input_.contiguous().view(-1, *input_.shape[2:])
def split_tensor_along_last_dim(tensor, num_partitions):
"""Split a tensor along its last dimension into num_partitions chunks."""
@ -45,7 +52,7 @@ class Gather(torch.autograd.Function):
chunks = split_tensor_along_last_dim(grad_output, pgm.process_group_manager.tp_world_size)
return chunks[pgm.process_group_manager.tp_rank].contiguous()
class Copy(torch.autograd.Function):
class Identity(torch.autograd.Function):
"""Identity in forward pass, all-reduce in backward pass."""
@staticmethod
def forward(ctx, input):
@ -56,4 +63,41 @@ class Copy(torch.autograd.Function):
if pgm.process_group_manager.tp_world_size == 1:
return grad_output
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
return grad_output
return grad_output
def linear_with_all_reduce(input, weight, bias):
input_parallel = Identity.apply(input)
output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i
return output
def linear_with_async_all_reduce(input, weight, bias):
return LinearWithAsyncAllReduce.apply(input, weight, bias)
class LinearWithAsyncAllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
output = input_ @ weight.t() + bias if bias is not None else input_ @ weight.t()
return output
@staticmethod
def backward(ctx, grad_output):
"""
The key difference with "linear_with_all_reduce" is that the all reduce of input_ gradeint is before
the calculation of the gradient of weights and bias, instead of after. So we can overlap the computation and communication
This is only applicable to Column Parallel Linear
Before: grad_output -> grad_input, grad_weight, grad_bias -> grad_input all reduce
Now: grad_output -> grad_input -> grad_input all reduce -> grad_weight, grad_bias
"""
input_, weight = ctx.saved_tensors
grad_input = grad_output @ weight # (b, s, out_size) @ (out_size, input_size) = (b, s, input_size)
# all-reduce input gradient.
input_gradient_all_reduce_handle = dist.all_reduce(grad_input, group=pgm.process_group_manager.tp_group, async_op=True)
# merge first two dims to allow matrix multiplication
grad_output, input_ = merge_first_two_dims(grad_output, input_) # grad_output, input_: (b, s, out_size), (b, s, input_size) -> (b*s, out_size), (b*s, input_size)
grad_weight = grad_output.t() @ input_ # (out_size, b*s) @ (b*s, input_size) -> (out_size, input_size)
grad_bias = grad_output.sum(0) if ctx.use_bias else None
input_gradient_all_reduce_handle.wait()
return grad_input, grad_weight, grad_bias

213
tests/test_dataloader.py Normal file
View File

@ -0,0 +1,213 @@
"""
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 test_dataloader.py
"""
from picotron.data import MicroBatchDataLoader
import torch.distributed as dist
import os
import datetime
from picotron.process_group_manager import setup_process_group_manager
import torch
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
from functools import partial
from datasets import Features, Sequence, Value, load_dataset
from transformers import AutoTokenizer
import picotron.process_group_manager as pgm
# remove context parallelism split. as a reference
class DummyDataLoader(DataLoader):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True):
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.grad_acc_steps = grad_acc_steps
self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.dataset = load_dataset(dataset_name, split=split)
if num_samples:
self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
# Tokenize and chunk the dataset
self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc)
self.sampler = DistributedSampler(
self.tokenized_dataset,
num_replicas=pgm.process_group_manager.dp_world_size,
rank=pgm.process_group_manager.dp_rank,
shuffle=False
)
super().__init__(
self.tokenized_dataset,
batch_size=micro_batch_size,
collate_fn=self.collate_batch,
pin_memory=True,
num_workers=num_workers,
sampler=self.sampler,
shuffle=False
)
@staticmethod
def tokenizer_group_text(examples, tokenizer, sequence_length):
"""Tokenize a list of texts and group them in chunks of sequence_length + 1"""
tokenized_text_batch = tokenizer.batch_encode_plus(
examples,
return_attention_mask=False,
return_token_type_ids=False,
return_tensors='np'
)
concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])}
total_length = len(concatenated_tokens['input_ids'])
if total_length >= sequence_length + 1:
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
result = {
'input_ids': [
concatenated_tokens['input_ids'][i : i + sequence_length + 1]
for i in range(0, total_length - sequence_length, sequence_length)
]
}
return result
def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc):
"""Tokenize the dataset and group texts in chunks of sequence_length + 1"""
# Create a partial function with fixed arguments
tokenizer_func = partial(
self.tokenizer_group_text,
tokenizer=self.tokenizer,
sequence_length=sequence_length
)
tokenized_dataset = dataset.map(
tokenizer_func,
input_columns=text_column_name,
remove_columns=dataset.column_names,
features=Features({
"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)
}),
batched=True,
num_proc=num_proc,
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {sequence_length+1}",
)
return tokenized_dataset
def collate_batch(self, batch):
batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
batch_size = batch_input_ids.size(0)
input_ids = batch_input_ids[:, :self.seq_length].contiguous()
target_ids = batch_input_ids[:, 1:self.seq_length+1].contiguous()
position_ids = torch.arange(0, self.seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
return {
"input_ids": input_ids,
"target_ids": target_ids,
"position_ids": position_ids,
"hidden_states": None
}
def __iter__(self):
if self._iterator is None:
self._iterator = super().__iter__()
return self
def __next__(self):
if self._iterator is None:
self._iterator = super().__iter__()
try:
batch = next(self._iterator)
except StopIteration:
# Reinitialize the sampler and iterator
self.sampler.set_epoch(self.sampler.epoch + 1 if hasattr(self.sampler, 'epoch') else 0)
self._iterator = super().__iter__()
try:
batch = next(self._iterator)
except StopIteration:
self._iterator = None
raise StopIteration
return batch
# test the tokens are split correctly in context parallelism
# TODO: test zigzag behavior
def test_cp_behavior(TP_SIZE, CP_SIZE, PP_SIZE, DP_SIZE, SEQ_LEN=8):
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
backend = "nccl"
assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE)
data_loader = MicroBatchDataLoader(
micro_batch_size=2,
seq_length=SEQ_LEN,
dataset_name="roneneldan/TinyStories",
tokenizer_name="HuggingFaceTB/SmolLM-135M",
grad_acc_steps=1,
num_workers=1,
num_proc=1,
num_samples=10,
pin_memory=False
)
ref_data_loader = DummyDataLoader(
micro_batch_size=2,
seq_length=SEQ_LEN,
dataset_name="roneneldan/TinyStories",
tokenizer_name="HuggingFaceTB/SmolLM-135M",
grad_acc_steps=1,
num_workers=1,
num_proc=1,
num_samples=10,
pin_memory=False
)
for i in range(1):
ref_batch = next(ref_data_loader)
batch = next(data_loader)
split_size = ref_batch["input_ids"].shape[1] // pgm.process_group_manager.cp_world_size
start_idx = split_size * global_rank
end_idx = start_idx + split_size
assert torch.equal(ref_batch["input_ids"][:,start_idx:end_idx], batch["input_ids"]), "input_ids are not equal"
# test the infinite loop behavior
def test_infinite_loop():
local_rank = 0
global_rank = 0
world_size = 1
backend = "nccl"
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
setup_process_group_manager(tp_size=1, cp_size=1, pp_size=1, dp_size=1)
data_loader = MicroBatchDataLoader(
micro_batch_size=2,
seq_length=256,
dataset_name="roneneldan/TinyStories",
tokenizer_name="HuggingFaceTB/SmolLM-135M",
grad_acc_steps=1,
num_workers=1,
num_proc=1,
num_samples=2,
)
s = set()
for i in range(10):
batch = next(data_loader)
# Convert the nested list to a tuple of tuples
batch_tuple = tuple(tuple(x) for x in batch["input_ids"].tolist())
if batch_tuple in s:
assert True
s.add(batch_tuple)
assert False
if __name__ == "__main__":
# test_infinite_loop()
test_cp_behavior(TP_SIZE=1, CP_SIZE=2, PP_SIZE=1, DP_SIZE=1, SEQ_LEN=8)

View File

@ -0,0 +1,75 @@
"""
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 test_tensor_parallel.py
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 test_tensor_parallel.py
"""
from picotron.process_group_manager import setup_process_group_manager
from picotron.tensor_parallel.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from picotron.utils import set_all_seed
import torch
import os
import torch.distributed as dist
import datetime
import picotron.process_group_manager as pgm
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device("cuda", local_rank)
dist.init_process_group(rank=global_rank, world_size=world_size, backend="nccl", init_method=f"env://", timeout=datetime.timedelta(minutes=3))
setup_process_group_manager(tp_size=world_size, cp_size=1, pp_size=1, dp_size=1)
set_all_seed(42)
batch_size, seq_len = 2, 4
input_size, output_size = 8, 16
bias = True # linear layer with/without bias
async_all_reduce = False # async all-reduce or not for column parallel linear layer
# Initialize input tensor
tensor_shape = (batch_size, seq_len, input_size)
tensor = torch.randn(tensor_shape, device=device, requires_grad=True)
column_parallel_tensor = tensor.clone().detach().requires_grad_(True)
row_parallel_tensor = tensor.clone().chunk(world_size, dim=-1)[local_rank].detach().requires_grad_(True)
# Initialize column/row parallel layers
column_parallel_linear = ColumnParallelLinear(input_size, output_size, bias=bias, gather_output=True, async_all_reduce=async_all_reduce).to(device)
row_parallel_linear = RowParallelLinear(input_size, output_size, bias=bias).to(device)
linear_layer = torch.nn.Linear(input_size, output_size, bias=bias, device=device)
# copy weight and bias from reference linear layer to column/row parallel layers
column_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=0)[local_rank])
row_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=1)[local_rank])
if bias:
column_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias.chunk(world_size, dim=0)[local_rank])
row_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias)
### forward pass ###
output_reference = linear_layer(tensor)
output_column_parallel = column_parallel_linear(column_parallel_tensor)
output_row_parallel = row_parallel_linear(row_parallel_tensor)
# check forward output consistency
assert torch.all(torch.eq(output_reference, output_column_parallel)), "Column Parallel Linear is not equal to the reference"
torch.testing.assert_close(output_reference, output_row_parallel) # not strictly equal. precision issue
### backward pass ###
output_reference.backward(torch.ones_like(output_reference))
output_column_parallel.backward(torch.ones_like(output_column_parallel))
output_row_parallel.backward(torch.ones_like(output_row_parallel))
# check backward weight gradient, bias gradient, and input gradient consistency
# column parallel linear test
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.weight.grad)
torch.testing.assert_close(tensor.grad, column_parallel_tensor.grad)
if bias:
torch.testing.assert_close(linear_layer.bias.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.bias.grad)
# row parallel linear test
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=1)[local_rank], row_parallel_linear.weight.grad)
torch.testing.assert_close(tensor.grad.chunk(world_size, dim=-1)[local_rank], row_parallel_tensor.grad)
if bias:
torch.testing.assert_close(linear_layer.bias.grad, row_parallel_linear.bias.grad)
print(f"Rank {dist.get_rank()}: All tests passed")

View File

@ -1,7 +1,7 @@
"""Training script for LLaMA model.
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
@ -125,7 +125,7 @@ if __name__ == "__main__":
if is_wandb_rank and config["logging"]["use_wandb"]:
wandb.init(
project="picotron",
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
name=f"{config['logging']['run_name']}_{to_readable_format(tokens_per_step)}_{pgm.process_group_manager}",
config={
"tensor_parallel_size": pgm.process_group_manager.tp_world_size,
"context_parallel_size": pgm.process_group_manager.cp_world_size,
@ -238,7 +238,8 @@ if __name__ == "__main__":
step_duration = time.time() - step_start_time
tokens_per_second = tokens_per_step / step_duration
mfu = get_mfu(tokens_per_second / world_size, num_params, model_config)
tokens_per_second_per_gpu = tokens_per_second / world_size
mfu = get_mfu(tokens_per_second_per_gpu, num_params, model_config)
if is_wandb_rank:
print(
@ -247,7 +248,7 @@ if __name__ == "__main__":
f"Loss: {loss:6.4f} | "
f"Global batch size: {to_readable_format(tokens_per_step):>7s} | "
f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | "
f"Tokens/s/GPU: {to_readable_format(tokens_per_second / world_size):>7s} | "
f"Tokens/s/GPU: {to_readable_format(tokens_per_second_per_gpu):>7s} | "
f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(config['training']['max_tokens'])) if config['training']['max_tokens'] else ''} | "
f"MFU: {mfu:5.2f}% | "
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB",
@ -259,6 +260,8 @@ if __name__ == "__main__":
"loss": loss,
"tokens_per_step": tokens_per_step,
"tokens_per_second": tokens_per_step / step_duration,
"mfu": mfu,
"tokens_per_second_per_gpu": tokens_per_second_per_gpu,
"memory_usage": torch.cuda.memory_reserved() / 1e9,
"trained_tokens": trained_tokens
})