From b98d89efd4b1a09c11c4d0cf30c9af0e93514764 Mon Sep 17 00:00:00 2001 From: Sky Lee <46676799+skylee-01@users.noreply.github.com> Date: Sun, 17 Nov 2024 00:33:01 +0800 Subject: [PATCH] [Misc] Medusa supports custom bias (#10361) --- vllm/model_executor/models/medusa.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index de5b2d89..b05360b5 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -14,11 +14,14 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata class ResidualBlock(nn.Module): - def __init__(self, hidden_size: int, num_layers: int) -> None: + def __init__(self, config: VllmConfig, hidden_size: int, + num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList([ - nn.Linear(hidden_size, hidden_size, bias=False) + nn.Linear(hidden_size, + hidden_size, + bias=getattr(config, "medusa_fc_bias", False)) for _ in range(num_layers) ]) self.act = nn.SiLU() @@ -49,7 +52,8 @@ class Medusa(nn.Module): super().__init__() self.config = config self.blocks = nn.ModuleList([ - ResidualBlock(hidden_size=self.config.hidden_size, + ResidualBlock(config=config, + hidden_size=self.config.hidden_size, num_layers=self.config.num_hidden_layers) for _ in range(self.config.num_heads) ])