Fix type typo in rmsnorm (#1119)
Initially the variable `h4` is `half4`, but its last two fields are not used. Based on the semantics and the context, I believe it should be `half2`.
This commit is contained in:
parent
7d8317a63e
commit
26986bbc60
@ -98,7 +98,7 @@ __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input,
|
||||
half2 *h1 = (half2 *)&tmp.x;
|
||||
half2 *h2 = (half2 *)&tmp.y;
|
||||
half2 *h3 = (half2 *)&tmp.z;
|
||||
half4 *h4 = (half4 *)&tmp.w;
|
||||
half2 *h4 = (half2 *)&tmp.w;
|
||||
|
||||
h1->x = half(static_cast<float>(l1->x) * s_mean * static_cast<float>(g1->x));
|
||||
h1->y = half(static_cast<float>(l1->y) * s_mean * static_cast<float>(g1->y));
|
||||
|
Loading…
Reference in New Issue
Block a user