diff --git a/cacheflow/model_executor/parallel_utils/__init__.py b/cacheflow/model_executor/parallel_utils/__init__.py index 11b6fa22..0ae09bf8 100644 --- a/cacheflow/model_executor/parallel_utils/__init__.py +++ b/cacheflow/model_executor/parallel_utils/__init__.py @@ -1,6 +1,5 @@ import cacheflow.model_executor.parallel_utils.parallel_state import cacheflow.model_executor.parallel_utils.tensor_parallel -import cacheflow.model_executor.parallel_utils.utils # Alias parallel_state as mpu, its legacy name mpu = parallel_state @@ -8,5 +7,4 @@ mpu = parallel_state __all__ = [ "parallel_state", "tensor_parallel", - "utils", ] diff --git a/cacheflow/model_executor/parallel_utils/parallel_state.py b/cacheflow/model_executor/parallel_utils/parallel_state.py index 8bb8c402..aacd6509 100644 --- a/cacheflow/model_executor/parallel_utils/parallel_state.py +++ b/cacheflow/model_executor/parallel_utils/parallel_state.py @@ -1,3 +1,5 @@ +# Copyright 2023 The CacheFlow team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Model and data parallel groups.""" @@ -5,8 +7,6 @@ import torch from typing import Optional -from .utils import GlobalMemoryBuffer - # Intra-layer model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Inter-layer model parallel group that the current rank belongs to. @@ -44,9 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None # rank when broadcasting weights from src to all other data parallel ranks _DATA_PARALLEL_GLOBAL_RANKS = None -# Memory buffers to avoid dynamic memory allocation -_GLOBAL_MEMORY_BUFFER = None - _ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None def initialize_model_parallel( @@ -199,13 +196,6 @@ def initialize_model_parallel( if rank in ranks: _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks - # Initialize global memory buffer - # This isn't really "parallel state" but there isn't another good place to - # put this. If we end up with a more generic initialization of megatron-core - # we could stick it there - _set_global_memory_buffer() - - def initialize_all_reduce_launcher( max_num_tokens: int, hidden_size: int, @@ -495,17 +485,6 @@ def get_data_parallel_rank(): """Return my rank for the data parallel group.""" return torch.distributed.get_rank(group=get_data_parallel_group()) -def _set_global_memory_buffer(): - """Initialize global buffer""" - global _GLOBAL_MEMORY_BUFFER - assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized' - _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() - -def get_global_memory_buffer(): - """Return the global GlobalMemoryBuffer object""" - assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' - return _GLOBAL_MEMORY_BUFFER - def get_all_reduce_launcher() -> 'GraphAllReduce': assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized' return _ALL_REDUCE_LAUNCHER @@ -536,8 +515,6 @@ def destroy_model_parallel(): _MPU_TENSOR_MODEL_PARALLEL_RANK = None global _MPU_PIPELINE_MODEL_PARALLEL_RANK _MPU_PIPELINE_MODEL_PARALLEL_RANK = None - global _GLOBAL_MEMORY_BUFFER - _GLOBAL_MEMORY_BUFFER = None class GraphAllReduce: diff --git a/cacheflow/model_executor/parallel_utils/tensor_parallel/__init__.py b/cacheflow/model_executor/parallel_utils/tensor_parallel/__init__.py index 246f5f6f..da0ce2a1 100644 --- a/cacheflow/model_executor/parallel_utils/tensor_parallel/__init__.py +++ b/cacheflow/model_executor/parallel_utils/tensor_parallel/__init__.py @@ -17,15 +17,12 @@ from .mappings import ( ) from .random import ( - checkpoint, get_cuda_rng_tracker, model_parallel_cuda_manual_seed, ) from .utils import ( split_tensor_along_last_dim, - split_tensor_into_1d_equal_chunks, - gather_split_1d_tensor, ) __all__ = [ @@ -45,11 +42,8 @@ __all__ = [ "scatter_to_tensor_model_parallel_region", "scatter_to_sequence_parallel_region", # random.py - "checkpoint", "get_cuda_rng_tracker", "model_parallel_cuda_manual_seed", # utils.py "split_tensor_along_last_dim", - "split_tensor_into_1d_equal_chunks", - "gather_split_1d_tensor", ] diff --git a/cacheflow/model_executor/parallel_utils/tensor_parallel/layers.py b/cacheflow/model_executor/parallel_utils/tensor_parallel/layers.py index fef8ff74..2ec8312f 100644 --- a/cacheflow/model_executor/parallel_utils/tensor_parallel/layers.py +++ b/cacheflow/model_executor/parallel_utils/tensor_parallel/layers.py @@ -1,3 +1,5 @@ +# Copyright 2023 The CacheFlow team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Parts of the code here are adapted from PyTorch diff --git a/cacheflow/model_executor/parallel_utils/tensor_parallel/mappings.py b/cacheflow/model_executor/parallel_utils/tensor_parallel/mappings.py index 4352514b..fe7de641 100644 --- a/cacheflow/model_executor/parallel_utils/tensor_parallel/mappings.py +++ b/cacheflow/model_executor/parallel_utils/tensor_parallel/mappings.py @@ -1,3 +1,5 @@ +# Copyright 2023 The CacheFlow team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import torch diff --git a/cacheflow/model_executor/parallel_utils/tensor_parallel/random.py b/cacheflow/model_executor/parallel_utils/tensor_parallel/random.py index 1374f13b..ab57f946 100644 --- a/cacheflow/model_executor/parallel_utils/tensor_parallel/random.py +++ b/cacheflow/model_executor/parallel_utils/tensor_parallel/random.py @@ -1,3 +1,5 @@ +# Copyright 2023 The CacheFlow team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Parts of the code here are adapted from PyTorch @@ -8,22 +10,11 @@ import contextlib import torch from torch import _C from torch.cuda import _lazy_call, device as device_ctx_manager -from torch.utils.checkpoint import detach_variable from cacheflow.model_executor.parallel_utils.parallel_state import ( - get_data_parallel_rank, - get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, ) -from .utils import ( - split_tensor_into_1d_equal_chunks, - gather_split_1d_tensor, -) - -from cacheflow.model_executor.parallel_utils.utils import safely_set_viewless_tensor_data - # Default name for the model parallel rng tracker. _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' @@ -171,83 +162,3 @@ def model_parallel_cuda_manual_seed(seed): # and model parallel state. _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed) - - -class CheckpointFunction(torch.autograd.Function): - """This function is adapted from torch.utils.checkpoint with - two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` - 2) the states in the model parallel tracker are also properly - tracked/set/reset. - """ - @staticmethod - def forward(ctx, run_function, distribute_saved_activations, *args): - ctx.run_function = run_function - ctx.distribute_saved_activations \ - = distribute_saved_activations - - # Copy the rng states. - ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - with torch.no_grad(): - outputs = run_function(*args) - - # Divide hidden states across model parallel group and only keep - # the chunk corresponding to the current rank. - if distribute_saved_activations: - ctx.input_0_shape = args[0].data.shape - safely_set_viewless_tensor_data( - args[0], - split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) - - # Store everything. - ctx.save_for_backward(*args) - - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError("Checkpointing is not compatible with .grad(), " - "please use .backward() if possible") - inputs = ctx.saved_tensors - if ctx.distribute_saved_activations: - safely_set_viewless_tensor_data( - inputs[0], - gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) - - # Store the current states. - bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() - bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - # Set the states to what it used to be before the forward pass. - torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) - - # Compute the forward pass. - detached_inputs = detach_variable(inputs) - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - # Set the states back to what it was at the start of this function. - torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs,) - torch.autograd.backward(outputs, args) - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp - for inp in detached_inputs) - return (None, None) + grads - - -def checkpoint(function, distribute_saved_activations, *args): - """Checkpoint a model or part of the model. - This has been directly copied from torch.utils.checkpoint.""" - return CheckpointFunction.apply(function, - distribute_saved_activations, *args) diff --git a/cacheflow/model_executor/parallel_utils/tensor_parallel/utils.py b/cacheflow/model_executor/parallel_utils/tensor_parallel/utils.py index e8e6c81b..892a5453 100644 --- a/cacheflow/model_executor/parallel_utils/tensor_parallel/utils.py +++ b/cacheflow/model_executor/parallel_utils/tensor_parallel/utils.py @@ -1,10 +1,23 @@ +# Copyright 2023 The CacheFlow team. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import torch from typing import List, Sequence -from cacheflow.model_executor.parallel_utils.utils import divide -from cacheflow.model_executor.parallel_utils import parallel_state +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator + ) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + def split_tensor_along_last_dim( tensor: torch.Tensor, @@ -33,57 +46,6 @@ def split_tensor_along_last_dim( return tensor_list -def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): - """ Break a tensor into equal 1D chunks across tensor parallel ranks. - - Returns a Tensor or View with this rank's portion of the data. - - Arguments: - tensor: The tensor to split - - Keyword Arguments: - new_buffer (bool): If True, returns a new Tensor. - If False, returns a view into the existing Tensor. - Default is False - - """ - partition_size = torch.numel(tensor) // \ - parallel_state.get_tensor_model_parallel_world_size() - start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() - end_index = start_index + partition_size - if new_buffer: - data = torch.empty(partition_size, dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False) - data.copy_(tensor.view(-1)[start_index:end_index]) - else: - data = tensor.view(-1)[start_index:end_index] - return data - - -def gather_split_1d_tensor(tensor): - """ Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor - model parallel ranks. - - Returns a new Tensor with the gathered data. - - Arguments: - tensor: A Tensor or view of this rank's portion of the data. - """ - numel_gathered = torch.numel(tensor) * \ - parallel_state.get_tensor_model_parallel_world_size() - gathered = torch.empty(numel_gathered, dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False) - # TODO: This API is experimental in pytorch (as of Feb 2022) and - # this might break in future pytorch releases. We chose this API - # as opposed to torch.distributed.all_gather for efficiency reasons. - # This API calls directly NCCL all-gather versus the former does - # internal copies and can potentially cause slow down. - torch.distributed._all_gather_base(gathered, tensor, - group=parallel_state.get_tensor_model_parallel_group()) - return gathered - class VocabUtility: """ Split the vocabulary into `world_size` chunks and return the first diff --git a/cacheflow/model_executor/parallel_utils/utils.py b/cacheflow/model_executor/parallel_utils/utils.py deleted file mode 100644 index c81e0afa..00000000 --- a/cacheflow/model_executor/parallel_utils/utils.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Utility functions used throughout Megatron core""" -from functools import reduce -import operator - -import torch - -from cacheflow.model_executor.parallel_utils import parallel_state - - -def ensure_divisibility(numerator, denominator): - """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, "{} is not divisible by {}".format( - numerator, denominator - ) - - -def divide(numerator, denominator): - """Ensure that numerator is divisible by the denominator and return - the division value.""" - ensure_divisibility(numerator, denominator) - return numerator // denominator - - -class GlobalMemoryBuffer: - """Global buffer to avoid dynamic memory allocations. - Caller should ensure that buffers of the same name - are not used concurrently.""" - - def __init__(self): - self.buffer = {} - - def get_tensor(self, tensor_shape, dtype, name): - required_len = reduce(operator.mul, tensor_shape, 1) - if self.buffer.get((name, dtype), None) is None or \ - self.buffer[(name, dtype)].numel() < required_len: - self.buffer[(name, dtype)] = \ - torch.empty(required_len, - dtype=dtype, - device=torch.cuda.current_device(), - requires_grad=False) - - return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) - -def _kernel_make_viewless_tensor(inp, requires_grad): - '''Make a viewless tensor. - - View tensors have the undesirable side-affect of retaining a reference - to the originally-viewed tensor, even after manually setting the '.data' - field. This method creates a new tensor that links to the old tensor's - data, without linking the viewed tensor, referenced via the '._base' - field. - ''' - out = torch.empty( - (1,), - dtype = inp.dtype, - device = inp.device, - requires_grad = requires_grad, - ) - out.data = inp.data - return out - -class MakeViewlessTensor(torch.autograd.Function): - ''' - Autograd function to make a viewless tensor. - - This function should be used in cases where the computation graph needs - to be propagated, but we only want a viewless tensor (e.g., - ParallelTransformer's hidden_states). Call this function by passing - 'keep_graph = True' to 'make_viewless_tensor()'. - ''' - @staticmethod - def forward(ctx, inp, requires_grad): - return _kernel_make_viewless_tensor(inp, requires_grad) - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - -def make_viewless_tensor(inp, requires_grad, keep_graph): - ''' - Entry-point for creating viewless tensors. - - This method should be used, rather than calling 'MakeViewlessTensor' - or '_kernel_make_viewless_tensor' directly. This method acts as a - switch for determining if an autograd function or a regular method - should be used to create the tensor. - ''' - - # return tensor as-is, if not a 'view' - if inp._base is None: - return inp - - # create viewless tensor - if keep_graph: - return MakeViewlessTensor.apply(inp, requires_grad) - else: - return _kernel_make_viewless_tensor(inp, requires_grad) - -def assert_viewless_tensor(tensor, extra_msg = None): - '''Assert that a tensor is not a view (i.e., its '._base' field is - not set).''' - if isinstance(tensor, list): - [ assert_viewless_tensor(t) for t in tensor ] - return tensor - if not isinstance(tensor, torch.Tensor): - return tensor - assert tensor._base is None, ( - "Ensure tensor._base is None before setting tensor.data or storing " - "tensor to memory buffer. Otherwise, a memory leak will occur (and " - "likely accumulate over iterations). %s" - ) % extra_msg - return tensor - -def safely_set_viewless_tensor_data(tensor, new_data_tensor): - '''Safely set tensor's '.data' field. - - Check first that the tensor is viewless (i.e., '._base' not set). If not, - raise an exception. - ''' - assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) - tensor.data = new_data_tensor