[Neuron] Adding support for context-lenght, token-gen buckets. (#7885)

Co-authored-by: Harsha Bikki <harbikh@amazon.com>
This commit is contained in:
Harsha vardhan manoj Bikki 2024-08-29 13:58:14 -07:00 committed by GitHub
parent 86a677de42
commit 257afc37c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 11 deletions

View File

@ -1,5 +1,12 @@
import os
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
@ -19,8 +26,8 @@ llm = LLM(
# Currently, this is a known limitation in continuous batching support # Currently, this is a known limitation in continuous batching support
# in transformers-neuronx. # in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx. # TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128, max_model_len=2048,
block_size=128, block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed. # The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection, # The device argument can be either unspecified for automated detection,
# or explicitly assigned. # or explicitly assigned.

View File

@ -1,7 +1,7 @@
"""Utilities for selecting and loading neuron models.""" """Utilities for selecting and loading neuron models."""
import importlib import importlib
import os import os
from typing import Dict, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -109,6 +109,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
f"{list(_NEURON_SUPPORTED_MODELS.keys())}") f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list
def get_neuron_model(model_config: ModelConfig, def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig) -> nn.Module:
@ -123,14 +134,18 @@ def get_neuron_model(model_config: ModelConfig,
neuron_config = NeuronConfig( neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config) continuous_batching=continuous_batching_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights( model.load_weights(model_config.model,
model_config.model, tp_degree=parallel_config.tensor_parallel_size,
tp_degree=parallel_config.tensor_parallel_size, amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], neuron_config=neuron_config,
neuron_config=neuron_config, context_length_estimate=context_length_estimates,
context_length_estimate=[scheduler_config.max_model_len], n_positions=n_positions,
n_positions=[scheduler_config.max_model_len], batch_size=scheduler_config.max_num_seqs)
batch_size=scheduler_config.max_num_seqs)
return model.eval() return model.eval()