[Bugfix] Fix incorrect output on OLMo models in Tensor Parallelism (#3869)

This commit is contained in:
Isotr0py 2024-04-06 03:02:09 +08:00 committed by GitHub
parent 18de883489
commit 54951ac4bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,14 +39,15 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
# this model must need this dependency # this model must need this dependency
from hf_olmo import OLMoConfig from hf_olmo import OLMoConfig
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -62,17 +63,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@property
def output_multiplier(self) -> float:
return 0.5
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
""" """
This is the attention block where the output is computed as This is the attention block where the output is computed as
@ -174,17 +164,16 @@ class OlmoMLP(nn.Module):
bias=False) bias=False)
# Feed-forward input projection. # Feed-forward input projection.
self.ff_proj = ColumnParallelLinear( self.ff_proj = MergedColumnParallelLinear(
config.d_model, config.d_model,
self.hidden_size, [self.hidden_size // 2] * 2,
bias=config.include_bias, bias=config.include_bias,
linear_method=linear_method, linear_method=linear_method,
) )
# Activation function. # Activation function.
# self.act = SiluAndMul() self.act = SiluAndMul()
# self.act.output_multiplier = 0.5 self.act.output_multiplier = 0.5
self.act = SwiGLU()
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Feed-forward output projection. # Feed-forward output projection.
@ -374,8 +363,12 @@ class OLMoForCausalLM(nn.Module):
if ".att" in name: if ".att" in name:
name = name.replace(".att", ".attn.att") name = name.replace(".att", ".attn.att")
# mlp # mlp
if ".ff" in name and "transformer.ff_out" not in name: if ".ff_proj" in name:
name = name.replace(".ff", ".mlp.ff") name = name.replace(".ff_proj", ".mlp.ff_proj")
# Reverse the weight for the MergeColumnParallelLinear
loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
if ".ff_out" in name and "transformer.ff_out" not in name:
name = name.replace(".ff_out", ".mlp.ff_out")
# there is no bias in olmo # there is no bias in olmo
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",