Add memory analyzer & utomatically configure KV cache size (#6)
This commit is contained in:
parent
1a7eb7da61
commit
e9d3f2ff77
@ -3,7 +3,7 @@
|
|||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install cmake torch transformers
|
pip install psutil numpy torch transformers
|
||||||
pip install flash-attn # This may take up to 10 mins.
|
pip install flash-attn # This may take up to 10 mins.
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|||||||
@ -9,8 +9,6 @@ from cacheflow.sequence import SequenceGroupInputs
|
|||||||
from cacheflow.sequence import SequenceOutputs
|
from cacheflow.sequence import SequenceOutputs
|
||||||
from cacheflow.sequence import SequenceStatus
|
from cacheflow.sequence import SequenceStatus
|
||||||
|
|
||||||
_MAX_NUM_BATCHED_TOKENS = 2048
|
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
|
|
||||||
@ -21,12 +19,14 @@ class Scheduler:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
num_gpu_blocks: int,
|
num_gpu_blocks: int,
|
||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.frontend = frontend
|
self.frontend = frontend
|
||||||
self.controllers = controllers
|
self.controllers = controllers
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
|
||||||
# Create the block space manager.
|
# Create the block space manager.
|
||||||
self.block_manager = BlockSpaceManager(
|
self.block_manager = BlockSpaceManager(
|
||||||
@ -164,7 +164,7 @@ class Scheduler:
|
|||||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||||
if self.block_manager.can_allocate(seq_group):
|
if self.block_manager.can_allocate(seq_group):
|
||||||
if (num_batched_tokens + num_prompt_tokens
|
if (num_batched_tokens + num_prompt_tokens
|
||||||
<= _MAX_NUM_BATCHED_TOKENS):
|
<= self.max_num_batched_tokens):
|
||||||
self._allocate(seq_group)
|
self._allocate(seq_group)
|
||||||
num_batched_tokens += num_prompt_tokens
|
num_batched_tokens += num_prompt_tokens
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
from cacheflow.models.input_metadata import InputMetadata
|
from cacheflow.models.input_metadata import InputMetadata
|
||||||
|
from cacheflow.models.model_utils import get_memory_analyzer
|
||||||
from cacheflow.models.model_utils import get_model
|
from cacheflow.models.model_utils import get_model
|
||||||
from cacheflow.models.model_utils import set_seed
|
from cacheflow.models.utils import set_seed
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'InputMetadata',
|
'InputMetadata',
|
||||||
|
'get_memory_analyzer',
|
||||||
'get_model',
|
'get_model',
|
||||||
'set_seed'
|
'set_seed',
|
||||||
]
|
]
|
||||||
|
|||||||
125
cacheflow/models/memory_analyzer.py
Normal file
125
cacheflow/models/memory_analyzer.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from cacheflow.models.utils import get_cpu_memory
|
||||||
|
from cacheflow.models.utils import get_dtype_size
|
||||||
|
from cacheflow.models.utils import get_gpu_memory
|
||||||
|
|
||||||
|
_GiB = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
|
class CacheFlowMemoryAnalyzer:
|
||||||
|
|
||||||
|
def get_max_num_gpu_blocks(
|
||||||
|
self,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
memory_utilization: float,
|
||||||
|
) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_max_num_cpu_blocks(
|
||||||
|
self,
|
||||||
|
memory_utilization: float,
|
||||||
|
) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
self.model_name = model_name
|
||||||
|
self.block_size = block_size
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# TODO(woosuk): Support tensor parallelism.
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = config.hidden_size // self.num_heads
|
||||||
|
self.ffn_size = config.ffn_dim
|
||||||
|
self.embedding_size = config.word_embed_proj_dim
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.max_position = config.max_position_embeddings
|
||||||
|
|
||||||
|
def _get_param_size(self) -> int:
|
||||||
|
# TODO(woosuk): Support tensor parallelism.
|
||||||
|
word_embedding = self.vocab_size * self.embedding_size
|
||||||
|
if self.embedding_size != self.vocab_size:
|
||||||
|
# Project in/out.
|
||||||
|
word_embedding += 2 * self.embedding_size * self.vocab_size
|
||||||
|
position_embedding = self.max_position * self.hidden_size
|
||||||
|
|
||||||
|
ln1 = 2 * self.hidden_size
|
||||||
|
q = self.hidden_size * self.hidden_size + self.hidden_size
|
||||||
|
k = self.hidden_size * self.hidden_size + self.hidden_size
|
||||||
|
v = self.hidden_size * self.hidden_size + self.hidden_size
|
||||||
|
out = self.hidden_size * self.hidden_size + self.hidden_size
|
||||||
|
mha = ln1 + q + k + v + out
|
||||||
|
|
||||||
|
ln2 = 2 * self.hidden_size
|
||||||
|
ffn1 = self.hidden_size * self.ffn_size + self.ffn_size
|
||||||
|
ffn2 = self.ffn_size * self.hidden_size + self.hidden_size
|
||||||
|
ffn = ln2 + ffn1 + ffn2
|
||||||
|
|
||||||
|
total = (word_embedding + position_embedding +
|
||||||
|
self.num_layers * (mha + ffn))
|
||||||
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
|
return dtype_size * total
|
||||||
|
|
||||||
|
def _get_max_act_size(
|
||||||
|
self,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
# TODO(woosuk): Support tensor parallelism.
|
||||||
|
# NOTE: We approxmiately calculate the maximum activation size by
|
||||||
|
# 1) estimating the maximum activation tensor size during inference, and
|
||||||
|
# 2) multiplying it by 4.
|
||||||
|
# Here, we assume that FlashAttention is used and
|
||||||
|
# thus the attention maps are never materialized in GPU DRAM.
|
||||||
|
qkv = 3 * (max_num_batched_tokens * self.hidden_size)
|
||||||
|
ffn = max_num_batched_tokens * self.ffn_size
|
||||||
|
max_act = 4 * max(qkv, ffn)
|
||||||
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
|
return dtype_size * max_act
|
||||||
|
|
||||||
|
def _get_workspace_size(self) -> int:
|
||||||
|
return 1 * _GiB
|
||||||
|
|
||||||
|
def _get_cache_block_size(self) -> int:
|
||||||
|
key_cache_block = self.block_size * self.num_heads * self.head_size
|
||||||
|
value_cache_block = self.block_size * self.num_heads * self.head_size
|
||||||
|
total = self.num_layers * (key_cache_block + value_cache_block)
|
||||||
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
|
return dtype_size * total
|
||||||
|
|
||||||
|
def get_max_num_gpu_blocks(
|
||||||
|
self,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
memory_utilization: float = 0.95,
|
||||||
|
) -> int:
|
||||||
|
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
|
||||||
|
gpu_memory = get_gpu_memory()
|
||||||
|
usable_memory = int(memory_utilization * gpu_memory)
|
||||||
|
|
||||||
|
param_size = self._get_param_size()
|
||||||
|
act_size = self._get_max_act_size(max_num_batched_tokens)
|
||||||
|
workspace_size = self._get_workspace_size()
|
||||||
|
|
||||||
|
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
||||||
|
max_num_blocks = max_cache_size // self._get_cache_block_size()
|
||||||
|
return max_num_blocks
|
||||||
|
|
||||||
|
def get_max_num_cpu_blocks(
|
||||||
|
self,
|
||||||
|
memory_utilization: float = 0.25,
|
||||||
|
) -> int:
|
||||||
|
cpu_memory = get_cpu_memory()
|
||||||
|
usable_memory = int(memory_utilization * cpu_memory)
|
||||||
|
max_num_blocks = usable_memory // self._get_cache_block_size()
|
||||||
|
return max_num_blocks
|
||||||
@ -1,21 +1,20 @@
|
|||||||
import random
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||||
|
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
||||||
from cacheflow.models.opt import OPTForCausalLM
|
from cacheflow.models.opt import OPTForCausalLM
|
||||||
|
from cacheflow.models.utils import get_torch_dtype
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
|
||||||
|
_MODELS = {
|
||||||
'opt': OPTForCausalLM,
|
'opt': OPTForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
_MEMORY_ANALYZERS = {
|
||||||
'half': torch.half,
|
'opt': OPTMemoryAnalyzer,
|
||||||
'float': torch.float,
|
|
||||||
'float16': torch.float16,
|
|
||||||
'float32': torch.float32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -23,20 +22,23 @@ def get_model(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
dtype: Union[torch.dtype, str],
|
dtype: Union[torch.dtype, str],
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
if isinstance(dtype, str):
|
torch_dtype = get_torch_dtype(dtype)
|
||||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
|
for model_class, hf_model in _MODELS.items():
|
||||||
else:
|
|
||||||
torch_dtype = dtype
|
|
||||||
for model_class, hf_model in MODEL_CLASSES.items():
|
|
||||||
if model_class in model_name:
|
if model_class in model_name:
|
||||||
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
|
model = hf_model.from_pretrained(
|
||||||
|
model_name, torch_dtype=torch_dtype)
|
||||||
return model.eval()
|
return model.eval()
|
||||||
raise ValueError(f'Invalid model name: {model_name}')
|
raise ValueError(f'Unsupported model name: {model_name}')
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int) -> None:
|
def get_memory_analyzer(
|
||||||
random.seed(seed)
|
model_name: str,
|
||||||
np.random.seed(seed)
|
block_size: int,
|
||||||
torch.manual_seed(seed)
|
dtype: Union[torch.dtype, str],
|
||||||
if torch.cuda.is_available():
|
) -> CacheFlowMemoryAnalyzer:
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch_dtype = get_torch_dtype(dtype)
|
||||||
|
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
|
||||||
|
if model_class in model_name:
|
||||||
|
return memory_analyzer(
|
||||||
|
model_name, block_size, torch_dtype)
|
||||||
|
raise ValueError(f'Unsupported model name: {model_name}')
|
||||||
|
|||||||
43
cacheflow/models/utils.py
Normal file
43
cacheflow/models/utils.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
|
||||||
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
'half': torch.half,
|
||||||
|
'float': torch.float,
|
||||||
|
'float16': torch.float16,
|
||||||
|
'float32': torch.float32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
|
||||||
|
if isinstance(dtype, str):
|
||||||
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
|
||||||
|
else:
|
||||||
|
torch_dtype = dtype
|
||||||
|
return torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
||||||
|
torch_dtype = get_torch_dtype(dtype)
|
||||||
|
return torch.tensor([], dtype=torch_dtype).element_size()
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed: int) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_memory(gpu: int = 0) -> int:
|
||||||
|
return torch.cuda.get_device_properties(gpu).total_memory
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_memory() -> int:
|
||||||
|
return psutil.virtual_memory().total
|
||||||
24
server.py
24
server.py
@ -3,6 +3,7 @@ from typing import List
|
|||||||
|
|
||||||
from cacheflow.master.frontend import Frontend
|
from cacheflow.master.frontend import Frontend
|
||||||
from cacheflow.master.scheduler import Scheduler
|
from cacheflow.master.scheduler import Scheduler
|
||||||
|
from cacheflow.models import get_memory_analyzer
|
||||||
from cacheflow.worker.controller import Controller
|
from cacheflow.worker.controller import Controller
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='CacheFlow server')
|
parser = argparse.ArgumentParser(description='CacheFlow server')
|
||||||
@ -10,17 +11,25 @@ parser.add_argument('--model', type=str, default='facebook/opt-125m', help='mode
|
|||||||
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
|
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
|
||||||
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
|
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
|
||||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
||||||
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
|
|
||||||
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
|
|
||||||
parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)')
|
|
||||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
|
parser.add_argument('--max-batch-size', type=int, default=2048, help='maximum number of batched tokens')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
memory_analyzer = get_memory_analyzer(
|
||||||
|
model_name=args.model,
|
||||||
|
block_size=args.block_size,
|
||||||
|
dtype=args.dtype,
|
||||||
|
)
|
||||||
|
num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks(
|
||||||
|
max_num_batched_tokens=args.max_batch_size)
|
||||||
|
num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks()
|
||||||
|
print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}')
|
||||||
|
|
||||||
# Create a controller for each node.
|
# Create a controller for each node.
|
||||||
controllers: List[Controller] = []
|
controllers: List[Controller] = []
|
||||||
for i in range(args.num_nodes):
|
for i in range(args.num_nodes):
|
||||||
@ -29,8 +38,8 @@ def main():
|
|||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
model_name=args.model,
|
model_name=args.model,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
num_cpu_blocks=args.num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
@ -47,8 +56,9 @@ def main():
|
|||||||
frontend=frontend,
|
frontend=frontend,
|
||||||
controllers=controllers,
|
controllers=controllers,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
num_cpu_blocks=args.num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
|
max_num_batched_tokens=args.max_batch_size,
|
||||||
)
|
)
|
||||||
# Connect the controllers.
|
# Connect the controllers.
|
||||||
for i in range(len(controllers) - 1):
|
for i in range(len(controllers) - 1):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user