Move verify_marlin_supported to GPTQMarlinLinearMethod (#8165)
This commit is contained in:
parent
9da25a88aa
commit
2ee45281a5
@ -51,10 +51,6 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||||
|
|
||||||
# Verify supported on platform.
|
|
||||||
verify_marlin_supported(quant_type=self.quant_type,
|
|
||||||
group_size=self.group_size)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||||
f"group_size={self.group_size}, "
|
f"group_size={self.group_size}, "
|
||||||
@ -153,6 +149,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
# Verify supported on platform.
|
||||||
|
verify_marlin_supported(quant_type=self.quant_config.quant_type,
|
||||||
|
group_size=self.quant_config.group_size)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user