Support Mistral Model Inference with transformers-neuronx (#3153)
This commit is contained in:
parent
c9415c19d3
commit
654865e21d
10
examples/offline_inference_neuron.py
Normal file → Executable file
10
examples/offline_inference_neuron.py
Normal file → Executable file
@ -14,14 +14,16 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="openlm-research/open_llama_3b",
|
model="openlm-research/open_llama_3b",
|
||||||
max_num_seqs=8,
|
max_num_seqs=8,
|
||||||
# The max_model_len and block_size arguments are required to be same as max sequence length,
|
# The max_model_len and block_size arguments are required to be same as
|
||||||
# when targeting neuron device. Currently, this is a known limitation in continuous batching
|
# max sequence length when targeting neuron device.
|
||||||
# support in transformers-neuronx.
|
# Currently, this is a known limitation in continuous batching support
|
||||||
|
# in transformers-neuronx.
|
||||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||||
max_model_len=128,
|
max_model_len=128,
|
||||||
block_size=128,
|
block_size=128,
|
||||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||||
# The device argument can be either unspecified for automated detection, or explicitly assigned.
|
# The device argument can be either unspecified for automated detection,
|
||||||
|
# or explicitly assigned.
|
||||||
device="neuron")
|
device="neuron")
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
|
|||||||
7
vllm/model_executor/models/__init__.py
Normal file → Executable file
7
vllm/model_executor/models/__init__.py
Normal file → Executable file
@ -62,8 +62,11 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
|||||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Models not supported by Neuron.
|
# Models supported by Neuron.
|
||||||
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
|
_NEURON_SUPPORTED_MODELS = {
|
||||||
|
"LlamaForCausalLM": "neuron.llama",
|
||||||
|
"MistralForCausalLM": "neuron.mistral"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistry:
|
class ModelRegistry:
|
||||||
|
|||||||
82
vllm/model_executor/models/neuron/mistral.py
Executable file
82
vllm/model_executor/models/neuron/mistral.py
Executable file
@ -0,0 +1,82 @@
|
|||||||
|
"""Inference-only Mistral model compatible with HuggingFace weights."""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import MistralConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
import os
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class MistralForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MistralConfig,
|
||||||
|
linear_method=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.linear_method = linear_method
|
||||||
|
self.model = None
|
||||||
|
self.lm_head = None
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
with torch.inference_mode():
|
||||||
|
seq_ids = []
|
||||||
|
block_size = self.model.context_buckets[-1]
|
||||||
|
if input_metadata.is_prompt:
|
||||||
|
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
|
||||||
|
else:
|
||||||
|
seq_ids = input_metadata.block_tables
|
||||||
|
|
||||||
|
logits = self.model(input_ids,
|
||||||
|
cache_ids=positions,
|
||||||
|
start_ids=seq_ids)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
|
||||||
|
hidden_states, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
from transformers_neuronx.mistral.model import MistralForSampling
|
||||||
|
|
||||||
|
split_model_dir = f"{model_name_or_path}-split"
|
||||||
|
if os.path.isdir(os.path.join(model_name_or_path,
|
||||||
|
"pytorch_model.bin")):
|
||||||
|
split_model_dir = model_name_or_path
|
||||||
|
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||||
|
from transformers import MistralForCausalLM
|
||||||
|
from transformers_neuronx.module import save_pretrained_split
|
||||||
|
|
||||||
|
hf_model = MistralForCausalLM.from_pretrained(
|
||||||
|
model_name_or_path, low_cpu_mem_usage=True)
|
||||||
|
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
|
||||||
|
|
||||||
|
self.model = MistralForSampling.from_pretrained(
|
||||||
|
split_model_dir, **kwargs)
|
||||||
|
self.model.to_neuron()
|
||||||
Loading…
Reference in New Issue
Block a user