diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 42f4284c..c2ec4376 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,5 @@ from functools import lru_cache +from typing import Type import torch @@ -10,7 +11,7 @@ logger = init_logger(__name__) @lru_cache(maxsize=None) -def get_attn_backend(dtype: torch.dtype) -> AttentionBackend: +def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: if _can_use_flash_attn(dtype): logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401