[Bugfix] Fix LoRA bug (#4032)
This commit is contained in:
parent
d04973ad54
commit
b8aacac31a
@ -32,14 +32,17 @@ if TYPE_CHECKING:
|
|||||||
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||||
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
||||||
"""Returns the device for where to place the LoRA tensors."""
|
"""Returns the device for where to place the LoRA tensors."""
|
||||||
|
# unquantizedLinear
|
||||||
if hasattr(base_layer, "weight"):
|
if hasattr(base_layer, "weight"):
|
||||||
return base_layer.weight.device
|
return base_layer.weight.device
|
||||||
if hasattr(base_layer, "linear_weights") and isinstance(
|
# GPTQ/AWQ/SqueezeLLM
|
||||||
base_layer.linear_weights, dict):
|
elif hasattr(base_layer, "qweight"):
|
||||||
values = list(base_layer.linear_weights.values())
|
return base_layer.qweight.device
|
||||||
if len(values) and isinstance(values[0], torch.Tensor):
|
# marlin
|
||||||
return values[0].device
|
elif hasattr(base_layer, "B"):
|
||||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
return base_layer.B.device
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||||
|
|
||||||
|
|
||||||
def _apply_lora(
|
def _apply_lora(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user