[Bugfix] Fix LoRA bug (#4032)

This commit is contained in:
Jee Li 2024-04-13 07:56:37 +08:00 committed by GitHub
parent d04973ad54
commit b8aacac31a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -32,14 +32,17 @@ if TYPE_CHECKING:
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
if hasattr(base_layer, "linear_weights") and isinstance(
base_layer.linear_weights, dict):
values = list(base_layer.linear_weights.values())
if len(values) and isinstance(values[0], torch.Tensor):
return values[0].device
raise ValueError(f"Unsupported base layer: {base_layer}")
# GPTQ/AWQ/SqueezeLLM
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# marlin
elif hasattr(base_layer, "B"):
return base_layer.B.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _apply_lora(