Fix type typo in rmsnorm (#1119)

Initially the variable `h4` is `half4`, but its last two fields are not used. Based on the semantics and the context, I believe it should be `half2`.
This commit is contained in:
Lequn Chen 2023-10-02 17:40:04 -07:00 committed by GitHub
parent 7d8317a63e
commit 26986bbc60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -98,7 +98,7 @@ __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
half2 *h1 = (half2 *)&tmp.x;
half2 *h2 = (half2 *)&tmp.y;
half2 *h3 = (half2 *)&tmp.z;
half4 *h4 = (half4 *)&tmp.w;
half2 *h4 = (half2 *)&tmp.w;
h1->x = half(static_cast<float>(l1->x) * s_mean * static_cast<float>(g1->x));
h1->y = half(static_cast<float>(l1->y) * s_mean * static_cast<float>(g1->y));