[Bugfix] Fix LoRA bug (#4032)
This commit is contained in:
parent
d04973ad54
commit
b8aacac31a
@ -32,14 +32,17 @@ if TYPE_CHECKING:
|
||||
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
||||
"""Returns the device for where to place the LoRA tensors."""
|
||||
# unquantizedLinear
|
||||
if hasattr(base_layer, "weight"):
|
||||
return base_layer.weight.device
|
||||
if hasattr(base_layer, "linear_weights") and isinstance(
|
||||
base_layer.linear_weights, dict):
|
||||
values = list(base_layer.linear_weights.values())
|
||||
if len(values) and isinstance(values[0], torch.Tensor):
|
||||
return values[0].device
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
# GPTQ/AWQ/SqueezeLLM
|
||||
elif hasattr(base_layer, "qweight"):
|
||||
return base_layer.qweight.device
|
||||
# marlin
|
||||
elif hasattr(base_layer, "B"):
|
||||
return base_layer.B.device
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
|
||||
|
||||
def _apply_lora(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user