[ Kernel ] Enable fp8-marlin for fbgemm-fp8 models (#6606)
This commit is contained in:
parent
06d6c5fe9f
commit
9364f74eee
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
|
||||
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.905
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.905
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
@ -1,3 +1,4 @@
|
||||
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
|
||||
Meta-Llama-3-70B-Instruct.yaml
|
||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||
Qwen2-57B-A14-Instruct.yaml
|
||||
|
||||
@ -9,9 +9,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, create_per_channel_scale_param)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -31,6 +34,12 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
self.ignore_list = ignore_list
|
||||
self.input_scale_ub = input_scale_ub
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "fbgemm_fp8"
|
||||
@ -41,7 +50,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@ -143,11 +152,26 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
if self.quant_config.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale_ub
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.quant_config.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
return apply_fp8_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
|
||||
@ -76,8 +76,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||
layer.orig_dtype).to(device)
|
||||
is_channelwise = layer.weight_scale.shape[0] == part_size_n
|
||||
if is_channelwise:
|
||||
scales = layer.weight_scale
|
||||
else:
|
||||
scales = layer.weight_scale.repeat(1, part_size_n)
|
||||
scales = scales.to(layer.orig_dtype).to(device)
|
||||
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=part_size_k,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user