diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 37e73b21..9f2c6a49 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -216,7 +216,7 @@ class Scheduler: self.block_manager.fork(parent_seq, seq) # Append a new token to the sequence. - seq.append(next_token) + seq.append([next_token]) # Check if the sequence has generated a stop token. if next_token in stop_token_ids: diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 321df607..d75e5011 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -13,7 +13,7 @@ class Sampler(nn.Module): embedding: torch.Tensor, ) -> None: super().__init__() - self.embedding = embedding.t() # [hidden_size, vocab_size] + self.embedding = embedding # [vocab_size, hidden_size] def forward( self, @@ -31,7 +31,7 @@ class Sampler(nn.Module): hidden_states = hidden_states[last_token_indicies] # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, self.embedding) + logits = torch.matmul(hidden_states, self.embedding.t()) # Sample the next tokens. # TODO(woosuk): Implement other sampling methods. diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 52a84128..3a0600c2 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -165,6 +165,7 @@ class Worker: output = self.model( input_ids=input_tokens, positions=input_positions, + kv_caches=self.gpu_cache, input_metadata=input_metadata, cache_events=cache_events, )