[Model] Fix and clean commandr (#3671)
This commit is contained in:
parent
6d9aa00fc4
commit
10e6322283
@ -26,7 +26,6 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import CohereConfig
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -46,8 +45,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
@ -70,9 +67,6 @@ class LayerNorm(nn.Module):
|
||||
return hidden_states.to(input_dtype), residuals
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS.append(LayerNorm)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
||||
class CohereMLP(nn.Module):
|
||||
|
||||
@ -137,7 +131,6 @@ class CohereAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.is_causal = True
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
@ -171,7 +164,7 @@ class CohereAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
@ -200,7 +193,7 @@ class CohereDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -242,7 +235,7 @@ class CohereModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@ -269,7 +262,6 @@ class CohereForCausalLM(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.linear_method = linear_method
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
scale=config.logit_scale)
|
||||
@ -281,7 +273,7 @@ class CohereForCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user