[Neuron] Adding support for context-lenght, token-gen buckets. (#7885)
Co-authored-by: Harsha Bikki <harbikh@amazon.com>
This commit is contained in:
parent
86a677de42
commit
257afc37c5
@ -1,5 +1,12 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# creates XLA hlo graphs for all the context length buckets.
|
||||||
|
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
|
||||||
|
# creates XLA hlo graphs for all the token gen buckets.
|
||||||
|
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
|
||||||
|
|
||||||
# Sample prompts.
|
# Sample prompts.
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@ -19,8 +26,8 @@ llm = LLM(
|
|||||||
# Currently, this is a known limitation in continuous batching support
|
# Currently, this is a known limitation in continuous batching support
|
||||||
# in transformers-neuronx.
|
# in transformers-neuronx.
|
||||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||||
max_model_len=128,
|
max_model_len=2048,
|
||||||
block_size=128,
|
block_size=2048,
|
||||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||||
# The device argument can be either unspecified for automated detection,
|
# The device argument can be either unspecified for automated detection,
|
||||||
# or explicitly assigned.
|
# or explicitly assigned.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Utilities for selecting and loading neuron models."""
|
"""Utilities for selecting and loading neuron models."""
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -109,6 +109,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
|
|||||||
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
|
||||||
|
env_value = os.getenv(env)
|
||||||
|
if env_value is None:
|
||||||
|
return default_value
|
||||||
|
buckets_remove_empty = filter(
|
||||||
|
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
|
||||||
|
buckets_int = map(int, buckets_remove_empty)
|
||||||
|
buckets_list = list(buckets_int)
|
||||||
|
return buckets_list
|
||||||
|
|
||||||
|
|
||||||
def get_neuron_model(model_config: ModelConfig,
|
def get_neuron_model(model_config: ModelConfig,
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||||
@ -123,14 +134,18 @@ def get_neuron_model(model_config: ModelConfig,
|
|||||||
neuron_config = NeuronConfig(
|
neuron_config = NeuronConfig(
|
||||||
continuous_batching=continuous_batching_config)
|
continuous_batching=continuous_batching_config)
|
||||||
|
|
||||||
|
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||||
|
[scheduler_config.max_model_len])
|
||||||
|
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
||||||
|
[scheduler_config.max_model_len])
|
||||||
|
|
||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(
|
model.load_weights(model_config.model,
|
||||||
model_config.model,
|
tp_degree=parallel_config.tensor_parallel_size,
|
||||||
tp_degree=parallel_config.tensor_parallel_size,
|
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
neuron_config=neuron_config,
|
||||||
neuron_config=neuron_config,
|
context_length_estimate=context_length_estimates,
|
||||||
context_length_estimate=[scheduler_config.max_model_len],
|
n_positions=n_positions,
|
||||||
n_positions=[scheduler_config.max_model_len],
|
batch_size=scheduler_config.max_num_seqs)
|
||||||
batch_size=scheduler_config.max_num_seqs)
|
|
||||||
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user