add tqdm when loading checkpoint shards (#6569)
Co-authored-by: tianyi.zhao <tianyi.zhao@transwarp.io> Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
7c2749a4fd
commit
e519ae097a
@ -331,7 +331,8 @@ def np_cache_weights_iterator(
|
|||||||
with get_lock(model_name_or_path, cache_dir):
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
if not os.path.exists(weight_names_file):
|
if not os.path.exists(weight_names_file):
|
||||||
weight_names: List[str] = []
|
weight_names: List[str] = []
|
||||||
for bin_file in hf_weights_files:
|
for bin_file in tqdm(hf_weights_files,
|
||||||
|
desc="Loading np_cache checkpoint shards"):
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
param_path = os.path.join(np_folder, name)
|
param_path = os.path.join(np_folder, name)
|
||||||
@ -355,7 +356,8 @@ def safetensors_weights_iterator(
|
|||||||
hf_weights_files: List[str]
|
hf_weights_files: List[str]
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
"""Iterate over the weights in the model safetensor files."""
|
"""Iterate over the weights in the model safetensor files."""
|
||||||
for st_file in hf_weights_files:
|
for st_file in tqdm(hf_weights_files,
|
||||||
|
desc="Loading safetensors checkpoint shards"):
|
||||||
with safe_open(st_file, framework="pt") as f:
|
with safe_open(st_file, framework="pt") as f:
|
||||||
for name in f.keys(): # noqa: SIM118
|
for name in f.keys(): # noqa: SIM118
|
||||||
param = f.get_tensor(name)
|
param = f.get_tensor(name)
|
||||||
@ -366,7 +368,8 @@ def pt_weights_iterator(
|
|||||||
hf_weights_files: List[str]
|
hf_weights_files: List[str]
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
"""Iterate over the weights in the model bin/pt files."""
|
"""Iterate over the weights in the model bin/pt files."""
|
||||||
for bin_file in hf_weights_files:
|
for bin_file in tqdm(hf_weights_files,
|
||||||
|
desc="Loading pt checkpoint shards"):
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
yield name, param
|
yield name, param
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user