[ Kernel ] Enable fp8-marlin for fbgemm-fp8 models (#6606)

This commit is contained in:
Robert Shaw 2024-07-20 14:50:10 -04:00 committed by GitHub
parent 06d6c5fe9f
commit 9364f74eee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 3 deletions

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.905
- name: "exact_match,flexible-extract"
value: 0.905
limit: 1000
num_fewshot: 5

View File

@ -1,3 +1,4 @@
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
Meta-Llama-3-70B-Instruct.yaml Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml Qwen2-57B-A14-Instruct.yaml

View File

@ -9,9 +9,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param) apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -31,6 +34,12 @@ class FBGEMMFp8Config(QuantizationConfig):
self.ignore_list = ignore_list self.ignore_list = ignore_list
self.input_scale_ub = input_scale_ub self.input_scale_ub = input_scale_ub
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
return "fbgemm_fp8" return "fbgemm_fp8"
@ -41,7 +50,7 @@ class FBGEMMFp8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 89 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
@ -143,11 +152,26 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
if self.quant_config.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale_ub
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
return apply_fp8_linear(input=x, return apply_fp8_linear(input=x,
weight=layer.weight, weight=layer.weight,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,

View File

@ -76,8 +76,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES # WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we # Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise # expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to( is_channelwise = layer.weight_scale.shape[0] == part_size_n
layer.orig_dtype).to(device) if is_channelwise:
scales = layer.weight_scale
else:
scales = layer.weight_scale.repeat(1, part_size_n)
scales = scales.to(layer.orig_dtype).to(device)
# Permute scales # Permute scales
marlin_scales = marlin_permute_scales(s=scales, marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k, size_k=part_size_k,