[TPU] Async output processing for TPU (#8011)
This commit is contained in:
parent
428dd1445e
commit
80c7b089b1
@ -347,10 +347,10 @@ class ModelConfig:
|
|||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
return
|
return
|
||||||
|
|
||||||
if device_config.device_type != "cuda":
|
if device_config.device_type not in ("cuda", "tpu"):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Async output processing is only supported for CUDA."
|
"Async output processing is only supported for CUDA or TPU. "
|
||||||
" Disabling it for other platforms.")
|
"Disabling it for other platforms.")
|
||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||||
|
Type, Union)
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -51,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
|
|||||||
best_of: List[int]
|
best_of: List[int]
|
||||||
seq_groups: List[List[int]]
|
seq_groups: List[List[int]]
|
||||||
virtual_engine: int = 0
|
virtual_engine: int = 0
|
||||||
|
async_callback: Optional[Callable] = None
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(
|
def as_broadcastable_tensor_dict(
|
||||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
@ -562,6 +564,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
model_input.attn_metadata, model_input.input_lens[i:i + 1],
|
model_input.attn_metadata, model_input.input_lens[i:i + 1],
|
||||||
model_input.t[i:i + 1], model_input.p[i:i + 1],
|
model_input.t[i:i + 1], model_input.p[i:i + 1],
|
||||||
model_input.num_samples, kv_caches)
|
model_input.num_samples, kv_caches)
|
||||||
|
if i == 0 and model_input.async_callback is not None:
|
||||||
|
model_input.async_callback()
|
||||||
# Retrieve the outputs to CPU.
|
# Retrieve the outputs to CPU.
|
||||||
next_token_ids += output_token_ids.cpu().tolist()
|
next_token_ids += output_token_ids.cpu().tolist()
|
||||||
start_idx = end_idx
|
start_idx = end_idx
|
||||||
@ -572,6 +576,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
model_input.attn_metadata, model_input.input_lens,
|
model_input.attn_metadata, model_input.input_lens,
|
||||||
model_input.t, model_input.p, model_input.num_samples,
|
model_input.t, model_input.p, model_input.num_samples,
|
||||||
kv_caches)
|
kv_caches)
|
||||||
|
if model_input.async_callback is not None:
|
||||||
|
model_input.async_callback()
|
||||||
# Retrieve the outputs to CPU.
|
# Retrieve the outputs to CPU.
|
||||||
next_token_ids = output_token_ids.cpu().tolist()
|
next_token_ids = output_token_ids.cpu().tolist()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user