picotron/picotron/parallel/data_parallel/bucket.py
2024-11-04 15:29:26 +00:00

158 lines
7.3 KiB
Python

from typing import List
import torch
import torch.distributed as dist
from torch import nn
class Bucket:
def __init__(self, params: List[torch.nn.Parameter], grad_data: torch.Tensor, process_group: torch.distributed.ProcessGroup) -> None:
"""
Initializes a Bucket instance.
Args:
params (List[torch.nn.Parameter]): List of parameters assigned to this bucket.
grad_data (torch.Tensor): Tensor to store the gradients for this bucket.
process_group (torch.distributed.ProcessGroup): Process group used for synchronizing gradients.
"""
self.params = set(params) # Set of parameters in this bucket.
self.params_with_grad_ready = set() # Parameters that have their gradients ready for synchronization. launch all reduce when all parameters have gradients ready
self.grad_data = grad_data # Tensor that stores gradients for all parameters in this bucket.
self.process_group = process_group # Process group for gradient synchronization.
self.process_group_size = dist.get_world_size(group=self.process_group)
self.handle = None # Handle for the async allreduce operation.
self.reset()
def sync_gradient(self) -> None:
"""
Launch an asynchronous all-reduce operation to synchronize gradients across processes.
"""
assert self.handle is None
self.grad_data /= self.process_group_size
self.handle = dist.all_reduce(self.grad_data, group=self.process_group, async_op=True)
def reset(self) -> None:
"""
Reset the bucket to its initial state. Typically called after the gradient synchronization is finished.
"""
self.handle = None
self.params_with_grad_ready.clear() # Clear the set of parameters ready for gradient synchronization.
self.grad_data.zero_() # Zero the gradient tensor.
def wait(self) -> None:
"""
wait for the allreduce operation to finish
"""
assert self.handle is not None, "You should launch an allreduce operation before waiting for it to finish"
self.handle.wait() # Block until the all-reduce operation finishes.
def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
"""
Mark a parameter as ready for gradient synchronization. Launches synchronization when all parameters in the
bucket have their gradients ready.
"""
assert param in self.params and param not in self.params_with_grad_ready
self.params_with_grad_ready.add(param)
# When all parameters in the bucket have their gradients ready, synchronize gradients
if len(self.params_with_grad_ready) == len(self.params):
self.sync_gradient()
class BucketManager:
def __init__(self, params: List[torch.nn.Parameter], process_group: torch.distributed.ProcessGroup, bucket_size: int, grad_type: torch.dtype = torch.float32) -> None:
"""
Initializes the BucketManager.
Args:
params (List[torch.nn.Parameter]): List of model parameters.
process_group (torch.distributed.ProcessGroup): Process group used for gradient synchronization.
bucket_size (int): Maximum size of each bucket in terms of gradient elements.
grad_type (torch.dtype, optional): Data type of gradients, defaults to torch.float32.
"""
self.params = list(params) # Convert parameter generator to a list.
self.buckets = [] # List of buckets.
self.process_group = process_group
self.process_group_size = dist.get_world_size(group=self.process_group)
self.params_to_bucket_location = {} # Map each parameter to its corresponding bucket/place (start, end, bucket_idx).
self.bucket_size = bucket_size
self.bucket_sizes = None # Actual sizes of each bucket.
self.grad_data_list = [] # List of tensors to store gradients, one tensor per bucket.
self.grad_type = grad_type
# Divide gradients into buckets based on the provided bucket size.
self._initialize_buckets()
def _initialize_buckets(self) -> None:
"""
Divides model parameters into buckets for gradient synchronization based on the bucket size.
"""
cur_bucket_size = 0
cur_bucket_idx = 0
# Assign parameters to buckets.
for param in self.params:
if not param.requires_grad:
continue
# If the bucket is empty, add the parameter to the bucket.
if cur_bucket_size == 0:
self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
cur_bucket_size = param.numel()
continue
# If the parameter cannot fit in the current bucket, create a new bucket
if cur_bucket_size + param.numel() > self.bucket_size:
cur_bucket_idx += 1
self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
cur_bucket_size = param.numel()
else:
self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + param.numel(), cur_bucket_idx)
cur_bucket_size += param.numel()
# Gather information about the bucket sizes and the parameters in each bucket
bucket_sizes = [0] * (cur_bucket_idx + 1)
buckets_to_params = [[] for _ in range(cur_bucket_idx + 1)]
for param, (_, end, idx) in self.params_to_bucket_location.items():
bucket_sizes[idx] = max(bucket_sizes[idx], end)
buckets_to_params[idx].append(param)
# Create tensors for storing gradients and initialize Bucket objects.
for i in range(len(bucket_sizes)):
self.grad_data_list.append(torch.zeros(bucket_sizes[i], dtype=self.grad_type, device='cuda'))
self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))
# Create gradient views for each parameter.
for param in self.params[::-1]:
if not param.requires_grad:
continue
data_start_index, data_end_index, bucket_id = self.params_to_bucket_location[param]
# param.main_grad is used for gradient calculation
param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_id], param.shape, data_start_index, data_end_index)
def _get_view_from_tensor(self, tensor: torch.Tensor, shape: torch.Size, start: int, end: int) -> torch.Tensor:
"""
Create a view of the given tensor with the specified shape from start to end indices.
"""
return tensor[start:end].view(shape)
def reset(self) -> None:
"""
Reset all buckets by clearing the gradients and internal states.
"""
for bucket in self.buckets:
bucket.reset()
def wait(self) -> None:
"""
Wait for all buckets to complete their gradient synchronization
"""
for bucket in self.buckets:
bucket.wait()
def mark_param_as_ready(self, param: torch.nn.Parameter) -> None:
"""
Mark a parameter's gradient as ready for synchronization.
"""
bucket_idx = self.params_to_bucket_location[param][2]
self.buckets[bucket_idx].mark_param_as_ready(param)