[Bugfix] use float32 precision in samplers/test_logprobs.py for comparing with HF (#6409)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
eaec4b9153
commit
4ef95b0f06
@ -11,7 +11,8 @@ MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
["float"]) # needed for comparing logprobs with HF
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
|
||||
@ -687,6 +687,12 @@ if triton.__version__ >= "2.1.0":
|
||||
|
||||
cap = current_platform.get_device_capability()
|
||||
BLOCK = 128 if cap[0] >= 8 else 64
|
||||
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
if q.dtype is torch.float32:
|
||||
BLOCK = BLOCK // 2
|
||||
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
|
||||
Loading…
Reference in New Issue
Block a user