[Bugfix] Fix incorrect output on OLMo models in Tensor Parallelism (#3869)
This commit is contained in:
parent
18de883489
commit
54951ac4bf
@ -39,14 +39,15 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
# this model must need this dependency
|
# this model must need this dependency
|
||||||
from hf_olmo import OLMoConfig
|
from hf_olmo import OLMoConfig
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@ -62,17 +63,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
|||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
class SwiGLU(nn.Module):
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
x, gate = x.chunk(2, dim=-1)
|
|
||||||
return F.silu(gate) * x
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_multiplier(self) -> float:
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class OlmoAttention(nn.Module):
|
class OlmoAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
This is the attention block where the output is computed as
|
This is the attention block where the output is computed as
|
||||||
@ -174,17 +164,16 @@ class OlmoMLP(nn.Module):
|
|||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
# Feed-forward input projection.
|
# Feed-forward input projection.
|
||||||
self.ff_proj = ColumnParallelLinear(
|
self.ff_proj = MergedColumnParallelLinear(
|
||||||
config.d_model,
|
config.d_model,
|
||||||
self.hidden_size,
|
[self.hidden_size // 2] * 2,
|
||||||
bias=config.include_bias,
|
bias=config.include_bias,
|
||||||
linear_method=linear_method,
|
linear_method=linear_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Activation function.
|
# Activation function.
|
||||||
# self.act = SiluAndMul()
|
self.act = SiluAndMul()
|
||||||
# self.act.output_multiplier = 0.5
|
self.act.output_multiplier = 0.5
|
||||||
self.act = SwiGLU()
|
|
||||||
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
||||||
|
|
||||||
# Feed-forward output projection.
|
# Feed-forward output projection.
|
||||||
@ -374,8 +363,12 @@ class OLMoForCausalLM(nn.Module):
|
|||||||
if ".att" in name:
|
if ".att" in name:
|
||||||
name = name.replace(".att", ".attn.att")
|
name = name.replace(".att", ".attn.att")
|
||||||
# mlp
|
# mlp
|
||||||
if ".ff" in name and "transformer.ff_out" not in name:
|
if ".ff_proj" in name:
|
||||||
name = name.replace(".ff", ".mlp.ff")
|
name = name.replace(".ff_proj", ".mlp.ff_proj")
|
||||||
|
# Reverse the weight for the MergeColumnParallelLinear
|
||||||
|
loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
|
||||||
|
if ".ff_out" in name and "transformer.ff_out" not in name:
|
||||||
|
name = name.replace(".ff_out", ".mlp.ff_out")
|
||||||
# there is no bias in olmo
|
# there is no bias in olmo
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user