[Model] Fix and clean commandr (#3671)
This commit is contained in:
parent
6d9aa00fc4
commit
10e6322283
@ -26,7 +26,6 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CohereConfig
|
from transformers import CohereConfig
|
||||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -46,8 +45,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
|||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
|
|
||||||
@ -70,9 +67,6 @@ class LayerNorm(nn.Module):
|
|||||||
return hidden_states.to(input_dtype), residuals
|
return hidden_states.to(input_dtype), residuals
|
||||||
|
|
||||||
|
|
||||||
ALL_LAYERNORM_LAYERS.append(LayerNorm)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
||||||
class CohereMLP(nn.Module):
|
class CohereMLP(nn.Module):
|
||||||
|
|
||||||
@ -137,7 +131,6 @@ class CohereAttention(nn.Module):
|
|||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
self.is_causal = True
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -171,7 +164,7 @@ class CohereAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: KVCache,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
@ -200,7 +193,7 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
kv_cache: KVCache,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -242,7 +235,7 @@ class CohereModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[KVCache],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
@ -269,7 +262,6 @@ class CohereForCausalLM(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||||
scale=config.logit_scale)
|
scale=config.logit_scale)
|
||||||
@ -281,7 +273,7 @@ class CohereForCausalLM(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[KVCache],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user