[ Kernel ] Enable fp8-marlin for fbgemm-fp8 models (#6606)
This commit is contained in:
parent
06d6c5fe9f
commit
9364f74eee
@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
|
||||||
|
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.905
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.905
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
|
||||||
Meta-Llama-3-70B-Instruct.yaml
|
Meta-Llama-3-70B-Instruct.yaml
|
||||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||||
Qwen2-57B-A14-Instruct.yaml
|
Qwen2-57B-A14-Instruct.yaml
|
||||||
|
|||||||
@ -9,9 +9,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, create_per_channel_scale_param)
|
apply_fp8_linear, create_per_channel_scale_param)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -31,6 +34,12 @@ class FBGEMMFp8Config(QuantizationConfig):
|
|||||||
self.ignore_list = ignore_list
|
self.ignore_list = ignore_list
|
||||||
self.input_scale_ub = input_scale_ub
|
self.input_scale_ub = input_scale_ub
|
||||||
|
|
||||||
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
|
# kernel for fast weight-only FP8 quantization
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
self.use_marlin = capability < 89
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
return "fbgemm_fp8"
|
return "fbgemm_fp8"
|
||||||
@ -41,7 +50,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 89
|
return 80
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
@ -143,11 +152,26 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
|
||||||
|
if self.quant_config.use_marlin:
|
||||||
|
prepare_fp8_layer_for_marlin(layer)
|
||||||
|
# Activations not quantized for marlin.
|
||||||
|
del layer.input_scale_ub
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
if self.quant_config.use_marlin:
|
||||||
|
return apply_fp8_marlin_linear(
|
||||||
|
input=x,
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
workspace=layer.workspace,
|
||||||
|
size_n=layer.output_size_per_partition,
|
||||||
|
size_k=layer.input_size_per_partition,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
return apply_fp8_linear(input=x,
|
return apply_fp8_linear(input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
|
|||||||
@ -76,8 +76,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
|||||||
# WEIGHT SCALES
|
# WEIGHT SCALES
|
||||||
# Currently Marlin doesn't support per-tensor scales, so we
|
# Currently Marlin doesn't support per-tensor scales, so we
|
||||||
# expand it to channelwise
|
# expand it to channelwise
|
||||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
is_channelwise = layer.weight_scale.shape[0] == part_size_n
|
||||||
layer.orig_dtype).to(device)
|
if is_channelwise:
|
||||||
|
scales = layer.weight_scale
|
||||||
|
else:
|
||||||
|
scales = layer.weight_scale.repeat(1, part_size_n)
|
||||||
|
scales = scales.to(layer.orig_dtype).to(device)
|
||||||
|
|
||||||
# Permute scales
|
# Permute scales
|
||||||
marlin_scales = marlin_permute_scales(s=scales,
|
marlin_scales = marlin_permute_scales(s=scales,
|
||||||
size_k=part_size_k,
|
size_k=part_size_k,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user