Remove unused parts in Megatron-LM code and add copyright notice (#110)

This commit is contained in:
Zhuohan Li 2023-05-20 09:11:34 -06:00 committed by GitHub
parent b7955ef17b
commit 7297fa6f7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 23 additions and 297 deletions

View File

@ -1,6 +1,5 @@
import cacheflow.model_executor.parallel_utils.parallel_state
import cacheflow.model_executor.parallel_utils.tensor_parallel
import cacheflow.model_executor.parallel_utils.utils
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
@ -8,5 +7,4 @@ mpu = parallel_state
__all__ = [
"parallel_state",
"tensor_parallel",
"utils",
]

View File

@ -1,3 +1,5 @@
# Copyright 2023 The CacheFlow team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
@ -5,8 +7,6 @@
import torch
from typing import Optional
from .utils import GlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
@ -44,9 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
def initialize_model_parallel(
@ -199,13 +196,6 @@ def initialize_model_parallel(
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
def initialize_all_reduce_launcher(
max_num_tokens: int,
hidden_size: int,
@ -495,17 +485,6 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer():
"""Return the global GlobalMemoryBuffer object"""
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def get_all_reduce_launcher() -> 'GraphAllReduce':
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
return _ALL_REDUCE_LAUNCHER
@ -536,8 +515,6 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
class GraphAllReduce:

View File

@ -17,15 +17,12 @@ from .mappings import (
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed,
)
from .utils import (
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
__all__ = [
@ -45,11 +42,8 @@ __all__ = [
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
"split_tensor_into_1d_equal_chunks",
"gather_split_1d_tensor",
]

View File

@ -1,3 +1,5 @@
# Copyright 2023 The CacheFlow team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch

View File

@ -1,3 +1,5 @@
# Copyright 2023 The CacheFlow team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch

View File

@ -1,3 +1,5 @@
# Copyright 2023 The CacheFlow team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
@ -8,22 +10,11 @@ import contextlib
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from cacheflow.model_executor.parallel_utils.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .utils import (
split_tensor_into_1d_equal_chunks,
gather_split_1d_tensor,
)
from cacheflow.model_executor.parallel_utils.utils import safely_set_viewless_tensor_data
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
@ -171,83 +162,3 @@ def model_parallel_cuda_manual_seed(seed):
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
tensor_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_saved_activations \
= distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad():
outputs = run_function(*args)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0],
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
# Store everything.
ctx.save_for_backward(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0],
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None, None) + grads
def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function,
distribute_saved_activations, *args)

View File

@ -1,10 +1,23 @@
# Copyright 2023 The CacheFlow team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from typing import List, Sequence
from cacheflow.model_executor.parallel_utils.utils import divide
from cacheflow.model_executor.parallel_utils import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
@ -33,57 +46,6 @@ def split_tensor_along_last_dim(
return tensor_list
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // \
parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=parallel_state.get_tensor_model_parallel_group())
return gathered
class VocabUtility:
""" Split the vocabulary into `world_size` chunks and return the first

View File

@ -1,120 +0,0 @@
"""Utility functions used throughout Megatron core"""
from functools import reduce
import operator
import torch
from cacheflow.model_executor.parallel_utils import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor