From de0fabbc5c84e6771d70b92014ae06fe82654ff0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 23 Feb 2023 20:30:12 +0000 Subject: [PATCH] Fix sampler --- cacheflow/models/opt.py | 5 +++-- cacheflow/models/sample.py | 9 +++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 5637e852..422ec263 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -227,7 +227,7 @@ class OPTForCausalLM(OPTPreTrainedModel): self.model = OPTModel(config) # the lm_head weight is automatically tied to the embed tokens weight self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) - self.sampler = Sampler(embedding=self.lm_head.weight) + self.sampler = Sampler() # Initialize weights and apply final processing self.post_init() @@ -242,5 +242,6 @@ class OPTForCausalLM(OPTPreTrainedModel): ) -> Dict[int, Tuple[int, int]]: hidden_states = self.model( input_ids, positions, kv_caches, input_metadata, cache_events) - next_tokens = self.sampler(hidden_states, input_metadata) + next_tokens = self.sampler( + self.lm_head.weight, hidden_states, input_metadata) return next_tokens diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index d75e5011..5c984d39 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -8,15 +8,12 @@ from cacheflow.models import InputMetadata class Sampler(nn.Module): - def __init__( - self, - embedding: torch.Tensor, - ) -> None: + def __init__(self) -> None: super().__init__() - self.embedding = embedding # [vocab_size, hidden_size] def forward( self, + embedding: torch.Tensor, hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> Dict[int, Tuple[int, int]]: @@ -31,7 +28,7 @@ class Sampler(nn.Module): hidden_states = hidden_states[last_token_indicies] # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, self.embedding.t()) + logits = torch.matmul(hidden_states, embedding.t()) # Sample the next tokens. # TODO(woosuk): Implement other sampling methods.