[Doc] Consistent naming of attention backends (#9498)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
696b01af8f
commit
496e991da8
@ -32,7 +32,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "flash-attn"
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||
|
||||
@ -40,7 +40,7 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "flashinfer"
|
||||
return "FLASHINFER"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashInferImpl"]:
|
||||
|
||||
@ -19,7 +19,7 @@ class IpexAttnBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ipex-attn"
|
||||
return "IPEX"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
|
||||
|
||||
@ -38,7 +38,7 @@ class OpenVINOAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "openvino"
|
||||
return "OPENVINO"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls():
|
||||
|
||||
@ -11,6 +11,10 @@ from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "placeholder-attn"
|
||||
return "NO_ATTENTION"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
|
||||
|
||||
@ -28,7 +28,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "rocm-flash-attn"
|
||||
return "ROCM_FLASH"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
|
||||
|
||||
@ -25,7 +25,7 @@ class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "torch-sdpa"
|
||||
return "TORCH_SDPA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
||||
|
||||
@ -317,8 +317,8 @@ class CommonAttentionState(AttentionState):
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "xformers", \
|
||||
f"Expected attn_backend name to be 'xformers', but "\
|
||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
@ -337,8 +337,8 @@ class CommonAttentionState(AttentionState):
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "xformers", \
|
||||
f"Expected attn_backend name to be 'xformers', but "\
|
||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additonal_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
@ -356,8 +356,8 @@ class CommonAttentionState(AttentionState):
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "xformers", \
|
||||
f"Expected attn_backend name to be 'xformers', but "\
|
||||
assert self.runner.attn_backend.get_name() == "XFORMERS", \
|
||||
f"Expected attn_backend name to be 'XFORMERS', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
self._prepare_input_buffers_for_enc_dec_model(
|
||||
attn_metadata, input_buffers)
|
||||
|
||||
@ -24,7 +24,7 @@ class XFormersBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "xformers"
|
||||
return "XFORMERS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["XFormersImpl"]:
|
||||
|
||||
@ -179,7 +179,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
return False
|
||||
|
||||
# TODO: Add support for other attn backends
|
||||
if self.attn_backend.get_name() != "flash-attn":
|
||||
if self.attn_backend.get_name() != "FLASH_ATTN":
|
||||
return False
|
||||
|
||||
# TODO: Add support for LORA
|
||||
|
||||
@ -184,7 +184,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
if not disable_mqa_scorer:
|
||||
if scorer_worker.model_runner.attn_backend.get_name(
|
||||
) != "flash-attn":
|
||||
) != "FLASH_ATTN":
|
||||
disable_mqa_scorer = True
|
||||
logger.info(
|
||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||
|
||||
@ -1855,7 +1855,7 @@ class CUDAGraphRunner(nn.Module):
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
|
||||
if self.backend_name != "placeholder-attn":
|
||||
if self.backend_name != "NO_ATTENTION":
|
||||
self.input_buffers["slot_mapping"].copy_(
|
||||
attn_metadata.slot_mapping, non_blocking=True)
|
||||
|
||||
|
||||
@ -29,8 +29,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]
|
||||
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"]
|
||||
MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
|
||||
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]
|
||||
|
||||
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
|
||||
-> List[str]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user