[Misc] Compute query_start_loc/seq_start_loc on CPU (#9447)
Co-authored-by: Yang Zheng(SW)(Alex) <you@example.com>
This commit is contained in:
parent
b67feb1274
commit
4dbcbbeb09
@ -1,6 +1,7 @@
|
|||||||
"""Attention layer with FlashAttention."""
|
"""Attention layer with FlashAttention."""
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from itertools import accumulate
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -503,6 +504,8 @@ class FlashAttentionMetadataBuilder(
|
|||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
|
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||||
|
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||||
|
|
||||||
num_seqs = len(seq_lens)
|
num_seqs = len(seq_lens)
|
||||||
if use_captured_graph:
|
if use_captured_graph:
|
||||||
@ -525,29 +528,18 @@ class FlashAttentionMetadataBuilder(
|
|||||||
device, self.runner.pin_memory)
|
device, self.runner.pin_memory)
|
||||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||||
self.runner.pin_memory)
|
self.runner.pin_memory)
|
||||||
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
|
||||||
self.runner.pin_memory)
|
|
||||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||||
device, self.runner.pin_memory)
|
device, self.runner.pin_memory)
|
||||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||||
dtype=torch.int32,
|
device,
|
||||||
device=device)
|
self.runner.pin_memory)
|
||||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||||
dtype=torch.int32,
|
device, self.runner.pin_memory)
|
||||||
device=device)
|
|
||||||
placeholder_index_maps = {
|
placeholder_index_maps = {
|
||||||
modality: placeholder_map.index_map()
|
modality: placeholder_map.index_map()
|
||||||
for modality, placeholder_map in
|
for modality, placeholder_map in
|
||||||
self.multimodal_placeholder_maps.items()
|
self.multimodal_placeholder_maps.items()
|
||||||
}
|
}
|
||||||
torch.cumsum(seq_lens_tensor,
|
|
||||||
dim=0,
|
|
||||||
dtype=seq_start_loc.dtype,
|
|
||||||
out=seq_start_loc[1:])
|
|
||||||
torch.cumsum(query_lens_tensor,
|
|
||||||
dim=0,
|
|
||||||
dtype=query_start_loc.dtype,
|
|
||||||
out=query_start_loc[1:])
|
|
||||||
|
|
||||||
return FlashAttentionMetadata(
|
return FlashAttentionMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
@ -561,8 +553,8 @@ class FlashAttentionMetadataBuilder(
|
|||||||
max_decode_query_len=max_decode_query_len,
|
max_decode_query_len=max_decode_query_len,
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc_tensor,
|
||||||
seq_start_loc=seq_start_loc,
|
seq_start_loc=seq_start_loc_tensor,
|
||||||
context_lens_tensor=context_lens_tensor,
|
context_lens_tensor=context_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""Attention backend utils"""
|
"""Attention backend utils"""
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from itertools import accumulate
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -216,6 +217,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
|||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
|
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||||
|
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||||
|
|
||||||
if use_captured_graph:
|
if use_captured_graph:
|
||||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||||
@ -244,29 +247,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
|||||||
device, self.runner.pin_memory)
|
device, self.runner.pin_memory)
|
||||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||||
self.runner.pin_memory)
|
self.runner.pin_memory)
|
||||||
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
|
||||||
self.runner.pin_memory)
|
|
||||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||||
device, self.runner.pin_memory)
|
device, self.runner.pin_memory)
|
||||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||||
dtype=torch.int32,
|
device,
|
||||||
device=device)
|
self.runner.pin_memory)
|
||||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||||
dtype=torch.int32,
|
device, self.runner.pin_memory)
|
||||||
device=device)
|
|
||||||
placeholder_index_maps = {
|
placeholder_index_maps = {
|
||||||
modality: placeholder_map.index_map()
|
modality: placeholder_map.index_map()
|
||||||
for modality, placeholder_map in
|
for modality, placeholder_map in
|
||||||
self.multimodal_placeholder_maps.items()
|
self.multimodal_placeholder_maps.items()
|
||||||
}
|
}
|
||||||
torch.cumsum(seq_lens_tensor,
|
|
||||||
dim=0,
|
|
||||||
dtype=seq_start_loc.dtype,
|
|
||||||
out=seq_start_loc[1:])
|
|
||||||
torch.cumsum(query_lens_tensor,
|
|
||||||
dim=0,
|
|
||||||
dtype=query_start_loc.dtype,
|
|
||||||
out=query_start_loc[1:])
|
|
||||||
|
|
||||||
return self._metadata_cls( # type: ignore
|
return self._metadata_cls( # type: ignore
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
@ -279,8 +271,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc_tensor,
|
||||||
seq_start_loc=seq_start_loc,
|
seq_start_loc=seq_start_loc_tensor,
|
||||||
context_lens_tensor=context_lens_tensor,
|
context_lens_tensor=context_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user