commit
55efb321f9
@ -12,10 +12,9 @@ from typing import Callable, Optional
|
||||
import picotron.process_group_manager as pgm
|
||||
from functools import partial
|
||||
import torch.nn.init as init
|
||||
from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
from picotron.tensor_parallel.tp_communications import gather_from_model_parallel_region, linear_with_all_reduce, linear_with_async_all_reduce, reduce_from_model_parallel_region
|
||||
|
||||
def apply_tensor_parallel(model, init_method):
|
||||
|
||||
def _replace_module(_module, _linear_proj_name, _style, _init_method, args={}):
|
||||
assert _style in ["column", "row", 'vocab']
|
||||
linear_layer = getattr(_module, _linear_proj_name)
|
||||
@ -135,6 +134,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
bias: bool = False,
|
||||
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
|
||||
gather_output: bool = False,
|
||||
async_all_reduce: bool = True,
|
||||
) -> None:
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
@ -143,7 +143,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size"
|
||||
self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size
|
||||
self.gather_output = gather_output
|
||||
|
||||
self.async_all_reduce = async_all_reduce
|
||||
# Allocate space for the weight and bias
|
||||
# Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions
|
||||
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i
|
||||
@ -166,8 +166,10 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
input_parallel = copy_to_model_parallel_region(input_)
|
||||
output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i
|
||||
if self.async_all_reduce:
|
||||
output = linear_with_async_all_reduce(input_, self.weight, self.bias)
|
||||
else:
|
||||
output = linear_with_all_reduce(input_, self.weight, self.bias)
|
||||
if self.gather_output:
|
||||
output = gather_from_model_parallel_region(output)
|
||||
return output
|
||||
|
||||
@ -2,10 +2,11 @@
|
||||
Inspired by Fair Scale/Megatron's Tensor Parallelism implementation
|
||||
Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
"""
|
||||
from picotron.tensor_parallel.tp_utils import split_tensor_along_last_dim
|
||||
from picotron.tensor_parallel.tp_utils import merge_first_two_dims, split_tensor_along_last_dim
|
||||
import torch.distributed as dist
|
||||
import torch
|
||||
import picotron.process_group_manager as pgm
|
||||
import torch.nn.functional as F
|
||||
|
||||
def _reduce(input_):
|
||||
"""All-reduce the input tensor across model parallel(Tensor Parallel) group."""
|
||||
@ -93,4 +94,41 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
|
||||
return _split(grad_output)
|
||||
|
||||
def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor:
|
||||
return _GatherFromModelParallelRegion.apply(input_)
|
||||
return _GatherFromModelParallelRegion.apply(input_)
|
||||
|
||||
def linear_with_all_reduce(input_, weight, bias):
|
||||
input_parallel = copy_to_model_parallel_region(input_)
|
||||
output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i
|
||||
return output
|
||||
|
||||
class _LinearWithAsyncAllReduce(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
output = input_ @ weight.t() + bias if bias is not None else input_ @ weight.t()
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
The key difference with "linear_with_all_reduce" is that the all reduce of input_ gradeint is before
|
||||
the calculation of the gradient of weights and bias, instead of after. So we can overlap the computation and communication
|
||||
This is only applicable to Column Parallel Linear
|
||||
|
||||
Before: grad_output -> grad_input, grad_weight, grad_bias -> grad_input all reduce
|
||||
Now: grad_output -> grad_input -> grad_input all reduce -> grad_weight, grad_bias
|
||||
"""
|
||||
input_, weight = ctx.saved_tensors
|
||||
grad_input = grad_output @ weight # (b, s, out_size) @ (out_size, input_size) = (b, s, input_size)
|
||||
# all-reduce input gradient.
|
||||
input_gradient_all_reduce_handle = dist.all_reduce(grad_input, group=pgm.process_group_manager.tp_group, async_op=True)
|
||||
# merge first two dims to allow matrix multiplication
|
||||
grad_output, input_ = merge_first_two_dims(grad_output, input_) # grad_output, input_: (b, s, out_size), (b, s, input_size) -> (b*s, out_size), (b*s, input_size)
|
||||
grad_weight = grad_output.t() @ input_ # (out_size, b*s) @ (b*s, input_size) -> (out_size, input_size)
|
||||
grad_bias = grad_output.sum(0) if ctx.use_bias else None
|
||||
input_gradient_all_reduce_handle.wait()
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
def linear_with_async_all_reduce(input_, weight, bias):
|
||||
return _LinearWithAsyncAllReduce.apply(input_, weight, bias)
|
||||
|
||||
@ -5,6 +5,10 @@ Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
def merge_first_two_dims(grad_output: torch.Tensor, input_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Merge the first two dimensions of tensors."""
|
||||
return grad_output.contiguous().view(-1, *grad_output.shape[2:]), input_.contiguous().view(-1, *input_.shape[2:])
|
||||
|
||||
def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
|
||||
75
tests/test_tensor_parallel.py
Normal file
75
tests/test_tensor_parallel.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 test_tensor_parallel.py
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 test_tensor_parallel.py
|
||||
"""
|
||||
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.tensor_parallel.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
||||
from picotron.utils import set_all_seed
|
||||
import torch
|
||||
import os
|
||||
import torch.distributed as dist
|
||||
import datetime
|
||||
import picotron.process_group_manager as pgm
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
device = torch.device("cuda", local_rank)
|
||||
|
||||
dist.init_process_group(rank=global_rank, world_size=world_size, backend="nccl", init_method=f"env://", timeout=datetime.timedelta(minutes=3))
|
||||
setup_process_group_manager(tp_size=world_size, cp_size=1, pp_size=1, dp_size=1)
|
||||
|
||||
set_all_seed(42)
|
||||
|
||||
batch_size, seq_len = 2, 4
|
||||
input_size, output_size = 8, 16
|
||||
bias = True # linear layer with/without bias
|
||||
async_all_reduce = False # async all-reduce or not for column parallel linear layer
|
||||
|
||||
# Initialize input tensor
|
||||
tensor_shape = (batch_size, seq_len, input_size)
|
||||
tensor = torch.randn(tensor_shape, device=device, requires_grad=True)
|
||||
column_parallel_tensor = tensor.clone().detach().requires_grad_(True)
|
||||
row_parallel_tensor = tensor.clone().chunk(world_size, dim=-1)[local_rank].detach().requires_grad_(True)
|
||||
|
||||
# Initialize column/row parallel layers
|
||||
column_parallel_linear = ColumnParallelLinear(input_size, output_size, bias=bias, gather_output=True, async_all_reduce=async_all_reduce).to(device)
|
||||
row_parallel_linear = RowParallelLinear(input_size, output_size, bias=bias).to(device)
|
||||
linear_layer = torch.nn.Linear(input_size, output_size, bias=bias, device=device)
|
||||
|
||||
# copy weight and bias from reference linear layer to column/row parallel layers
|
||||
column_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=0)[local_rank])
|
||||
row_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=1)[local_rank])
|
||||
if bias:
|
||||
column_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias.chunk(world_size, dim=0)[local_rank])
|
||||
row_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias)
|
||||
|
||||
### forward pass ###
|
||||
output_reference = linear_layer(tensor)
|
||||
output_column_parallel = column_parallel_linear(column_parallel_tensor)
|
||||
output_row_parallel = row_parallel_linear(row_parallel_tensor)
|
||||
|
||||
# check forward output consistency
|
||||
assert torch.all(torch.eq(output_reference, output_column_parallel)), "Column Parallel Linear is not equal to the reference"
|
||||
torch.testing.assert_close(output_reference, output_row_parallel) # not strictly equal. precision issue
|
||||
|
||||
### backward pass ###
|
||||
output_reference.backward(torch.ones_like(output_reference))
|
||||
output_column_parallel.backward(torch.ones_like(output_column_parallel))
|
||||
output_row_parallel.backward(torch.ones_like(output_row_parallel))
|
||||
|
||||
# check backward weight gradient, bias gradient, and input gradient consistency
|
||||
# column parallel linear test
|
||||
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.weight.grad)
|
||||
torch.testing.assert_close(tensor.grad, column_parallel_tensor.grad)
|
||||
if bias:
|
||||
torch.testing.assert_close(linear_layer.bias.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.bias.grad)
|
||||
|
||||
# row parallel linear test
|
||||
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=1)[local_rank], row_parallel_linear.weight.grad)
|
||||
torch.testing.assert_close(tensor.grad.chunk(world_size, dim=-1)[local_rank], row_parallel_tensor.grad)
|
||||
if bias:
|
||||
torch.testing.assert_close(linear_layer.bias.grad, row_parallel_linear.bias.grad)
|
||||
|
||||
print(f"Rank {dist.get_rank()}: All tests passed")
|
||||
11
train.py
11
train.py
@ -1,7 +1,7 @@
|
||||
"""Training script for LLaMA model.
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
|
||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||
@ -147,7 +147,7 @@ if __name__ == "__main__":
|
||||
if is_wandb_rank and USE_WANDB:
|
||||
wandb.init(
|
||||
project="picotron",
|
||||
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
|
||||
name=f"{config['logging']['run_name']}_{to_readable_format(tokens_per_step)}_{pgm.process_group_manager}",
|
||||
config={
|
||||
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
||||
"context_parallel_size": pgm.process_group_manager.cp_size,
|
||||
@ -243,7 +243,8 @@ if __name__ == "__main__":
|
||||
|
||||
step_duration = time.time() - step_start_time
|
||||
tokens_per_second = tokens_per_step / step_duration
|
||||
mfu = get_mfu(tokens_per_second / world_size, num_params, model_config)
|
||||
tokens_per_second_per_gpu = tokens_per_second / world_size
|
||||
mfu = get_mfu(tokens_per_second_per_gpu, num_params, model_config)
|
||||
|
||||
if is_wandb_rank:
|
||||
print(
|
||||
@ -252,7 +253,7 @@ if __name__ == "__main__":
|
||||
f"Loss: {loss:6.4f} | "
|
||||
f"Global batch size: {to_readable_format(tokens_per_step):>7s} | "
|
||||
f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | "
|
||||
f"Tokens/s/GPU: {to_readable_format(tokens_per_second / world_size):>7s} | "
|
||||
f"Tokens/s/GPU: {to_readable_format(tokens_per_second_per_gpu):>7s} | "
|
||||
f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''} | "
|
||||
f"MFU: {mfu:5.2f}% | "
|
||||
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB",
|
||||
@ -261,7 +262,7 @@ if __name__ == "__main__":
|
||||
|
||||
if USE_WANDB:
|
||||
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
|
||||
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
||||
"mfu": mfu, "tokens_per_second_per_gpu": tokens_per_second_per_gpu, "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
||||
|
||||
if step % CHECKPOINT_FREQ == 0:
|
||||
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user