From 26986bbc604b7c395844f9fdbb73ba945dfd28e3 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Mon, 2 Oct 2023 17:40:04 -0700 Subject: [PATCH] Fix type typo in rmsnorm (#1119) Initially the variable `h4` is `half4`, but its last two fields are not used. Based on the semantics and the context, I believe it should be `half2`. --- tools/util/include/cutlass/util/device_rmsnorm.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/util/include/cutlass/util/device_rmsnorm.h b/tools/util/include/cutlass/util/device_rmsnorm.h index 5090efa0..a401db00 100644 --- a/tools/util/include/cutlass/util/device_rmsnorm.h +++ b/tools/util/include/cutlass/util/device_rmsnorm.h @@ -98,7 +98,7 @@ __global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, half2 *h1 = (half2 *)&tmp.x; half2 *h2 = (half2 *)&tmp.y; half2 *h3 = (half2 *)&tmp.z; - half4 *h4 = (half4 *)&tmp.w; + half2 *h4 = (half2 *)&tmp.w; h1->x = half(static_cast(l1->x) * s_mean * static_cast(g1->x)); h1->y = half(static_cast(l1->y) * s_mean * static_cast(g1->y));