diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 459f11d1..611a48a9 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -39,14 +39,15 @@ from typing import List, Optional, Tuple import torch -import torch.nn.functional as F # this model must need this dependency from hf_olmo import OLMoConfig from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -62,17 +63,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.sequence import SamplerOutput -class SwiGLU(nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - - @property - def output_multiplier(self) -> float: - return 0.5 - - class OlmoAttention(nn.Module): """ This is the attention block where the output is computed as @@ -174,17 +164,16 @@ class OlmoMLP(nn.Module): bias=False) # Feed-forward input projection. - self.ff_proj = ColumnParallelLinear( + self.ff_proj = MergedColumnParallelLinear( config.d_model, - self.hidden_size, + [self.hidden_size // 2] * 2, bias=config.include_bias, linear_method=linear_method, ) # Activation function. - # self.act = SiluAndMul() - # self.act.output_multiplier = 0.5 - self.act = SwiGLU() + self.act = SiluAndMul() + self.act.output_multiplier = 0.5 assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 # Feed-forward output projection. @@ -374,8 +363,12 @@ class OLMoForCausalLM(nn.Module): if ".att" in name: name = name.replace(".att", ".attn.att") # mlp - if ".ff" in name and "transformer.ff_out" not in name: - name = name.replace(".ff", ".mlp.ff") + if ".ff_proj" in name: + name = name.replace(".ff_proj", ".mlp.ff_proj") + # Reverse the weight for the MergeColumnParallelLinear + loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1]) + if ".ff_out" in name and "transformer.ff_out" not in name: + name = name.replace(".ff_out", ".mlp.ff_out") # there is no bias in olmo param = params_dict[name] weight_loader = getattr(param, "weight_loader",