diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 4b9653de..aac86351 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -32,14 +32,17 @@ if TYPE_CHECKING: def _get_lora_device(base_layer: nn.Module) -> torch.device: # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear if hasattr(base_layer, "weight"): return base_layer.weight.device - if hasattr(base_layer, "linear_weights") and isinstance( - base_layer.linear_weights, dict): - values = list(base_layer.linear_weights.values()) - if len(values) and isinstance(values[0], torch.Tensor): - return values[0].device - raise ValueError(f"Unsupported base layer: {base_layer}") + # GPTQ/AWQ/SqueezeLLM + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") def _apply_lora(