From 767b71ccf0664ea382135f039212f087afc4c682 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sun, 23 Jul 2023 18:05:13 -0400 Subject: [PATCH] Fix random state for dropout_layer_norm (#315) --- csrc/layer_norm/ln_api.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/csrc/layer_norm/ln_api.cpp b/csrc/layer_norm/ln_api.cpp index abb2e07..3981bba 100644 --- a/csrc/layer_norm/ln_api.cpp +++ b/csrc/layer_norm/ln_api.cpp @@ -229,11 +229,6 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: // Request the kernel launcher. auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - at::Tensor workspace, barrier; - // Set the kernel runtime parameters. layer_norm::FwdParams ¶ms = launch_params.params; params.rows = rows; @@ -252,6 +247,11 @@ std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: params.rowscale_const = rowscale_const; params.is_rms_norm = is_rms_norm; + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + at::Tensor workspace, barrier; + if (dropout_p > 0.f) { // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -594,11 +594,6 @@ std::vector dropout_add_ln_parallel_residual_fwd( // Request the kernel launcher. auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - at::Tensor workspace, barrier; - // Set the kernel runtime parameters. layer_norm::FwdParams ¶ms = launch_params.params; params.rows = rows; @@ -621,6 +616,11 @@ std::vector dropout_add_ln_parallel_residual_fwd( params.inverse_cols = 1.f / float(params.cols); params.is_rms_norm = is_rms_norm; + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + at::Tensor workspace, barrier; + if (dropout_p > 0.f) { // number of times random will be generated per thread, to offset philox counter in thc random // state