[Kernel] Support sliding window in flash attention backend (#9403)
This commit is contained in:
parent
962d2c6349
commit
4fa3e33349
@ -20,21 +20,21 @@ def test_env(name: str, device: str, monkeypatch):
|
|||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.attention.selector.is_cpu", return_value=True):
|
with patch("vllm.attention.selector.is_cpu", return_value=True):
|
||||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
|
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||||
16, False)
|
False)
|
||||||
assert backend.name == "TORCH_SDPA"
|
assert backend.name == "TORCH_SDPA"
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
with patch("vllm.attention.selector.is_hip", return_value=True):
|
with patch("vllm.attention.selector.is_hip", return_value=True):
|
||||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
|
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||||
16, False)
|
False)
|
||||||
assert backend.name == "ROCM_FLASH"
|
assert backend.name == "ROCM_FLASH"
|
||||||
elif device == "openvino":
|
elif device == "openvino":
|
||||||
with patch("vllm.attention.selector.is_openvino", return_value=True):
|
with patch("vllm.attention.selector.is_openvino", return_value=True):
|
||||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
|
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||||
16, False)
|
False)
|
||||||
assert backend.name == "OPENVINO"
|
assert backend.name == "OPENVINO"
|
||||||
else:
|
else:
|
||||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
|
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||||
False)
|
False)
|
||||||
assert backend.name == name
|
assert backend.name == name
|
||||||
|
|
||||||
@ -46,37 +46,32 @@ def test_flash_attn(monkeypatch):
|
|||||||
|
|
||||||
# Unsupported CUDA arch
|
# Unsupported CUDA arch
|
||||||
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
||||||
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
|
backend = which_attn_to_use(16, torch.float16, None, 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported data type
|
# Unsupported data type
|
||||||
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
|
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported kv cache data type
|
# Unsupported kv cache data type
|
||||||
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
|
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported block size
|
# Unsupported block size
|
||||||
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
|
backend = which_attn_to_use(16, torch.float16, None, 8, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
|
||||||
|
|
||||||
# Unsupported sliding window
|
|
||||||
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
|
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# flash-attn is not installed
|
# flash-attn is not installed
|
||||||
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||||
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
|
backend = which_attn_to_use(16, torch.float16, None, 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported head size
|
# Unsupported head size
|
||||||
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
|
backend = which_attn_to_use(17, torch.float16, None, 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Attention-free models should bypass env and use PlaceholderAttention
|
# Attention-free models should bypass env and use PlaceholderAttention
|
||||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
|
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
|
||||||
True)
|
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.name != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
|
|
||||||
@ -84,4 +79,4 @@ def test_invalid_env(monkeypatch):
|
|||||||
"""Throw an exception if the backend name is invalid."""
|
"""Throw an exception if the backend name is invalid."""
|
||||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
which_attn_to_use(16, None, torch.float16, None, 16, False)
|
which_attn_to_use(16, torch.float16, None, 16, False)
|
||||||
|
|||||||
@ -78,6 +78,7 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flash_attn_with_paged_kv(
|
def test_flash_attn_with_paged_kv(
|
||||||
kv_lens: List[int],
|
kv_lens: List[int],
|
||||||
@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
soft_cap: Optional[float],
|
soft_cap: Optional[float],
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
sliding_window: Optional[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
seed_everything(0)
|
seed_everything(0)
|
||||||
@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
|
|||||||
assert num_query_heads % num_kv_heads == 0
|
assert num_query_heads % num_kv_heads == 0
|
||||||
max_kv_len = max(kv_lens)
|
max_kv_len = max(kv_lens)
|
||||||
scale = head_size**-0.5
|
scale = head_size**-0.5
|
||||||
|
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
|
||||||
|
(-1, -1))
|
||||||
|
|
||||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||||
key_cache = torch.randn(num_blocks,
|
key_cache = torch.randn(num_blocks,
|
||||||
@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv(
|
|||||||
block_table=block_tables,
|
block_table=block_tables,
|
||||||
cache_seqlens=kv_lens_tensor,
|
cache_seqlens=kv_lens_tensor,
|
||||||
softcap=soft_cap if soft_cap is not None else 0,
|
softcap=soft_cap if soft_cap is not None else 0,
|
||||||
|
window_size=window_size,
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
|
|
||||||
ref_output = ref_paged_attn(
|
ref_output = ref_paged_attn(query=query,
|
||||||
query=query,
|
key_cache=key_cache,
|
||||||
key_cache=key_cache,
|
value_cache=value_cache,
|
||||||
value_cache=value_cache,
|
query_lens=[1] * num_seqs,
|
||||||
query_lens=[1] * num_seqs,
|
kv_lens=kv_lens,
|
||||||
kv_lens=kv_lens,
|
block_tables=block_tables,
|
||||||
block_tables=block_tables,
|
scale=scale,
|
||||||
scale=scale,
|
soft_cap=soft_cap,
|
||||||
soft_cap=soft_cap,
|
sliding_window=sliding_window)
|
||||||
)
|
|
||||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
|
|||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("sliding_window", [None])
|
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
|
|||||||
assert num_query_heads % num_kv_heads == 0
|
assert num_query_heads % num_kv_heads == 0
|
||||||
max_query_len = max(query_lens)
|
max_query_len = max(query_lens)
|
||||||
max_kv_len = max(kv_lens)
|
max_kv_len = max(kv_lens)
|
||||||
window_size = ((sliding_window,
|
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
|
||||||
sliding_window) if sliding_window is not None else
|
|
||||||
(-1, -1))
|
(-1, -1))
|
||||||
scale = head_size**-0.5
|
scale = head_size**-0.5
|
||||||
|
|
||||||
|
|||||||
@ -524,8 +524,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
self.sliding_window = ((sliding_window, sliding_window)
|
self.sliding_window = ((sliding_window - 1,
|
||||||
if sliding_window is not None else (-1, -1))
|
0) if sliding_window is not None else (-1, -1))
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
@ -535,12 +535,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
if sliding_window is not None:
|
|
||||||
# NOTE(woosuk): flash-attn's sliding window does not work with
|
|
||||||
# paged KV cache.
|
|
||||||
raise ValueError(
|
|
||||||
"Sliding window is not supported in FlashAttention.")
|
|
||||||
|
|
||||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||||
if head_size not in support_head_sizes:
|
if head_size not in support_head_sizes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -704,6 +698,7 @@ def unified_flash_attention(
|
|||||||
max_seqlen_k=max_seq_len,
|
max_seqlen_k=max_seq_len,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
block_table=prefill_meta.block_tables,
|
block_table=prefill_meta.block_tables,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
@ -725,6 +720,7 @@ def unified_flash_attention(
|
|||||||
max_seqlen_k=decode_meta.max_decode_seq_len,
|
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
block_table=decode_meta.block_tables,
|
block_table=decode_meta.block_tables,
|
||||||
@ -739,6 +735,7 @@ def unified_flash_attention(
|
|||||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
|
|||||||
@ -78,10 +78,9 @@ class Attention(nn.Module):
|
|||||||
# During model initialization, the default dtype is set as the model
|
# During model initialization, the default dtype is set as the model
|
||||||
# weight and activation dtype.
|
# weight and activation dtype.
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
attn_backend = get_attn_backend(head_size, sliding_window, dtype,
|
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
|
||||||
kv_cache_dtype, block_size,
|
block_size, is_attention_free,
|
||||||
is_attention_free, blocksparse_params
|
blocksparse_params is not None)
|
||||||
is not None)
|
|
||||||
impl_cls = attn_backend.get_impl_cls()
|
impl_cls = attn_backend.get_impl_cls()
|
||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
|
|||||||
@ -90,7 +90,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
|||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_attn_backend(
|
def get_attn_backend(
|
||||||
head_size: int,
|
head_size: int,
|
||||||
sliding_window: Optional[int],
|
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: Optional[str],
|
kv_cache_dtype: Optional[str],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -105,8 +104,8 @@ def get_attn_backend(
|
|||||||
BlocksparseFlashAttentionBackend)
|
BlocksparseFlashAttentionBackend)
|
||||||
return BlocksparseFlashAttentionBackend
|
return BlocksparseFlashAttentionBackend
|
||||||
|
|
||||||
backend = which_attn_to_use(head_size, sliding_window, dtype,
|
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
||||||
kv_cache_dtype, block_size, is_attention_free)
|
is_attention_free)
|
||||||
if backend == _Backend.FLASH_ATTN:
|
if backend == _Backend.FLASH_ATTN:
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
FlashAttentionBackend)
|
FlashAttentionBackend)
|
||||||
@ -155,7 +154,6 @@ def get_attn_backend(
|
|||||||
|
|
||||||
def which_attn_to_use(
|
def which_attn_to_use(
|
||||||
head_size: int,
|
head_size: int,
|
||||||
sliding_window: Optional[int],
|
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: Optional[str],
|
kv_cache_dtype: Optional[str],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -243,10 +241,6 @@ def which_attn_to_use(
|
|||||||
"Cannot use FlashAttention-2 backend for block size not "
|
"Cannot use FlashAttention-2 backend for block size not "
|
||||||
"divisible by 16.")
|
"divisible by 16.")
|
||||||
selected_backend = _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
elif sliding_window is not None:
|
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend due to sliding window.")
|
|
||||||
selected_backend = _Backend.XFORMERS
|
|
||||||
|
|
||||||
# FlashAttn is valid for the model, checking if the package is installed.
|
# FlashAttn is valid for the model, checking if the package is installed.
|
||||||
if selected_backend == _Backend.FLASH_ATTN:
|
if selected_backend == _Backend.FLASH_ATTN:
|
||||||
|
|||||||
@ -53,7 +53,6 @@ class CacheEngine:
|
|||||||
|
|
||||||
# Get attention backend.
|
# Get attention backend.
|
||||||
self.attn_backend = get_attn_backend(self.head_size,
|
self.attn_backend = get_attn_backend(self.head_size,
|
||||||
model_config.get_sliding_window(),
|
|
||||||
model_config.dtype,
|
model_config.dtype,
|
||||||
cache_config.cache_dtype,
|
cache_config.cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
|||||||
@ -420,7 +420,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
|||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
self.model_config.get_sliding_window(),
|
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
|||||||
@ -57,7 +57,6 @@ class CPUCacheEngine:
|
|||||||
# Get attention backend.
|
# Get attention backend.
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
self.model_config.get_sliding_window(),
|
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
cache_config.cache_dtype,
|
cache_config.cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
|||||||
@ -1011,7 +1011,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
self.model_config.get_sliding_window(),
|
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
|||||||
@ -75,7 +75,6 @@ class OpenVINOModelRunner:
|
|||||||
|
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
self.model_config.get_sliding_window(),
|
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
|||||||
@ -71,7 +71,6 @@ class OpenVINOCacheEngine:
|
|||||||
# Get attention backend.
|
# Get attention backend.
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.model_config.get_sliding_window(),
|
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.cache_config.cache_dtype,
|
self.cache_config.cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
|||||||
@ -114,7 +114,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
self.model_config.get_sliding_window(),
|
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.cache_config.cache_dtype,
|
self.cache_config.cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
|||||||
@ -374,7 +374,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
|||||||
|
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
self.model_config.get_sliding_window(),
|
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user