[Model][Bugfix] Support TP for PixtralHF ViT (#10405)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin 2024-11-18 13:04:14 -05:00 committed by GitHub
parent 4f686d139f
commit 281cc4b3cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,6 +17,7 @@ from transformers.models.pixtral.modeling_pixtral import (
from vllm.attention import AttentionMetadata
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_and_mul_fn
@ -843,17 +844,20 @@ class PixtralHFAttention(nn.Module):
self.config = config
assert not config.hidden_size % config.num_attention_heads
self.n_heads = config.num_attention_heads
self.total_num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.n_heads = divide(config.num_attention_heads, tp_size)
self.head_dim = config.hidden_size // config.num_attention_heads
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
total_num_heads=self.n_heads,
total_num_heads=self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
assert self.total_num_heads * self.head_dim == config.hidden_size
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,