[TPU] Call torch._sync(param) during weight loading (#9437)
This commit is contained in:
parent
5e443b594f
commit
8e1cddcd44
@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import seed_everything
|
from vllm.utils import seed_everything
|
||||||
|
|
||||||
|
|
||||||
@ -28,4 +29,25 @@ def set_weight_attrs(
|
|||||||
for key, value in weight_attrs.items():
|
for key, value in weight_attrs.items():
|
||||||
assert not hasattr(
|
assert not hasattr(
|
||||||
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
||||||
|
|
||||||
|
# NOTE(woosuk): During weight loading, we often do something like:
|
||||||
|
# narrowed_tensor = param.data.narrow(0, offset, len)
|
||||||
|
# narrowed_tensor.copy_(real_weight)
|
||||||
|
# expecting narrowed_tensor and param.data to share the same storage.
|
||||||
|
# However, on TPUs, narrowed_tensor will lazily propagate to the base
|
||||||
|
# tensor, which is param.data, leading to the redundant memory usage.
|
||||||
|
# This sometimes causes OOM errors during model loading. To avoid this,
|
||||||
|
# we sync the param tensor after its weight loader is called.
|
||||||
|
# TODO(woosuk): Remove this hack once we have a better solution.
|
||||||
|
if current_platform.is_tpu() and key == "weight_loader":
|
||||||
|
value = _make_synced_weight_loader(value)
|
||||||
setattr(weight, key, value)
|
setattr(weight, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_synced_weight_loader(original_weight_loader):
|
||||||
|
|
||||||
|
def _synced_weight_loader(param, *args, **kwargs):
|
||||||
|
original_weight_loader(param, *args, **kwargs)
|
||||||
|
torch._sync(param)
|
||||||
|
|
||||||
|
return _synced_weight_loader
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user