Optimize data movement (#20)
This commit is contained in:
parent
1f01a18d39
commit
897cb2ae28
20
cacheflow/models/activation.py
Normal file
20
cacheflow/models/activation.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow import activation_ops
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (num_tokens, 2 * d)
|
||||
) -> torch.Tensor: # (num_tokens, d)
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1] // 2
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
return out
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -16,40 +16,38 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
|
||||
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
prompt_lens: List[int],
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
|
||||
max_prompt_len: int,
|
||||
) -> None:
|
||||
if query.dtype == torch.float:
|
||||
raise ValueError('The float data type is not supported by '
|
||||
'FlashAttention. Use the half data type instead.')
|
||||
head_size = query.shape[2]
|
||||
head_size = query.shape[-1]
|
||||
if head_size > 128:
|
||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
||||
|
||||
device = query.device
|
||||
prefix_sum = [0]
|
||||
for prompt_len in prompt_lens:
|
||||
prefix_sum.append(prefix_sum[-1] + prompt_len)
|
||||
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
|
||||
max_prompt_len = max(prompt_lens)
|
||||
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
qkv = torch.stack([query, key, value], dim=1)
|
||||
out = self.flash_attn(
|
||||
qkv,
|
||||
cu_seqlens=prefix_sum,
|
||||
max_s=max_prompt_len,
|
||||
# Directly call FlashAttention's internal function to avoid allocating
|
||||
# a new tensor for the output.
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cumulative_prompt_lens,
|
||||
cumulative_prompt_lens,
|
||||
max_prompt_len,
|
||||
max_prompt_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)[0]
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
output.copy_(out, non_blocking=True)
|
||||
return_softmax=False,
|
||||
)
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
@ -90,21 +88,18 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# Pre-allocate the output tensor.
|
||||
output = torch.empty_like(query)
|
||||
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
|
||||
# Prune out paddings if any.
|
||||
query = query[:input_metadata.num_valid_tokens]
|
||||
key = key[:input_metadata.num_valid_tokens]
|
||||
value = value[:input_metadata.num_valid_tokens]
|
||||
|
||||
# Reshape the input tensors.
|
||||
# Reshape the query, key, and value tensors.
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_heads, head_size)
|
||||
value = value.view(-1, num_heads, head_size)
|
||||
output = output.view(-1, num_heads, head_size)
|
||||
|
||||
# Pre-allocate the output tensor.
|
||||
output = torch.empty_like(query)
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||
@ -114,7 +109,8 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
input_metadata.prompt_lens,
|
||||
input_metadata.cumulative_prompt_lens,
|
||||
input_metadata.max_prompt_len,
|
||||
)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
@ -122,14 +118,22 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
cache_event.wait()
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, input_metadata.slot_mapping)
|
||||
num_valid_tokens = input_metadata.num_valid_tokens
|
||||
if num_valid_tokens > 0:
|
||||
# The stride is 3 because the key and value are sliced from qkv.
|
||||
cache_ops.reshape_and_cache(
|
||||
key[:num_valid_tokens],
|
||||
value[:num_valid_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.slot_mapping,
|
||||
)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(
|
||||
output[num_prompt_tokens:],
|
||||
query[num_prompt_tokens:],
|
||||
output[num_prompt_tokens:num_valid_tokens],
|
||||
query[num_prompt_tokens:num_valid_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata)
|
||||
@ -186,19 +190,15 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
out_query = torch.empty_like(query)
|
||||
out_key = torch.empty_like(key)
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
out_query,
|
||||
out_key,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.cos_sin_cache,
|
||||
)
|
||||
return super().forward(
|
||||
out_query,
|
||||
out_key,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
|
||||
@ -12,6 +12,7 @@ class InputMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
prompt_lens: List[int],
|
||||
cumulative_prompt_lens: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
@ -20,6 +21,7 @@ class InputMetadata:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.prompt_lens = prompt_lens
|
||||
self.cumulative_prompt_lens = cumulative_prompt_lens
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.max_context_len = max_context_len
|
||||
@ -27,6 +29,7 @@ class InputMetadata:
|
||||
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = sum(prompt_lens)
|
||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
self.num_valid_tokens = slot_mapping.shape[0]
|
||||
if block_tables.numel() > 0:
|
||||
@ -40,11 +43,13 @@ class InputMetadata:
|
||||
return (f'InputMetadata('
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'max_prompt_len={self.max_prompt_len}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
|
||||
f'slot_mapping={self.slot_mapping}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'block_tables={self.block_tables})')
|
||||
|
||||
@ -11,6 +11,7 @@ from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.activation import SiluAndMul
|
||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||
from cacheflow.models.layernorm import RMSNorm
|
||||
from cacheflow.models.sample import Sampler
|
||||
@ -39,16 +40,14 @@ class LlamaMLP(nn.Module):
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
assert hidden_act == 'silu'
|
||||
self.act_fn = nn.SiLU()
|
||||
if hidden_act != 'silu':
|
||||
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
||||
'Only silu is supported for now.')
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
|
||||
gate, up = torch.split(gate_up, 1, dim=-2)
|
||||
gate = gate.squeeze(dim=-2).contiguous()
|
||||
up = up.squeeze(dim=-2).contiguous()
|
||||
x = self.act_fn(gate) * up
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
@ -94,11 +93,7 @@ class LlamaAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
||||
q, k, v = torch.split(qkv, 1, dim=-2)
|
||||
q = q.squeeze(dim=-2).contiguous()
|
||||
k = k.squeeze(dim=-2).contiguous()
|
||||
v = v.squeeze(dim=-2).contiguous()
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
|
||||
@ -69,17 +69,14 @@ class OPTAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
||||
q, k, v = torch.split(qkv, 1, dim=-2)
|
||||
q = q.squeeze(dim=-2).contiguous()
|
||||
k = k.squeeze(dim=-2).contiguous()
|
||||
v = v.squeeze(dim=-2).contiguous()
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
|
||||
@ -128,6 +128,11 @@ class Worker:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
cumulative_prompt_lens: List[int] = [0]
|
||||
for prompt_len in prompt_lens:
|
||||
cumulative_prompt_lens.append(
|
||||
cumulative_prompt_lens[-1] + prompt_len)
|
||||
|
||||
# Add generation tokens.
|
||||
max_context_len = 0
|
||||
max_num_blocks_per_seq = 0
|
||||
@ -183,11 +188,14 @@ class Worker:
|
||||
for block_table in generation_block_tables]
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_block_tables, dtype=torch.int, device='cuda')
|
||||
cumulative_prompt_lens_tensor = torch.tensor(
|
||||
cumulative_prompt_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_logprobs=seq_logprobs,
|
||||
prompt_lens=prompt_lens,
|
||||
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
context_lens=context_lens_tensor,
|
||||
max_context_len=max_context_len,
|
||||
|
||||
12
csrc/activation.cpp
Normal file
12
csrc/activation.cpp
Normal file
@ -0,0 +1,12 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
}
|
||||
46
csrc/activation_kernels.cu
Normal file
46
csrc/activation_kernels.cu
Normal file
@ -0,0 +1,46 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void silu_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
|
||||
const int d) {
|
||||
const int token_idx = blockIdx.x;
|
||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [num_tokens, d]
|
||||
torch::Tensor& input) // [num_tokens, 2 * d]
|
||||
{
|
||||
int num_tokens = input.size(0);
|
||||
int d = input.size(1) / 2;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(),
|
||||
"silu_and_mul_kernel",
|
||||
[&] {
|
||||
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
d);
|
||||
});
|
||||
}
|
||||
@ -25,7 +25,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq) {
|
||||
const int max_num_blocks_per_seq,
|
||||
const int q_stride) {
|
||||
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
@ -56,7 +57,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
// For example, if the the thread group size is 4, then the first thread in the group
|
||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||
// th vectors of the query, and so on.
|
||||
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||
@ -264,7 +266,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_blocks_per_seq);
|
||||
max_num_blocks_per_seq, \
|
||||
query_stride);
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
@ -284,6 +287,7 @@ void single_query_cached_kv_attention_launcher(
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int query_stride = query.stride(0);
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
@ -333,13 +337,13 @@ void single_query_cached_kv_attention_launcher(
|
||||
}
|
||||
|
||||
void single_query_cached_kv_attention(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len) {
|
||||
// TODO(woosuk): Support BF16.
|
||||
|
||||
@ -81,6 +81,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
@ -92,7 +94,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
|
||||
const int n = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int src_idx = token_idx * n + i;
|
||||
const int src_key_idx = token_idx * key_stride + i;
|
||||
const int src_value_idx = token_idx * value_stride + i;
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
@ -108,25 +111,29 @@ __global__ void reshape_and_cache_kernel(
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
key_cache[tgt_key_idx] = __ldg(&key[src_idx]);
|
||||
value_cache[tgt_value_idx] = __ldg(&value[src_idx]);
|
||||
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
|
||||
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping) {
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [num_tokens]
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -140,6 +147,8 @@ void reshape_and_cache(
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rotary_embedding_neox(
|
||||
torch::Tensor& out_query,
|
||||
torch::Tensor& out_key,
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
|
||||
@ -5,12 +5,11 @@ namespace cacheflow {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void rotary_embedding_neox_kernel(
|
||||
scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size]
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
|
||||
const int stride,
|
||||
const int num_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
@ -19,41 +18,36 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
|
||||
|
||||
const int embed_dim = head_size / 2;
|
||||
const int n = num_heads * head_size;
|
||||
const int n = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int idx = token_idx * n + i;
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * stride + head_idx * head_size;
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int token_head = token_idx * n + head_idx * head_size;
|
||||
|
||||
const bool is_first_half = head_offset < embed_dim;
|
||||
const int rot_offset = head_offset % embed_dim;
|
||||
const int rot_offset = i % embed_dim;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int out_x = token_idx * stride + head_idx * head_size + x_index;
|
||||
const int out_y = token_idx * stride + head_idx * head_size + y_index;
|
||||
|
||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||
|
||||
const scalar_t q_x = __ldg(query + token_head + x_index);
|
||||
const scalar_t q_y = __ldg(query + token_head + y_index);
|
||||
const scalar_t q_cos = is_first_half ? q_x : q_y;
|
||||
const scalar_t q_sin = is_first_half ? -q_y : q_x;
|
||||
out_query[idx] = q_cos * cos + q_sin * sin;
|
||||
const scalar_t q_x = query[token_head + x_index];
|
||||
const scalar_t q_y = query[token_head + y_index];
|
||||
query[out_x] = q_x * cos - q_y * sin;
|
||||
query[out_y] = q_y * cos + q_x * sin;
|
||||
|
||||
const scalar_t k_x = __ldg(key + token_head + x_index);
|
||||
const scalar_t k_y = __ldg(key + token_head + y_index);
|
||||
const scalar_t k_cos = is_first_half ? k_x : k_y;
|
||||
const scalar_t k_sin = is_first_half ? -k_y : k_x;
|
||||
out_key[idx] = k_cos * cos + k_sin * sin;
|
||||
const scalar_t k_x = key[token_head + x_index];
|
||||
const scalar_t k_y = key[token_head + y_index];
|
||||
key[out_x] = k_x * cos - k_y * sin;
|
||||
key[out_y] = k_y * cos + k_x * sin;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void rotary_embedding_neox(
|
||||
torch::Tensor& out_query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& out_key, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
||||
@ -62,21 +56,22 @@ void rotary_embedding_neox(
|
||||
int num_tokens = query.size(0);
|
||||
int head_size = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
int stride = query.stride(0);
|
||||
TORCH_CHECK(stride == key.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
dim3 block(std::min(num_heads * head_size / 2, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_neox",
|
||||
[&] {
|
||||
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out_query.data_ptr<scalar_t>(),
|
||||
out_key.data_ptr<scalar_t>(),
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
stride,
|
||||
num_heads,
|
||||
head_size);
|
||||
});
|
||||
|
||||
7
setup.py
7
setup.py
@ -39,6 +39,13 @@ layernorm_extension = cpp_extension.CUDAExtension(
|
||||
)
|
||||
ext_modules.append(layernorm_extension)
|
||||
|
||||
activation_extension = cpp_extension.CUDAExtension(
|
||||
name='cacheflow.activation_ops',
|
||||
sources=['csrc/activation.cpp', 'csrc/activation_kernels.cu'],
|
||||
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||
)
|
||||
ext_modules.append(activation_extension)
|
||||
|
||||
setuptools.setup(
|
||||
name='cacheflow',
|
||||
ext_modules=ext_modules,
|
||||
|
||||
30
tests/kernels/activation.py
Normal file
30
tests/kernels/activation.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cacheflow import activation_ops
|
||||
|
||||
|
||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(chunks=2, dim=1)
|
||||
return F.silu(x1) * x2
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
ref_out = ref_silu_and_mul(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for num_tokens in [7, 83, 2048]:
|
||||
for d in [512, 4096, 13824]:
|
||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||
test_silu_and_mul(num_tokens, d, dtype)
|
||||
@ -1,7 +1,7 @@
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
|
||||
from cacheflow import attention_ops
|
||||
@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
query = torch.randn(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention(
|
||||
value_cache = torch.randn(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
|
||||
# Adjust the range of the values to reduce precision errors.
|
||||
query = query / (head_size ** 0.5)
|
||||
key_cache = key_cache / (head_size ** 0.5)
|
||||
value_cache = value_cache / (head_size ** 0.5)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention(
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty_like(query)
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
@ -175,19 +182,28 @@ def test_multi_query_kv_attention(
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
query = torch.randn(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
key = torch.rand_like(query)
|
||||
value = torch.rand_like(query)
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
# Adjust the range of the values to reduce precision errors.
|
||||
qkv = qkv / (head_size ** 0.5)
|
||||
|
||||
qkv = torch.stack([query, key, value], dim=1)
|
||||
flash_attn = FlashAttention(softmax_scale=scale)
|
||||
output = flash_attn(
|
||||
qkv,
|
||||
cu_seqlens=cu_seq_lens,
|
||||
max_s=max_seq_len,
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cu_seq_lens,
|
||||
cu_seq_lens,
|
||||
max_seq_len,
|
||||
max_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
)[0]
|
||||
return_softmax=False,
|
||||
)
|
||||
|
||||
cu_seq_lens = cu_seq_lens.cpu().tolist()
|
||||
ref_output = ref_multi_query_kv_attention(
|
||||
|
||||
@ -17,10 +17,10 @@ def test_reshape_and_cache(
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
kv_shape = (num_tokens, num_heads, head_size)
|
||||
key = torch.randn(size=kv_shape, dtype=dtype, device='cuda')
|
||||
value = torch.randn(size=kv_shape, dtype=dtype, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
@ -35,7 +35,7 @@ def test_reshape_and_cache(
|
||||
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = slot_mapping[i] // block_size
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
|
||||
@ -85,15 +85,13 @@ def test_rotary_embedding_neox(
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||
|
||||
# Run the kernel.
|
||||
out_query = torch.empty_like(query)
|
||||
out_key = torch.empty_like(key)
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
out_key = key.clone()
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
cos_sin_cache,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user