[Bugfix/Core] Flashinfer k_scale and v_scale (#9861)
This commit is contained in:
parent
aff1fd8188
commit
598b6d7b07
@ -258,19 +258,20 @@ def test_reshape_and_cache_flash(
|
|||||||
del key_caches
|
del key_caches
|
||||||
del value_caches
|
del value_caches
|
||||||
|
|
||||||
|
k_scale = key.amax().item() / 256
|
||||||
|
v_scale = value.amax().item() / 256
|
||||||
|
|
||||||
# Clone the KV caches.
|
# Clone the KV caches.
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||||
ops.convert_fp8(cloned_key_cache, key_cache)
|
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
|
||||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||||
ops.convert_fp8(cloned_value_cache, value_cache)
|
ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
|
||||||
|
kv_cache_dtype)
|
||||||
else:
|
else:
|
||||||
cloned_key_cache = key_cache.clone()
|
cloned_key_cache = key_cache.clone()
|
||||||
cloned_value_cache = value_cache.clone()
|
cloned_value_cache = value_cache.clone()
|
||||||
|
|
||||||
# Using default kv_scale
|
|
||||||
k_scale = v_scale = 1.0
|
|
||||||
|
|
||||||
# Call the reshape_and_cache kernel.
|
# Call the reshape_and_cache kernel.
|
||||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||||
@ -281,9 +282,15 @@ def test_reshape_and_cache_flash(
|
|||||||
|
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||||
ops.convert_fp8(result_key_cache, key_cache)
|
ops.convert_fp8(result_key_cache,
|
||||||
|
key_cache,
|
||||||
|
k_scale,
|
||||||
|
kv_dtype=kv_cache_dtype)
|
||||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||||
ops.convert_fp8(result_value_cache, value_cache)
|
ops.convert_fp8(result_value_cache,
|
||||||
|
value_cache,
|
||||||
|
v_scale,
|
||||||
|
kv_dtype=kv_cache_dtype)
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||||
|
|||||||
@ -759,8 +759,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
|
||||||
"key/v_scale is not supported in FlashInfer.")
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
@ -874,7 +872,12 @@ def unified_flash_infer(
|
|||||||
assert prefill_meta is not None
|
assert prefill_meta is not None
|
||||||
assert prefill_meta.prefill_wrapper is not None
|
assert prefill_meta.prefill_wrapper is not None
|
||||||
prefill_output = prefill_meta.prefill_wrapper.forward(
|
prefill_output = prefill_meta.prefill_wrapper.forward(
|
||||||
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
|
query,
|
||||||
|
kv_cache,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
causal=True,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale)
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
assert attn_metadata.decode_metadata is not None
|
assert attn_metadata.decode_metadata is not None
|
||||||
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
assert attn_metadata.decode_metadata.decode_wrapper is not None
|
||||||
|
|||||||
@ -141,8 +141,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
layer.register_parameter("input_scale", scale)
|
layer.register_parameter("input_scale", scale)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
max_w_scale, weight = requantize_with_max_scale(
|
weight = layer.weight
|
||||||
layer.weight, layer.weight_scale, layer.logical_widths)
|
max_w_scale = layer.weight_scale.max()
|
||||||
|
if not (layer.weight_scale == layer.weight_scale[0]).all():
|
||||||
|
max_w_scale, weight = requantize_with_max_scale(
|
||||||
|
layer.weight, layer.weight_scale, layer.logical_widths)
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user