[Kernel] Use CUTLASS kernels for the FP8 layers with Bias (#6270)

This commit is contained in:
Tyler Michael Smith 2024-07-15 13:05:52 -04:00 committed by GitHub
parent 94b82e8c18
commit c8fd97f26d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -112,7 +112,7 @@ def apply_fp8_linear(
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if bias is None and cutlass_fp8_supported:
if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
# Fused GEMM_DQ
@ -120,7 +120,8 @@ def apply_fp8_linear(
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale)
scale_b=weight_scale,
bias=bias)
else:
qinput, x_scale = ops.scaled_fp8_quant(input,