[Bugfix] Fix incorrect output on OLMo models in Tensor Parallelism (#3869)

This commit is contained in:
Isotr0py 2024-04-06 03:02:09 +08:00 committed by GitHub
parent 18de883489
commit 54951ac4bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,14 +39,15 @@
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
# this model must need this dependency
from hf_olmo import OLMoConfig
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -62,17 +63,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@property
def output_multiplier(self) -> float:
return 0.5
class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as
@ -174,17 +164,16 @@ class OlmoMLP(nn.Module):
bias=False)
# Feed-forward input projection.
self.ff_proj = ColumnParallelLinear(
self.ff_proj = MergedColumnParallelLinear(
config.d_model,
self.hidden_size,
[self.hidden_size // 2] * 2,
bias=config.include_bias,
linear_method=linear_method,
)
# Activation function.
# self.act = SiluAndMul()
# self.act.output_multiplier = 0.5
self.act = SwiGLU()
self.act = SiluAndMul()
self.act.output_multiplier = 0.5
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Feed-forward output projection.
@ -374,8 +363,12 @@ class OLMoForCausalLM(nn.Module):
if ".att" in name:
name = name.replace(".att", ".attn.att")
# mlp
if ".ff" in name and "transformer.ff_out" not in name:
name = name.replace(".ff", ".mlp.ff")
if ".ff_proj" in name:
name = name.replace(".ff_proj", ".mlp.ff_proj")
# Reverse the weight for the MergeColumnParallelLinear
loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
if ".ff_out" in name and "transformer.ff_out" not in name:
name = name.replace(".ff_out", ".mlp.ff_out")
# there is no bias in olmo
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",