Fix 1D query issue from _prune_hidden_states (#3539)
This commit is contained in:
parent
6ebd02bdef
commit
3bbff9e5ab
@ -77,7 +77,6 @@ def _prune_hidden_states(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
||||||
return hidden_states.index_select(0,
|
return hidden_states.index_select(0,
|
||||||
sampling_metadata.selected_token_indices)
|
sampling_metadata.selected_token_indices)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user