[Bugfix] Fix incorrect output on OLMo models in Tensor Parallelism (#3869)
This commit is contained in:
parent
18de883489
commit
54951ac4bf
@ -39,14 +39,15 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
# this model must need this dependency
|
||||
from hf_olmo import OLMoConfig
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -62,17 +63,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.silu(gate) * x
|
||||
|
||||
@property
|
||||
def output_multiplier(self) -> float:
|
||||
return 0.5
|
||||
|
||||
|
||||
class OlmoAttention(nn.Module):
|
||||
"""
|
||||
This is the attention block where the output is computed as
|
||||
@ -174,17 +164,16 @@ class OlmoMLP(nn.Module):
|
||||
bias=False)
|
||||
|
||||
# Feed-forward input projection.
|
||||
self.ff_proj = ColumnParallelLinear(
|
||||
self.ff_proj = MergedColumnParallelLinear(
|
||||
config.d_model,
|
||||
self.hidden_size,
|
||||
[self.hidden_size // 2] * 2,
|
||||
bias=config.include_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
# Activation function.
|
||||
# self.act = SiluAndMul()
|
||||
# self.act.output_multiplier = 0.5
|
||||
self.act = SwiGLU()
|
||||
self.act = SiluAndMul()
|
||||
self.act.output_multiplier = 0.5
|
||||
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
||||
|
||||
# Feed-forward output projection.
|
||||
@ -374,8 +363,12 @@ class OLMoForCausalLM(nn.Module):
|
||||
if ".att" in name:
|
||||
name = name.replace(".att", ".attn.att")
|
||||
# mlp
|
||||
if ".ff" in name and "transformer.ff_out" not in name:
|
||||
name = name.replace(".ff", ".mlp.ff")
|
||||
if ".ff_proj" in name:
|
||||
name = name.replace(".ff_proj", ".mlp.ff_proj")
|
||||
# Reverse the weight for the MergeColumnParallelLinear
|
||||
loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
|
||||
if ".ff_out" in name and "transformer.ff_out" not in name:
|
||||
name = name.replace(".ff_out", ".mlp.ff_out")
|
||||
# there is no bias in olmo
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user