Merge pull request #9 from huggingface/async_tp

Async tp
This commit is contained in:
Haojun Zhao 2024-12-14 07:24:35 -05:00 committed by GitHub
commit 55efb321f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 132 additions and 12 deletions

View File

@ -12,10 +12,9 @@ from typing import Callable, Optional
import picotron.process_group_manager as pgm
from functools import partial
import torch.nn.init as init
from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
from picotron.tensor_parallel.tp_communications import gather_from_model_parallel_region, linear_with_all_reduce, linear_with_async_all_reduce, reduce_from_model_parallel_region
def apply_tensor_parallel(model, init_method):
def _replace_module(_module, _linear_proj_name, _style, _init_method, args={}):
assert _style in ["column", "row", 'vocab']
linear_layer = getattr(_module, _linear_proj_name)
@ -135,6 +134,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
gather_output: bool = False,
async_all_reduce: bool = True,
) -> None:
super(ColumnParallelLinear, self).__init__()
@ -143,7 +143,7 @@ class ColumnParallelLinear(torch.nn.Module):
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
self.gather_output = gather_output
self.async_all_reduce = async_all_reduce
# Allocate space for the weight and bias
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
@ -166,8 +166,10 @@ class ColumnParallelLinear(torch.nn.Module):
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
input_parallel = copy_to_model_parallel_region(input_)
output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
if self.async_all_reduce:
output = linear_with_async_all_reduce(input_, self.weight, self.bias)
else:
output = linear_with_all_reduce(input_, self.weight, self.bias)
if self.gather_output:
output = gather_from_model_parallel_region(output)
return output

View File

@ -2,10 +2,11 @@
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
"""
from picotron.tensor_parallel.tp_utils import split_tensor_along_last_dim
from picotron.tensor_parallel.tp_utils import merge_first_two_dims, split_tensor_along_last_dim
import torch.distributed as dist
import torch
import picotron.process_group_manager as pgm
import torch.nn.functional as F
def _reduce(input_):
"""All-reduce the input tensor across model parallel(Tensor Parallel) group."""
@ -93,4 +94,41 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
return _split(grad_output)
def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_)
return _GatherFromModelParallelRegion.apply(input_)
def linear_with_all_reduce(input_, weight, bias):
input_parallel = copy_to_model_parallel_region(input_)
output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i
return output
class _LinearWithAsyncAllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
output = input_ @ weight.t() + bias if bias is not None else input_ @ weight.t()
return output
@staticmethod
def backward(ctx, grad_output):
"""
The key difference with "linear_with_all_reduce" is that the all reduce of input_ gradeint is before
the calculation of the gradient of weights and bias, instead of after. So we can overlap the computation and communication
This is only applicable to Column Parallel Linear
Before: grad_output -> grad_input, grad_weight, grad_bias -> grad_input all reduce
Now: grad_output -> grad_input -> grad_input all reduce -> grad_weight, grad_bias
"""
input_, weight = ctx.saved_tensors
grad_input = grad_output @ weight # (b, s, out_size) @ (out_size, input_size) = (b, s, input_size)
# all-reduce input gradient.
input_gradient_all_reduce_handle = dist.all_reduce(grad_input, group=pgm.process_group_manager.tp_group, async_op=True)
# merge first two dims to allow matrix multiplication
grad_output, input_ = merge_first_two_dims(grad_output, input_) # grad_output, input_: (b, s, out_size), (b, s, input_size) -> (b*s, out_size), (b*s, input_size)
grad_weight = grad_output.t() @ input_ # (out_size, b*s) @ (b*s, input_size) -> (out_size, input_size)
grad_bias = grad_output.sum(0) if ctx.use_bias else None
input_gradient_all_reduce_handle.wait()
return grad_input, grad_weight, grad_bias
def linear_with_async_all_reduce(input_, weight, bias):
return _LinearWithAsyncAllReduce.apply(input_, weight, bias)

View File

@ -5,6 +5,10 @@ Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
from typing import Tuple
import torch
def merge_first_two_dims(grad_output: torch.Tensor, input_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Merge the first two dimensions of tensors."""
return grad_output.contiguous().view(-1, *grad_output.shape[2:]), input_.contiguous().view(-1, *input_.shape[2:])
def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""

View File

