Merge remote-tracking branch 'origin/main' into loading_big_model
This commit is contained in:
commit
43f39ff9ec
@ -10,8 +10,7 @@ from picotron.utils import print
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, split="train", num_samples=None):
|
||||
|
||||
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True):
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.seq_length = seq_length
|
||||
self.grad_acc_steps = grad_acc_steps
|
||||
@ -50,7 +49,7 @@ class MicroBatchDataLoader(DataLoader):
|
||||
self.tokenized_dataset,
|
||||
batch_size=micro_batch_size,
|
||||
collate_fn=self.collate_batch,
|
||||
pin_memory=True,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
sampler=self.sampler,
|
||||
shuffle=False
|
||||
@ -109,14 +108,11 @@ class MicroBatchDataLoader(DataLoader):
|
||||
input_ids = batch_input_ids[:, start_idx:end_idx].contiguous()
|
||||
target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous()
|
||||
position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
|
||||
local_attn_mask = torch.tril(torch.ones((self.seq_length_per_gpu, self.seq_length_per_gpu), dtype=torch.bool))
|
||||
attn_mask = local_attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous()
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"target_ids": target_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_mask": attn_mask,
|
||||
"hidden_states": None
|
||||
}
|
||||
|
||||
@ -131,6 +127,12 @@ class MicroBatchDataLoader(DataLoader):
|
||||
try:
|
||||
batch = next(self._iterator)
|
||||
except StopIteration:
|
||||
self._iterator = None
|
||||
raise StopIteration
|
||||
# Reinitialize the sampler and iterator
|
||||
self.sampler.set_epoch(self.sampler.epoch + 1 if hasattr(self.sampler, 'epoch') else 0)
|
||||
self._iterator = super().__iter__()
|
||||
try:
|
||||
batch = next(self._iterator)
|
||||
except StopIteration:
|
||||
self._iterator = None
|
||||
raise StopIteration
|
||||
return batch
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.tensor_parallel.tp_communications import Reduce, Gather, Copy, split_tensor_along_last_dim
|
||||
from picotron.tensor_parallel.tp_communications import Reduce, Gather, linear_with_all_reduce, linear_with_async_all_reduce
|
||||
|
||||
def apply_tensor_parallel(model):
|
||||
|
||||
@ -51,10 +51,26 @@ def apply_tensor_parallel(model):
|
||||
|
||||
return model
|
||||
|
||||
class ColumnParallelLinear(nn.Module):
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Column Parallel Linear layer
|
||||
Y = XW + b, where weight matrix W is parallelized along its second dimension. W = [W_1, ..., W_p]
|
||||
This module returns the results of Y_i = XW_i + b_i in the forward method, Y_i is parallelized in the second dimension.
|
||||
Arguments:
|
||||
in_features: first dimension of weight matrix W.
|
||||
out_features: second dimension of weight matrix W.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights
|
||||
gather_output: If true, gather the output from all the partitions. This is used for the last linear layer
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool, gather_output: bool = False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
gather_output: bool = False,
|
||||
async_all_reduce: bool = True,
|
||||
) -> None:
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
self.tp_world_size = pgm.process_group_manager.tp_world_size
|
||||
@ -65,7 +81,8 @@ class ColumnParallelLinear(nn.Module):
|
||||
assert out_features % self.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
|
||||
self.output_size_per_partition = out_features // self.tp_world_size
|
||||
self.gather_output = gather_output
|
||||
|
||||
self.async_all_reduce = async_all_reduce
|
||||
# Allocate space for the weight and bias
|
||||
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
|
||||
self.weight = nn.Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
|
||||
if bias:
|
||||
@ -95,17 +112,34 @@ class ColumnParallelLinear(nn.Module):
|
||||
# Split the model into size of self.output_size_per_partition
|
||||
weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0)
|
||||
self.weight.data = weight_list[self.tp_rank].contiguous()
|
||||
|
||||
def forward(self, input):
|
||||
input_parallel = Copy.apply(input)
|
||||
# XW_i^T + b, output is Y_i
|
||||
output = F.linear(input_parallel, self.weight, self.bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.async_all_reduce:
|
||||
output = linear_with_async_all_reduce(input, self.weight, self.bias)
|
||||
else:
|
||||
output = linear_with_all_reduce(input, self.weight, self.bias)
|
||||
if self.gather_output:
|
||||
output = Gather.apply(output)
|
||||
return output
|
||||
|
||||
class RowParallelLinear(nn.Module):
|
||||
|
||||
"""Linear layer with row parallelism.
|
||||
Y = XW + b. W is parallelized along its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| W_1 |
|
||||
| . |
|
||||
W = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| W_p |
|
||||
- -
|
||||
We assume that X is already parallelized. This is the case after ColumnParallelLinear.
|
||||
This module returns the results of Y = sum(X_i * W_i + b_i) in the forward method.
|
||||
Arguments:
|
||||
in_features: first dimension of matrix W.
|
||||
out_features: second dimension of matrix W.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights.
|
||||
"""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
|
||||
|
||||
@ -1,6 +1,13 @@
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import picotron.process_group_manager as pgm
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
def merge_first_two_dims(grad_output: torch.Tensor, input_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Merge the first two dimensions of tensors."""
|
||||
return grad_output.contiguous().view(-1, *grad_output.shape[2:]), input_.contiguous().view(-1, *input_.shape[2:])
|
||||
|
||||
def split_tensor_along_last_dim(tensor, num_partitions):
|
||||
"""Split a tensor along its last dimension into num_partitions chunks."""
|
||||
@ -45,7 +52,7 @@ class Gather(torch.autograd.Function):
|
||||
chunks = split_tensor_along_last_dim(grad_output, pgm.process_group_manager.tp_world_size)
|
||||
return chunks[pgm.process_group_manager.tp_rank].contiguous()
|
||||
|
||||
class Copy(torch.autograd.Function):
|
||||
class Identity(torch.autograd.Function):
|
||||
"""Identity in forward pass, all-reduce in backward pass."""
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
@ -56,4 +63,41 @@ class Copy(torch.autograd.Function):
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
return grad_output
|
||||
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return grad_output
|
||||
return grad_output
|
||||
|
||||
def linear_with_all_reduce(input, weight, bias):
|
||||
input_parallel = Identity.apply(input)
|
||||
output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i
|
||||
return output
|
||||
|
||||
def linear_with_async_all_reduce(input, weight, bias):
|
||||
return LinearWithAsyncAllReduce.apply(input, weight, bias)
|
||||
|
||||
class LinearWithAsyncAllReduce(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
output = input_ @ weight.t() + bias if bias is not None else input_ @ weight.t()
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
The key difference with "linear_with_all_reduce" is that the all reduce of input_ gradeint is before
|
||||
the calculation of the gradient of weights and bias, instead of after. So we can overlap the computation and communication
|
||||
This is only applicable to Column Parallel Linear
|
||||
|
||||
Before: grad_output -> grad_input, grad_weight, grad_bias -> grad_input all reduce
|
||||
Now: grad_output -> grad_input -> grad_input all reduce -> grad_weight, grad_bias
|
||||
"""
|
||||
input_, weight = ctx.saved_tensors
|
||||
grad_input = grad_output @ weight # (b, s, out_size) @ (out_size, input_size) = (b, s, input_size)
|
||||
# all-reduce input gradient.
|
||||
input_gradient_all_reduce_handle = dist.all_reduce(grad_input, group=pgm.process_group_manager.tp_group, async_op=True)
|
||||
# merge first two dims to allow matrix multiplication
|
||||
grad_output, input_ = merge_first_two_dims(grad_output, input_) # grad_output, input_: (b, s, out_size), (b, s, input_size) -> (b*s, out_size), (b*s, input_size)
|
||||
grad_weight = grad_output.t() @ input_ # (out_size, b*s) @ (b*s, input_size) -> (out_size, input_size)
|
||||
grad_bias = grad_output.sum(0) if ctx.use_bias else None
|
||||
input_gradient_all_reduce_handle.wait()
|
||||
return grad_input, grad_weight, grad_bias
|
||||
213
tests/test_dataloader.py
Normal file
213
tests/test_dataloader.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""
|
||||
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 test_dataloader.py
|
||||
"""
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import datetime
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from datasets import Features, Sequence, Value, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
# remove context parallelism split. as a reference
|
||||
class DummyDataLoader(DataLoader):
|
||||
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True):
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.seq_length = seq_length
|
||||
self.grad_acc_steps = grad_acc_steps
|
||||
self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
|
||||
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples:
|
||||
self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
|
||||
# Tokenize and chunk the dataset
|
||||
self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length, num_proc)
|
||||
|
||||
self.sampler = DistributedSampler(
|
||||
self.tokenized_dataset,
|
||||
num_replicas=pgm.process_group_manager.dp_world_size,
|
||||
rank=pgm.process_group_manager.dp_rank,
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
self.tokenized_dataset,
|
||||
batch_size=micro_batch_size,
|
||||
collate_fn=self.collate_batch,
|
||||
pin_memory=True,
|
||||
num_workers=num_workers,
|
||||
sampler=self.sampler,
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def tokenizer_group_text(examples, tokenizer, sequence_length):
|
||||
"""Tokenize a list of texts and group them in chunks of sequence_length + 1"""
|
||||
tokenized_text_batch = tokenizer.batch_encode_plus(
|
||||
examples,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
return_tensors='np'
|
||||
)
|
||||
concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])}
|
||||
total_length = len(concatenated_tokens['input_ids'])
|
||||
if total_length >= sequence_length + 1:
|
||||
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
|
||||
result = {
|
||||
'input_ids': [
|
||||
concatenated_tokens['input_ids'][i : i + sequence_length + 1]
|
||||
for i in range(0, total_length - sequence_length, sequence_length)
|
||||
]
|
||||
}
|
||||
return result
|
||||
|
||||
def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc):
|
||||
"""Tokenize the dataset and group texts in chunks of sequence_length + 1"""
|
||||
# Create a partial function with fixed arguments
|
||||
tokenizer_func = partial(
|
||||
self.tokenizer_group_text,
|
||||
tokenizer=self.tokenizer,
|
||||
sequence_length=sequence_length
|
||||
)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenizer_func,
|
||||
input_columns=text_column_name,
|
||||
remove_columns=dataset.column_names,
|
||||
features=Features({
|
||||
"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)
|
||||
}),
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=True,
|
||||
desc=f"Grouping texts in chunks of {sequence_length+1}",
|
||||
)
|
||||
|
||||
return tokenized_dataset
|
||||
|
||||
def collate_batch(self, batch):
|
||||
batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
|
||||
batch_size = batch_input_ids.size(0)
|
||||
input_ids = batch_input_ids[:, :self.seq_length].contiguous()
|
||||
target_ids = batch_input_ids[:, 1:self.seq_length+1].contiguous()
|
||||
position_ids = torch.arange(0, self.seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"target_ids": target_ids,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": None
|
||||
}
|
||||
|
||||
def __iter__(self):
|
||||
if self._iterator is None:
|
||||
self._iterator = super().__iter__()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._iterator is None:
|
||||
self._iterator = super().__iter__()
|
||||
try:
|
||||
batch = next(self._iterator)
|
||||
except StopIteration:
|
||||
# Reinitialize the sampler and iterator
|
||||
self.sampler.set_epoch(self.sampler.epoch + 1 if hasattr(self.sampler, 'epoch') else 0)
|
||||
self._iterator = super().__iter__()
|
||||
try:
|
||||
batch = next(self._iterator)
|
||||
except StopIteration:
|
||||
self._iterator = None
|
||||
raise StopIteration
|
||||
return batch
|
||||
|
||||
# test the tokens are split correctly in context parallelism
|
||||
# TODO: test zigzag behavior
|
||||
def test_cp_behavior(TP_SIZE, CP_SIZE, PP_SIZE, DP_SIZE, SEQ_LEN=8):
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
backend = "nccl"
|
||||
|
||||
assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
|
||||
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
|
||||
setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE)
|
||||
|
||||
data_loader = MicroBatchDataLoader(
|
||||
micro_batch_size=2,
|
||||
seq_length=SEQ_LEN,
|
||||
dataset_name="roneneldan/TinyStories",
|
||||
tokenizer_name="HuggingFaceTB/SmolLM-135M",
|
||||
grad_acc_steps=1,
|
||||
num_workers=1,
|
||||
num_proc=1,
|
||||
num_samples=10,
|
||||
pin_memory=False
|
||||
)
|
||||
|
||||
ref_data_loader = DummyDataLoader(
|
||||
micro_batch_size=2,
|
||||
seq_length=SEQ_LEN,
|
||||
dataset_name="roneneldan/TinyStories",
|
||||
tokenizer_name="HuggingFaceTB/SmolLM-135M",
|
||||
grad_acc_steps=1,
|
||||
num_workers=1,
|
||||
num_proc=1,
|
||||
num_samples=10,
|
||||
pin_memory=False
|
||||
)
|
||||
|
||||
for i in range(1):
|
||||
ref_batch = next(ref_data_loader)
|
||||
batch = next(data_loader)
|
||||
split_size = ref_batch["input_ids"].shape[1] // pgm.process_group_manager.cp_world_size
|
||||
start_idx = split_size * global_rank
|
||||
end_idx = start_idx + split_size
|
||||
assert torch.equal(ref_batch["input_ids"][:,start_idx:end_idx], batch["input_ids"]), "input_ids are not equal"
|
||||
|
||||
# test the infinite loop behavior
|
||||
def test_infinite_loop():
|
||||
local_rank = 0
|
||||
global_rank = 0
|
||||
world_size = 1
|
||||
backend = "nccl"
|
||||
|
||||
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
|
||||
setup_process_group_manager(tp_size=1, cp_size=1, pp_size=1, dp_size=1)
|
||||
|
||||
data_loader = MicroBatchDataLoader(
|
||||
micro_batch_size=2,
|
||||
seq_length=256,
|
||||
dataset_name="roneneldan/TinyStories",
|
||||
tokenizer_name="HuggingFaceTB/SmolLM-135M",
|
||||
grad_acc_steps=1,
|
||||
num_workers=1,
|
||||
num_proc=1,
|
||||
num_samples=2,
|
||||
)
|
||||
|
||||
s = set()
|
||||
for i in range(10):
|
||||
batch = next(data_loader)
|
||||
# Convert the nested list to a tuple of tuples
|
||||
batch_tuple = tuple(tuple(x) for x in batch["input_ids"].tolist())
|
||||
if batch_tuple in s:
|
||||
assert True
|
||||
s.add(batch_tuple)
|
||||
assert False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_infinite_loop()
|
||||
test_cp_behavior(TP_SIZE=1, CP_SIZE=2, PP_SIZE=1, DP_SIZE=1, SEQ_LEN=8)
|
||||
75
tests/test_tensor_parallel.py
Normal file
75
tests/test_tensor_parallel.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 test_tensor_parallel.py
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 test_tensor_parallel.py
|
||||
"""
|
||||
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.tensor_parallel.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
||||
from picotron.utils import set_all_seed
|
||||
import torch
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
import datetime
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
device = torch.device("cuda", local_rank)
|
||||
|
||||
dist.init_process_group(rank=global_rank, world_size=world_size, backend="nccl", init_method=f"env://", timeout=datetime.timedelta(minutes=3))
|
||||
setup_process_group_manager(tp_size=world_size, cp_size=1, pp_size=1, dp_size=1)
|
||||
|
||||
set_all_seed(42)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
input_size, output_size = 8, 16
|
||||
bias = True # linear layer with/without bias
|
||||
async_all_reduce = False # async all-reduce or not for column parallel linear layer
|
||||
|
||||
# Initialize input tensor
|
||||
tensor_shape = (batch_size, seq_len, input_size)
|
||||
tensor = torch.randn(tensor_shape, device=device, requires_grad=True)
|
||||
column_parallel_tensor = tensor.clone().detach().requires_grad_(True)
|
||||
row_parallel_tensor = tensor.clone().chunk(world_size, dim=-1)[local_rank].detach().requires_grad_(True)
|
||||
|
||||
# Initialize column/row parallel layers
|
||||
column_parallel_linear = ColumnParallelLinear(input_size, output_size, bias=bias, gather_output=True, async_all_reduce=async_all_reduce).to(device)
|
||||
row_parallel_linear = RowParallelLinear(input_size, output_size, bias=bias).to(device)
|
||||
linear_layer = torch.nn.Linear(input_size, output_size, bias=bias, device=device)
|
||||
|
||||
# copy weight and bias from reference linear layer to column/row parallel layers
|
||||
column_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=0)[local_rank])
|
||||
row_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=1)[local_rank])
|
||||
if bias:
|
||||
column_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias.chunk(world_size, dim=0)[local_rank])
|
||||
row_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias)
|
||||
|
||||
### forward pass ###
|
||||
output_reference = linear_layer(tensor)
|
||||
output_column_parallel = column_parallel_linear(column_parallel_tensor)
|
||||
output_row_parallel = row_parallel_linear(row_parallel_tensor)
|
||||
|
||||
# check forward output consistency
|
||||
assert torch.all(torch.eq(output_reference, output_column_parallel)), "Column Parallel Linear is not equal to the reference"
|
||||
torch.testing.assert_close(output_reference, output_row_parallel) # not strictly equal. precision issue
|
||||
|
||||
### backward pass ###
|
||||
output_reference.backward(torch.ones_like(output_reference))
|
||||
output_column_parallel.backward(torch.ones_like(output_column_parallel))
|
||||
output_row_parallel.backward(torch.ones_like(output_row_parallel))
|
||||
|
||||
# check backward weight gradient, bias gradient, and input gradient consistency
|
||||
# column parallel linear test
|
||||
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.weight.grad)
|
||||
torch.testing.assert_close(tensor.grad, column_parallel_tensor.grad)
|
||||
if bias:
|
||||
torch.testing.assert_close(linear_layer.bias.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.bias.grad)
|
||||
|
||||
# row parallel linear test
|
||||
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=1)[local_rank], row_parallel_linear.weight.grad)
|
||||
torch.testing.assert_close(tensor.grad.chunk(world_size, dim=-1)[local_rank], row_parallel_tensor.grad)
|
||||
if bias:
|
||||
torch.testing.assert_close(linear_layer.bias.grad, row_parallel_linear.bias.grad)
|
||||
|
||||
print(f"Rank {dist.get_rank()}: All tests passed")
|
||||
11
train.py
11
train.py
@ -1,7 +1,7 @@
|
||||
"""Training script for LLaMA model.
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
|
||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||
@ -125,7 +125,7 @@ if __name__ == "__main__":
|
||||
if is_wandb_rank and config["logging"]["use_wandb"]:
|
||||
wandb.init(
|
||||
project="picotron",
|
||||
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
|
||||
name=f"{config['logging']['run_name']}_{to_readable_format(tokens_per_step)}_{pgm.process_group_manager}",
|
||||
config={
|
||||
"tensor_parallel_size": pgm.process_group_manager.tp_world_size,
|
||||
"context_parallel_size": pgm.process_group_manager.cp_world_size,
|
||||
@ -238,7 +238,8 @@ if __name__ == "__main__":
|
||||
|
||||
step_duration = time.time() - step_start_time
|
||||
tokens_per_second = tokens_per_step / step_duration
|
||||
mfu = get_mfu(tokens_per_second / world_size, num_params, model_config)
|
||||
tokens_per_second_per_gpu = tokens_per_second / world_size
|
||||
mfu = get_mfu(tokens_per_second_per_gpu, num_params, model_config)
|
||||
|
||||
if is_wandb_rank:
|
||||
print(
|
||||
@ -247,7 +248,7 @@ if __name__ == "__main__":
|
||||
f"Loss: {loss:6.4f} | "
|
||||
f"Global batch size: {to_readable_format(tokens_per_step):>7s} | "
|
||||
f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | "
|
||||
f"Tokens/s/GPU: {to_readable_format(tokens_per_second / world_size):>7s} | "
|
||||
f"Tokens/s/GPU: {to_readable_format(tokens_per_second_per_gpu):>7s} | "
|
||||
f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(config['training']['max_tokens'])) if config['training']['max_tokens'] else ''} | "
|
||||
f"MFU: {mfu:5.2f}% | "
|
||||
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB",
|
||||
@ -259,6 +260,8 @@ if __name__ == "__main__":
|
||||
"loss": loss,
|
||||
"tokens_per_step": tokens_per_step,
|
||||
"tokens_per_second": tokens_per_step / step_duration,
|
||||
"mfu": mfu,
|
||||
"tokens_per_second_per_gpu": tokens_per_second_per_gpu,
|
||||
"memory_usage": torch.cuda.memory_reserved() / 1e9,
|
||||
"trained_tokens": trained_tokens
|
||||
})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user