From 4ef95b0f0677f95d8181837bbeebca7fca5a2bb2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 15 Jul 2024 19:14:49 +0200 Subject: [PATCH] [Bugfix] use float32 precision in samplers/test_logprobs.py for comparing with HF (#6409) Signed-off-by: Thomas Parnell --- tests/samplers/test_logprobs.py | 3 ++- vllm/attention/ops/prefix_prefill.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 02a953da..f7bcd4c8 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -11,7 +11,8 @@ MODELS = ["facebook/opt-125m"] @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", + ["float"]) # needed for comparing logprobs with HF @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) @pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size @pytest.mark.parametrize("detokenize", [True, False]) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 70b544b6..4577d84d 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -687,6 +687,12 @@ if triton.__version__ >= "2.1.0": cap = current_platform.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 + + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + if q.dtype is torch.float32: + BLOCK = BLOCK // 2 + # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv