[Kernel] Use CUTLASS kernels for the FP8 layers with Bias (#6270)
This commit is contained in:
parent
94b82e8c18
commit
c8fd97f26d
@ -112,7 +112,7 @@ def apply_fp8_linear(
|
|||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
|
||||||
if bias is None and cutlass_fp8_supported:
|
if cutlass_fp8_supported:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
|
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
|
||||||
|
|
||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
@ -120,7 +120,8 @@ def apply_fp8_linear(
|
|||||||
weight,
|
weight,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=weight_scale)
|
scale_b=weight_scale,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(input,
|
qinput, x_scale = ops.scaled_fp8_quant(input,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user