[Model] Fix and clean commandr (#3671)

This commit is contained in:
Roy 2024-03-28 08:20:00 +08:00 committed by GitHub
parent 6d9aa00fc4
commit 10e6322283
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,7 +26,6 @@ import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from vllm.attention import Attention, AttentionMetadata
from vllm.model_executor.layers.activation import SiluAndMul
@ -46,8 +45,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LayerNorm(nn.Module):
@ -70,9 +67,6 @@ class LayerNorm(nn.Module):
return hidden_states.to(input_dtype), residuals
ALL_LAYERNORM_LAYERS.append(LayerNorm)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
@ -137,7 +131,6 @@ class CohereAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None)
self.is_causal = True
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
@ -171,7 +164,7 @@ class CohereAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
@ -200,7 +193,7 @@ class CohereDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -242,7 +235,7 @@ class CohereModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
@ -269,7 +262,6 @@ class CohereForCausalLM(nn.Module):
) -> None:
super().__init__()
self.config = config
self.unpadded_vocab_size = config.vocab_size
self.linear_method = linear_method
self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale)
@ -281,7 +273,7 @@ class CohereForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,