[Misc] Add get_name method to attention backends (#4685)

This commit is contained in:
Woosuk Kwon 2024-05-08 09:59:31 -07:00 committed by GitHub
parent 0f9a6e3d22
commit 5510cf0e8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 30 additions and 12 deletions

View File

@ -9,6 +9,11 @@ import torch
class AttentionBackend(ABC): class AttentionBackend(ABC):
"""Abstract class for attention backends.""" """Abstract class for attention backends."""
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]: def get_impl_cls() -> Type["AttentionImpl"]:

View File

@ -19,6 +19,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flash-attn"
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl return FlashAttentionImpl

View File

@ -1,16 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type from typing import Any, Dict, List, Optional, Set, Tuple, Type
try: import flashinfer
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
import torch import torch
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@ -20,6 +14,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flashinfer"
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]: def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl return FlashInferImpl

View File

@ -17,6 +17,10 @@ logger = init_logger(__name__)
class ROCmFlashAttentionBackend(AttentionBackend): class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "rocm-flash-attn"
@staticmethod @staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl return ROCmFlashAttentionImpl

View File

@ -15,6 +15,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class TorchSDPABackend(AttentionBackend): class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "torch-sdpa"
@staticmethod @staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]: def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl return TorchSDPABackendImpl

View File

@ -20,6 +20,10 @@ logger = init_logger(__name__)
class XFormersBackend(AttentionBackend): class XFormersBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "xformers"
@staticmethod @staticmethod
def get_impl_cls() -> Type["XFormersImpl"]: def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl return XFormersImpl

View File

@ -9,7 +9,6 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend) get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
@ -395,7 +394,7 @@ class ModelRunner:
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
out=seq_start_loc[1:]) out=seq_start_loc[1:])
if self.attn_backend is FlashInferBackend: if self.attn_backend.get_name() == "flashinfer":
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
use_cuda_graph=False, use_cuda_graph=False,
@ -556,7 +555,7 @@ class ModelRunner:
device=self.device, device=self.device,
) )
if self.attn_backend is FlashInferBackend: if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"): if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer # Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html