Remove unused parts in Megatron-LM code and add copyright notice (#110)
This commit is contained in:
parent
b7955ef17b
commit
7297fa6f7c
@ -1,6 +1,5 @@
|
||||
import cacheflow.model_executor.parallel_utils.parallel_state
|
||||
import cacheflow.model_executor.parallel_utils.tensor_parallel
|
||||
import cacheflow.model_executor.parallel_utils.utils
|
||||
|
||||
# Alias parallel_state as mpu, its legacy name
|
||||
mpu = parallel_state
|
||||
@ -8,5 +7,4 @@ mpu = parallel_state
|
||||
__all__ = [
|
||||
"parallel_state",
|
||||
"tensor_parallel",
|
||||
"utils",
|
||||
]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
"""Model and data parallel groups."""
|
||||
@ -5,8 +7,6 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .utils import GlobalMemoryBuffer
|
||||
|
||||
# Intra-layer model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Inter-layer model parallel group that the current rank belongs to.
|
||||
@ -44,9 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
|
||||
# rank when broadcasting weights from src to all other data parallel ranks
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = None
|
||||
|
||||
# Memory buffers to avoid dynamic memory allocation
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
||||
|
||||
def initialize_model_parallel(
|
||||
@ -199,13 +196,6 @@ def initialize_model_parallel(
|
||||
if rank in ranks:
|
||||
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
||||
|
||||
# Initialize global memory buffer
|
||||
# This isn't really "parallel state" but there isn't another good place to
|
||||
# put this. If we end up with a more generic initialization of megatron-core
|
||||
# we could stick it there
|
||||
_set_global_memory_buffer()
|
||||
|
||||
|
||||
def initialize_all_reduce_launcher(
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
@ -495,17 +485,6 @@ def get_data_parallel_rank():
|
||||
"""Return my rank for the data parallel group."""
|
||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
||||
|
||||
def _set_global_memory_buffer():
|
||||
"""Initialize global buffer"""
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
|
||||
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
|
||||
|
||||
def get_global_memory_buffer():
|
||||
"""Return the global GlobalMemoryBuffer object"""
|
||||
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
|
||||
return _GLOBAL_MEMORY_BUFFER
|
||||
|
||||
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
||||
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
||||
return _ALL_REDUCE_LAUNCHER
|
||||
@ -536,8 +515,6 @@ def destroy_model_parallel():
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
|
||||
class GraphAllReduce:
|
||||
|
||||
@ -17,15 +17,12 @@ from .mappings import (
|
||||
)
|
||||
|
||||
from .random import (
|
||||
checkpoint,
|
||||
get_cuda_rng_tracker,
|
||||
model_parallel_cuda_manual_seed,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
split_tensor_along_last_dim,
|
||||
split_tensor_into_1d_equal_chunks,
|
||||
gather_split_1d_tensor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -45,11 +42,8 @@ __all__ = [
|
||||
"scatter_to_tensor_model_parallel_region",
|
||||
"scatter_to_sequence_parallel_region",
|
||||
# random.py
|
||||
"checkpoint",
|
||||
"get_cuda_rng_tracker",
|
||||
"model_parallel_cuda_manual_seed",
|
||||
# utils.py
|
||||
"split_tensor_along_last_dim",
|
||||
"split_tensor_into_1d_equal_chunks",
|
||||
"gather_split_1d_tensor",
|
||||
]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
@ -8,22 +10,11 @@ import contextlib
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
||||
from torch.utils.checkpoint import detach_variable
|
||||
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
get_data_parallel_rank,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
split_tensor_into_1d_equal_chunks,
|
||||
gather_split_1d_tensor,
|
||||
)
|
||||
|
||||
from cacheflow.model_executor.parallel_utils.utils import safely_set_viewless_tensor_data
|
||||
|
||||
# Default name for the model parallel rng tracker.
|
||||
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
|
||||
|
||||
@ -171,83 +162,3 @@ def model_parallel_cuda_manual_seed(seed):
|
||||
# and model parallel state.
|
||||
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
|
||||
tensor_model_parallel_seed)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
"""This function is adapted from torch.utils.checkpoint with
|
||||
two main changes:
|
||||
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
|
||||
2) the states in the model parallel tracker are also properly
|
||||
tracked/set/reset.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, distribute_saved_activations, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.distribute_saved_activations \
|
||||
= distribute_saved_activations
|
||||
|
||||
# Copy the rng states.
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*args)
|
||||
|
||||
# Divide hidden states across model parallel group and only keep
|
||||
# the chunk corresponding to the current rank.
|
||||
if distribute_saved_activations:
|
||||
ctx.input_0_shape = args[0].data.shape
|
||||
safely_set_viewless_tensor_data(
|
||||
args[0],
|
||||
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
|
||||
|
||||
# Store everything.
|
||||
ctx.save_for_backward(*args)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError("Checkpointing is not compatible with .grad(), "
|
||||
"please use .backward() if possible")
|
||||
inputs = ctx.saved_tensors
|
||||
if ctx.distribute_saved_activations:
|
||||
safely_set_viewless_tensor_data(
|
||||
inputs[0],
|
||||
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
|
||||
|
||||
# Store the current states.
|
||||
bwd_cpu_rng_state = torch.get_rng_state()
|
||||
bwd_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
||||
|
||||
# Set the states to what it used to be before the forward pass.
|
||||
torch.set_rng_state(ctx.fwd_cpu_rng_state)
|
||||
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
|
||||
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
|
||||
|
||||
# Compute the forward pass.
|
||||
detached_inputs = detach_variable(inputs)
|
||||
with torch.enable_grad():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
# Set the states back to what it was at the start of this function.
|
||||
torch.set_rng_state(bwd_cpu_rng_state)
|
||||
_set_cuda_rng_state(bwd_cuda_rng_state)
|
||||
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
torch.autograd.backward(outputs, args)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
|
||||
for inp in detached_inputs)
|
||||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint(function, distribute_saved_activations, *args):
|
||||
"""Checkpoint a model or part of the model.
|
||||
This has been directly copied from torch.utils.checkpoint."""
|
||||
return CheckpointFunction.apply(function,
|
||||
distribute_saved_activations, *args)
|
||||
|
||||
@ -1,10 +1,23 @@
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
import torch
|
||||
from typing import List, Sequence
|
||||
|
||||
from cacheflow.model_executor.parallel_utils.utils import divide
|
||||
from cacheflow.model_executor.parallel_utils import parallel_state
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
@ -33,57 +46,6 @@ def split_tensor_along_last_dim(
|
||||
|
||||
return tensor_list
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
|
||||
|
||||
Returns a Tensor or View with this rank's portion of the data.
|
||||
|
||||
Arguments:
|
||||
tensor: The tensor to split
|
||||
|
||||
Keyword Arguments:
|
||||
new_buffer (bool): If True, returns a new Tensor.
|
||||
If False, returns a view into the existing Tensor.
|
||||
Default is False
|
||||
|
||||
"""
|
||||
partition_size = torch.numel(tensor) // \
|
||||
parallel_state.get_tensor_model_parallel_world_size()
|
||||
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
|
||||
end_index = start_index + partition_size
|
||||
if new_buffer:
|
||||
data = torch.empty(partition_size, dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
data.copy_(tensor.view(-1)[start_index:end_index])
|
||||
else:
|
||||
data = tensor.view(-1)[start_index:end_index]
|
||||
return data
|
||||
|
||||
|
||||
def gather_split_1d_tensor(tensor):
|
||||
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
|
||||
model parallel ranks.
|
||||
|
||||
Returns a new Tensor with the gathered data.
|
||||
|
||||
Arguments:
|
||||
tensor: A Tensor or view of this rank's portion of the data.
|
||||
"""
|
||||
numel_gathered = torch.numel(tensor) * \
|
||||
parallel_state.get_tensor_model_parallel_world_size()
|
||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
# TODO: This API is experimental in pytorch (as of Feb 2022) and
|
||||
# this might break in future pytorch releases. We chose this API
|
||||
# as opposed to torch.distributed.all_gather for efficiency reasons.
|
||||
# This API calls directly NCCL all-gather versus the former does
|
||||
# internal copies and can potentially cause slow down.
|
||||
torch.distributed._all_gather_base(gathered, tensor,
|
||||
group=parallel_state.get_tensor_model_parallel_group())
|
||||
return gathered
|
||||
|
||||
|
||||
class VocabUtility:
|
||||
""" Split the vocabulary into `world_size` chunks and return the first
|
||||
|
||||
@ -1,120 +0,0 @@
|
||||
"""Utility functions used throughout Megatron core"""
|
||||
from functools import reduce
|
||||
import operator
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.model_executor.parallel_utils import parallel_state
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class GlobalMemoryBuffer:
|
||||
"""Global buffer to avoid dynamic memory allocations.
|
||||
Caller should ensure that buffers of the same name
|
||||
are not used concurrently."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = {}
|
||||
|
||||
def get_tensor(self, tensor_shape, dtype, name):
|
||||
required_len = reduce(operator.mul, tensor_shape, 1)
|
||||
if self.buffer.get((name, dtype), None) is None or \
|
||||
self.buffer[(name, dtype)].numel() < required_len:
|
||||
self.buffer[(name, dtype)] = \
|
||||
torch.empty(required_len,
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
|
||||
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
|
||||
|
||||
def _kernel_make_viewless_tensor(inp, requires_grad):
|
||||
'''Make a viewless tensor.
|
||||
|
||||
View tensors have the undesirable side-affect of retaining a reference
|
||||
to the originally-viewed tensor, even after manually setting the '.data'
|
||||
field. This method creates a new tensor that links to the old tensor's
|
||||
data, without linking the viewed tensor, referenced via the '._base'
|
||||
field.
|
||||
'''
|
||||
out = torch.empty(
|
||||
(1,),
|
||||
dtype = inp.dtype,
|
||||
device = inp.device,
|
||||
requires_grad = requires_grad,
|
||||
)
|
||||
out.data = inp.data
|
||||
return out
|
||||
|
||||
class MakeViewlessTensor(torch.autograd.Function):
|
||||
'''
|
||||
Autograd function to make a viewless tensor.
|
||||
|
||||
This function should be used in cases where the computation graph needs
|
||||
to be propagated, but we only want a viewless tensor (e.g.,
|
||||
ParallelTransformer's hidden_states). Call this function by passing
|
||||
'keep_graph = True' to 'make_viewless_tensor()'.
|
||||
'''
|
||||
@staticmethod
|
||||
def forward(ctx, inp, requires_grad):
|
||||
return _kernel_make_viewless_tensor(inp, requires_grad)
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
def make_viewless_tensor(inp, requires_grad, keep_graph):
|
||||
'''
|
||||
Entry-point for creating viewless tensors.
|
||||
|
||||
This method should be used, rather than calling 'MakeViewlessTensor'
|
||||
or '_kernel_make_viewless_tensor' directly. This method acts as a
|
||||
switch for determining if an autograd function or a regular method
|
||||
should be used to create the tensor.
|
||||
'''
|
||||
|
||||
# return tensor as-is, if not a 'view'
|
||||
if inp._base is None:
|
||||
return inp
|
||||
|
||||
# create viewless tensor
|
||||
if keep_graph:
|
||||
return MakeViewlessTensor.apply(inp, requires_grad)
|
||||
else:
|
||||
return _kernel_make_viewless_tensor(inp, requires_grad)
|
||||
|
||||
def assert_viewless_tensor(tensor, extra_msg = None):
|
||||
'''Assert that a tensor is not a view (i.e., its '._base' field is
|
||||
not set).'''
|
||||
if isinstance(tensor, list):
|
||||
[ assert_viewless_tensor(t) for t in tensor ]
|
||||
return tensor
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
return tensor
|
||||
assert tensor._base is None, (
|
||||
"Ensure tensor._base is None before setting tensor.data or storing "
|
||||
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
|
||||
"likely accumulate over iterations). %s"
|
||||
) % extra_msg
|
||||
return tensor
|
||||
|
||||
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
|
||||
'''Safely set tensor's '.data' field.
|
||||
|
||||
Check first that the tensor is viewless (i.e., '._base' not set). If not,
|
||||
raise an exception.
|
||||
'''
|
||||
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
|
||||
tensor.data = new_data_tensor
|
||||
Loading…
Reference in New Issue
Block a user