[Misc] Add get_name method to attention backends (#4685)
This commit is contained in:
parent
0f9a6e3d22
commit
5510cf0e8a
@ -9,6 +9,11 @@ import torch
|
|||||||
class AttentionBackend(ABC):
|
class AttentionBackend(ABC):
|
||||||
"""Abstract class for attention backends."""
|
"""Abstract class for attention backends."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_impl_cls() -> Type["AttentionImpl"]:
|
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||||
|
|||||||
@ -19,6 +19,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
|||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "flash-attn"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||||
return FlashAttentionImpl
|
return FlashAttentionImpl
|
||||||
|
|||||||
@ -1,16 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
try:
|
import flashinfer
|
||||||
import flashinfer
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
|
||||||
except ImportError:
|
|
||||||
flashinfer = None
|
|
||||||
flash_attn_varlen_func = None
|
|
||||||
BatchDecodeWithPagedKVCacheWrapper = None
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
@ -20,6 +14,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
|
|
||||||
class FlashInferBackend(AttentionBackend):
|
class FlashInferBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "flashinfer"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["FlashInferImpl"]:
|
def get_impl_cls() -> Type["FlashInferImpl"]:
|
||||||
return FlashInferImpl
|
return FlashInferImpl
|
||||||
|
|||||||
@ -17,6 +17,10 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "rocm-flash-attn"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
|
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
|
||||||
return ROCmFlashAttentionImpl
|
return ROCmFlashAttentionImpl
|
||||||
|
|||||||
@ -15,6 +15,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
|||||||
|
|
||||||
class TorchSDPABackend(AttentionBackend):
|
class TorchSDPABackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "torch-sdpa"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
||||||
return TorchSDPABackendImpl
|
return TorchSDPABackendImpl
|
||||||
|
|||||||
@ -20,6 +20,10 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class XFormersBackend(AttentionBackend):
|
class XFormersBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "xformers"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["XFormersImpl"]:
|
def get_impl_cls() -> Type["XFormersImpl"]:
|
||||||
return XFormersImpl
|
return XFormersImpl
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||||
get_attn_backend)
|
get_attn_backend)
|
||||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
|
||||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
||||||
@ -395,7 +394,7 @@ class ModelRunner:
|
|||||||
dtype=seq_start_loc.dtype,
|
dtype=seq_start_loc.dtype,
|
||||||
out=seq_start_loc[1:])
|
out=seq_start_loc[1:])
|
||||||
|
|
||||||
if self.attn_backend is FlashInferBackend:
|
if self.attn_backend.get_name() == "flashinfer":
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
@ -556,7 +555,7 @@ class ModelRunner:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.attn_backend is FlashInferBackend:
|
if self.attn_backend.get_name() == "flashinfer":
|
||||||
if not hasattr(self, "flashinfer_workspace_buffer"):
|
if not hasattr(self, "flashinfer_workspace_buffer"):
|
||||||
# Allocate 16MB workspace buffer
|
# Allocate 16MB workspace buffer
|
||||||
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
|
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user