[Doc] Consistent naming of attention backends (#9498)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2024-10-21 16:29:57 +02:00 committed by GitHub
parent 696b01af8f
commit 496e991da8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 23 additions and 19 deletions

View File

@ -32,7 +32,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flash-attn"
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:

View File

@ -40,7 +40,7 @@ class FlashInferBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flashinfer"
return "FLASHINFER"
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:

View File

@ -19,7 +19,7 @@ class IpexAttnBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ipex-attn"
return "IPEX"
@staticmethod
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:

View File

@ -38,7 +38,7 @@ class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "openvino"
return "OPENVINO"
@staticmethod
def get_impl_cls():

View File

@ -11,6 +11,10 @@ from vllm.attention.backends.utils import CommonAttentionState
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS"
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl

View File

@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "placeholder-attn"
return "NO_ATTENTION"
@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:

View File

@ -28,7 +28,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "rocm-flash-attn"
return "ROCM_FLASH"
@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:

View File

@ -25,7 +25,7 @@ class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "torch-sdpa"
return "TORCH_SDPA"
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:

View File

@ -317,8 +317,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
@ -337,8 +337,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
@ -356,8 +356,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)

View File

@ -24,7 +24,7 @@ class XFormersBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "xformers"
return "XFORMERS"
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:

View File

@ -179,7 +179,7 @@ class TP1DraftModelRunner(ModelRunner):
return False
# TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn":
if self.attn_backend.get_name() != "FLASH_ATTN":
return False
# TODO: Add support for LORA

View File

@ -184,7 +184,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name(
) != "flash-attn":
) != "FLASH_ATTN":
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "

View File

@ -1855,7 +1855,7 @@ class CUDAGraphRunner(nn.Module):
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
if self.backend_name != "placeholder-attn":
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(
attn_metadata.slot_mapping, non_blocking=True)

View File

@ -29,8 +29,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"]
MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
-> List[str]: