87 lines
3.0 KiB
Python
87 lines
3.0 KiB
Python
# Copyright 2023 The vLLM team.
|
|
# Adapted from
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
from typing import Sequence, Tuple
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def ensure_divisibility(numerator, denominator):
|
|
"""Ensure that numerator is divisible by the denominator."""
|
|
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
|
numerator, denominator)
|
|
|
|
|
|
def divide(numerator, denominator):
|
|
"""Ensure that numerator is divisible by the denominator and return
|
|
the division value."""
|
|
ensure_divisibility(numerator, denominator)
|
|
return numerator // denominator
|
|
|
|
|
|
def split_tensor_along_last_dim(
|
|
tensor: torch.Tensor,
|
|
num_partitions: int,
|
|
contiguous_split_chunks: bool = False,
|
|
) -> Sequence[torch.Tensor]:
|
|
""" Split a tensor along its last dimension.
|
|
|
|
Arguments:
|
|
tensor: input tensor.
|
|
num_partitions: number of partitions to split the tensor
|
|
contiguous_split_chunks: If True, make each chunk contiguous
|
|
in memory.
|
|
|
|
Returns:
|
|
A list of Tensors
|
|
"""
|
|
# Get the size and dimension.
|
|
last_dim = tensor.dim() - 1
|
|
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
|
# Split.
|
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
|
# NOTE: torch.split does not create contiguous tensors by default.
|
|
if contiguous_split_chunks:
|
|
return tuple(chunk.contiguous() for chunk in tensor_list)
|
|
|
|
return tensor_list
|
|
|
|
|
|
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
|
pp_size: int) -> Tuple[int, int]:
|
|
"""Try to evenly distribute layers across partitions.
|
|
If the number of layers is not divisible by the number of partitions,
|
|
the last partition will have the remaining layers.
|
|
"""
|
|
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
|
|
if partition_list_str is not None:
|
|
try:
|
|
partitions = [
|
|
int(layer) for layer in partition_list_str.split(",")
|
|
]
|
|
except ValueError as err:
|
|
raise ValueError("Invalid partition string: {}".format(
|
|
partition_list_str)) from err
|
|
if len(partitions) != pp_size:
|
|
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
|
if sum(partitions) != num_hidden_layers:
|
|
raise ValueError(
|
|
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
|
start_layer = sum(partitions[:pp_rank])
|
|
end_layer = start_layer + partitions[pp_rank]
|
|
else:
|
|
layers_per_partition = num_hidden_layers // pp_size
|
|
start_layer = pp_rank * layers_per_partition
|
|
end_layer = start_layer + layers_per_partition
|
|
|
|
if pp_rank == pp_size - 1:
|
|
end_layer = num_hidden_layers
|
|
|
|
return (start_layer, end_layer)
|