diff --git a/tools/util/include/cutlass/util/device_rmsnorm.h b/tools/util/include/cutlass/util/device_rmsnorm.h index 5090efa0..a401db00 100644 --- a/tools/util/include/cutlass/util/device_rmsnorm.h +++ b/tools/util/include/cutlass/util/device_rmsnorm.h @@ -98,7 +98,7 @@ __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, half2 *h1 = (half2 *)&tmp.x; half2 *h2 = (half2 *)&tmp.y; half2 *h3 = (half2 *)&tmp.z; - half4 *h4 = (half4 *)&tmp.w; + half2 *h4 = (half2 *)&tmp.w; h1->x = half(static_cast(l1->x) * s_mean * static_cast(g1->x)); h1->y = half(static_cast(l1->y) * s_mean * static_cast(g1->y));