Add fp8 support to reshape_and_cache_flash (#6667)
This commit is contained in:
parent
ee812580f7
commit
0e63494cf3
@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
const std::string& kv_cache_dtype,
|
||||
const double k_scale, const double v_scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
|
||||
@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void reshape_and_cache_flash_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size) {
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float k_scale, const float v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
key_cache[tgt_key_value_idx] = tgt_key;
|
||||
value_cache[tgt_key_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
value_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
@ -278,40 +288,45 @@ void reshape_and_cache(
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype) {
|
||||
// FIXME: only support auto datatype, does not support fp8
|
||||
if (kv_cache_dtype != "auto") {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = k_cache.size(1);
|
||||
int block_size = key_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = k_cache.stride(0);
|
||||
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
|
||||
int block_stride = key_cache.stride(0);
|
||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "reshape_and_cache_flash", [&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
|
||||
value_stride, num_heads, head_size, block_size);
|
||||
});
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -248,7 +248,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! key_cache,"
|
||||
" Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype) -> ()");
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||
&reshape_and_cache_flash);
|
||||
|
||||
|
||||
@ -215,8 +215,6 @@ def test_reshape_and_cache_flash(
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8":
|
||||
pytest.skip()
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
@ -248,15 +246,33 @@ def test_reshape_and_cache_flash(
|
||||
dtype,
|
||||
device=device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
key_cache, value_cache = key_caches[0].contiguous(
|
||||
), value_caches[0].contiguous()
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype)
|
||||
slot_mapping, kv_cache_dtype, k_scale, v_scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache, key_cache)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache, value_cache)
|
||||
|
||||
# Run the reference implementation.
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
@ -269,8 +285,18 @@ def test_reshape_and_cache_flash(
|
||||
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
if kv_cache_dtype == "fp8":
|
||||
assert torch.allclose(result_key_cache,
|
||||
cloned_key_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
assert torch.allclose(result_value_cache,
|
||||
cloned_value_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
else:
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
||||
|
||||
@ -426,10 +426,13 @@ def reshape_and_cache_flash(
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
|
||||
value_cache, slot_mapping,
|
||||
kv_cache_dtype)
|
||||
kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
|
||||
|
||||
def copy_blocks(key_caches: List[torch.Tensor],
|
||||
|
||||
@ -478,6 +478,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
@ -489,6 +489,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
query = query.contiguous(
|
||||
|
||||
@ -491,7 +491,6 @@ def create_kv_caches_with_random_flash(
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
assert cache_dtype != "fp8"
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
@ -507,7 +506,13 @@ def create_kv_caches_with_random_flash(
|
||||
key_value_cache = torch.empty(size=key_value_cache_shape,
|
||||
dtype=torch_dtype,
|
||||
device=device)
|
||||
key_value_cache.uniform_(-scale, scale)
|
||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||
key_value_cache.uniform_(-scale, scale)
|
||||
elif cache_dtype == 'fp8':
|
||||
_generate_random_fp8(key_value_cache, -scale, scale)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Does not support key cache of type {cache_dtype}")
|
||||
key_caches.append(key_value_cache[:, 0])
|
||||
value_caches.append(key_value_cache[:, 1])
|
||||
return key_caches, value_caches
|
||||
|
||||
Loading…
Reference in New Issue
Block a user