From f0eecee6106774e1e0f9b31c7438cde77654df52 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 20 May 2024 21:44:25 +0300 Subject: [PATCH] [Bugfix] Fix dummy weight for fp8 (#4916) Allow dummy load format for fp8, torch.uniform_ doesn't support FP8 at the moment Co-authored-by: Mor Zusman --- vllm/model_executor/model_loader/weight_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c1abde9a..a1642baa 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -369,4 +369,11 @@ def initialize_dummy_weights( """ for param in model.state_dict().values(): if torch.is_floating_point(param): - param.data.uniform_(low, high) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high)