@ -0,0 +1,75 @@
"""
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 test_tensor_parallel.py
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 test_tensor_parallel.py
"""
from picotron.process_group_manager import setup_process_group_manager
from picotron.tensor_parallel.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from picotron.utils import set_all_seed
import torch
import os
import torch.distributed as dist
import datetime
import picotron.process_group_manager as pgm
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device("cuda", local_rank)
dist.init_process_group(rank=global_rank, world_size=world_size, backend="nccl", init_method=f"env://", timeout=datetime.timedelta(minutes=3))
setup_process_group_manager(tp_size=world_size, cp_size=1, pp_size=1, dp_size=1)
set_all_seed(42)
batch_size, seq_len = 2, 4
input_size, output_size = 8, 16
bias = True # linear layer with/without bias
async_all_reduce = False # async all-reduce or not for column parallel linear layer
# Initialize input tensor
tensor_shape = (batch_size, seq_len, input_size)
tensor = torch.randn(tensor_shape, device=device, requires_grad=True)
column_parallel_tensor = tensor.clone().detach().requires_grad_(True)
row_parallel_tensor = tensor.clone().chunk(world_size, dim=-1)[local_rank].detach().requires_grad_(True)
# Initialize column/row parallel layers
column_parallel_linear = ColumnParallelLinear(input_size, output_size, bias=bias, gather_output=True, async_all_reduce=async_all_reduce).to(device)
row_parallel_linear = RowParallelLinear(input_size, output_size, bias=bias).to(device)
linear_layer = torch.nn.Linear(input_size, output_size, bias=bias, device=device)
# copy weight and bias from reference linear layer to column/row parallel layers
column_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=0)[local_rank])
row_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=1)[local_rank])
if bias:
column_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias.chunk(world_size, dim=0)[local_rank])
row_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias)
### forward pass ###
output_reference = linear_layer(tensor)
output_column_parallel = column_parallel_linear(column_parallel_tensor)
output_row_parallel = row_parallel_linear(row_parallel_tensor)
# check forward output consistency
assert torch.all(torch.eq(output_reference, output_column_parallel)), "Column Parallel Linear is not equal to the reference"
torch.testing.assert_close(output_reference, output_row_parallel) # not strictly equal. precision issue
### backward pass ###
output_reference.backward(torch.ones_like(output_reference))
output_column_parallel.backward(torch.ones_like(output_column_parallel))
output_row_parallel.backward(torch.ones_like(output_row_parallel))
# check backward weight gradient, bias gradient, and input gradient consistency
# column parallel linear test
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.weight.grad)
torch.testing.assert_close(tensor.grad, column_parallel_tensor.grad)
if bias:
torch.testing.assert_close(linear_layer.bias.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.bias.grad)
# row parallel linear test
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=1)[local_rank], row_parallel_linear.weight.grad)
torch.testing.assert_close(tensor.grad.chunk(world_size, dim=-1)[local_rank], row_parallel_tensor.grad)
if bias:
torch.testing.assert_close(linear_layer.bias.grad, row_parallel_linear.bias.grad)
print(f"Rank {dist.get_rank()}: All tests passed")

View File

@ -1,7 +1,7 @@
"""Training script for LLaMA model.
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
@ -147,7 +147,7 @@ if __name__ == "__main__":
if is_wandb_rank and USE_WANDB:
wandb.init(
project="picotron",
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
name=f"{config['logging']['run_name']}_{to_readable_format(tokens_per_step)}_{pgm.process_group_manager}",
config={
"tensor_parallel_size": pgm.process_group_manager.tp_size,
"context_parallel_size": pgm.process_group_manager.cp_size,
@ -243,7 +243,8 @@ if __name__ == "__main__":
step_duration = time.time() - step_start_time
tokens_per_second = tokens_per_step / step_duration
mfu = get_mfu(tokens_per_second / world_size, num_params, model_config)
tokens_per_second_per_gpu = tokens_per_second / world_size
mfu = get_mfu(tokens_per_second_per_gpu, num_params, model_config)
if is_wandb_rank:
print(
@ -252,7 +253,7 @@ if __name__ == "__main__":
f"Loss: {loss:6.4f} | "
f"Global batch size: {to_readable_format(tokens_per_step):>7s} | "
f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | "
f"Tokens/s/GPU: {to_readable_format(tokens_per_second / world_size):>7s} | "
f"Tokens/s/GPU: {to_readable_format(tokens_per_second_per_gpu):>7s} | "
f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''} | "
f"MFU: {mfu:5.2f}% | "
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB",
@ -261,7 +262,7 @@ if __name__ == "__main__":
if USE_WANDB:
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
"mfu": mfu, "tokens_per_second_per_gpu": tokens_per_second_per_gpu, "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
if step % CHECKPOINT_FREQ == 0:
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")