This commit is contained in:
Woosuk Kwon 2023-02-23 20:23:47 +00:00
parent 7f985166f7
commit fdd0f2f472
3 changed files with 4 additions and 3 deletions

View File

@ -216,7 +216,7 @@ class Scheduler:
self.block_manager.fork(parent_seq, seq)
# Append a new token to the sequence.
seq.append(next_token)
seq.append([next_token])
# Check if the sequence has generated a stop token.
if next_token in stop_token_ids:

View File

@ -13,7 +13,7 @@ class Sampler(nn.Module):
embedding: torch.Tensor,
) -> None:
super().__init__()
self.embedding = embedding.t() # [hidden_size, vocab_size]
self.embedding = embedding # [vocab_size, hidden_size]
def forward(
self,
@ -31,7 +31,7 @@ class Sampler(nn.Module):
hidden_states = hidden_states[last_token_indicies]
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, self.embedding)
logits = torch.matmul(hidden_states, self.embedding.t())
# Sample the next tokens.
# TODO(woosuk): Implement other sampling methods.

View File

@ -165,6 +165,7 @@ class Worker:
output = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=self.gpu_cache,
input_metadata=input_metadata,
cache_events=cache_events,
)