Support Mistral Model Inference with transformers-neuronx (#3153)

This commit is contained in:
DAIZHENWEI 2024-03-11 13:19:51 -07:00 committed by GitHub
parent c9415c19d3
commit 654865e21d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 6 deletions

10
examples/offline_inference_neuron.py Normal file → Executable file
View File

@ -14,14 +14,16 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="openlm-research/open_llama_3b",
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as max sequence length,
# when targeting neuron device. Currently, this is a known limitation in continuous batching
# support in transformers-neuronx.
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128,
block_size=128,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection, or explicitly assigned.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

7
vllm/model_executor/models/__init__.py Normal file → Executable file
View File

@ -62,8 +62,11 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Sliding window attention is not yet supported in ROCm's flash attention",
}
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = {
"LlamaForCausalLM": "neuron.llama",
"MistralForCausalLM": "neuron.mistral"
}
class ModelRegistry:

View File

@ -0,0 +1,82 @@
"""Inference-only Mistral model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
import os
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MistralForCausalLM(nn.Module):
def __init__(
self,
config: MistralConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.lm_head = None
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> SamplerOutput:
with torch.inference_mode():
seq_ids = []
block_size = self.model.context_buckets[-1]
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables
logits = self.model(input_ids,
cache_ids=positions,
start_ids=seq_ids)
return logits
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.mistral.model import MistralForSampling
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
from transformers import MistralForCausalLM
from transformers_neuronx.module import save_pretrained_split
hf_model = MistralForCausalLM.from_pretrained(
model_name_or_path, low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = MistralForSampling.from_pretrained(
split_model_dir, **kwargs)
self.model.to_neuron()