[Bugfix/Core] Flashinfer k_scale and v_scale (#9861)
This commit is contained in:
parent
aff1fd8188
commit
598b6d7b07
@ -258,19 +258,20 @@ def test_reshape_and_cache_flash(
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
k_scale = key.amax().item() / 256
|
||||
v_scale = value.amax().item() / 256
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
|
||||
kv_cache_dtype)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
@ -281,9 +282,15 @@ def test_reshape_and_cache_flash(
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache, key_cache)
|
||||
ops.convert_fp8(result_key_cache,
|
||||
key_cache,
|
||||
k_scale,
|
||||
kv_dtype=kv_cache_dtype)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache, value_cache)
|
||||
ops.convert_fp8(result_value_cache,
|
||||
value_cache,
|
||||
v_scale,
|
||||
kv_dtype=kv_cache_dtype)
|
||||
|
||||
# Run the reference implementation.
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
|
||||
@ -759,8 +759,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashInfer.")
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
@ -874,7 +872,12 @@ def unified_flash_infer(
|
||||
assert prefill_meta is not None
|
||||
assert prefill_meta.prefill_wrapper is not None
|
||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
|
||||
query,
|
||||
kv_cache,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
causal=True,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||
|
||||
@ -141,8 +141,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("input_scale", scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
layer.weight, layer.weight_scale, layer.logical_widths)
|
||||
weight = layer.weight
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
if not (layer.weight_scale == layer.weight_scale[0]).all():
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
layer.weight, layer.weight_scale, layer.logical_widths)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user