Support Mistral Model Inference with transformers-neuronx (#3153)
This commit is contained in:
parent
c9415c19d3
commit
654865e21d
10
examples/offline_inference_neuron.py
Normal file → Executable file
10
examples/offline_inference_neuron.py
Normal file → Executable file
@ -14,14 +14,16 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
llm = LLM(
|
||||
model="openlm-research/open_llama_3b",
|
||||
max_num_seqs=8,
|
||||
# The max_model_len and block_size arguments are required to be same as max sequence length,
|
||||
# when targeting neuron device. Currently, this is a known limitation in continuous batching
|
||||
# support in transformers-neuronx.
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
# max sequence length when targeting neuron device.
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=128,
|
||||
block_size=128,
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection, or explicitly assigned.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
|
||||
7
vllm/model_executor/models/__init__.py
Normal file → Executable file
7
vllm/model_executor/models/__init__.py
Normal file → Executable file
@ -62,8 +62,11 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
}
|
||||
|
||||
# Models not supported by Neuron.
|
||||
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
|
||||
# Models supported by Neuron.
|
||||
_NEURON_SUPPORTED_MODELS = {
|
||||
"LlamaForCausalLM": "neuron.llama",
|
||||
"MistralForCausalLM": "neuron.mistral"
|
||||
}
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
|
||||
82
vllm/model_executor/models/neuron/mistral.py
Executable file
82
vllm/model_executor/models/neuron/mistral.py
Executable file
@ -0,0 +1,82 @@
|
||||
"""Inference-only Mistral model compatible with HuggingFace weights."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MistralConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
import os
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class MistralForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
linear_method=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = None
|
||||
self.lm_head = None
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
) -> SamplerOutput:
|
||||
with torch.inference_mode():
|
||||
seq_ids = []
|
||||
block_size = self.model.context_buckets[-1]
|
||||
if input_metadata.is_prompt:
|
||||
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
|
||||
else:
|
||||
seq_ids = input_metadata.block_tables
|
||||
|
||||
logits = self.model(input_ids,
|
||||
cache_ids=positions,
|
||||
start_ids=seq_ids)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
|
||||
hidden_states, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
**kwargs):
|
||||
from transformers_neuronx.mistral.model import MistralForSampling
|
||||
|
||||
split_model_dir = f"{model_name_or_path}-split"
|
||||
if os.path.isdir(os.path.join(model_name_or_path,
|
||||
"pytorch_model.bin")):
|
||||
split_model_dir = model_name_or_path
|
||||
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||
from transformers import MistralForCausalLM
|
||||
from transformers_neuronx.module import save_pretrained_split
|
||||
|
||||
hf_model = MistralForCausalLM.from_pretrained(
|
||||
model_name_or_path, low_cpu_mem_usage=True)
|
||||
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
|
||||
|
||||
self.model = MistralForSampling.from_pretrained(
|
||||
split_model_dir, **kwargs)
|
||||
self.model.to_neuron()
|
||||
Loading…
Reference in New Issue
Block a user