Fix random state for dropout_layer_norm (#315)

This commit is contained in:
Joel Lamy-Poirier 2023-07-23 18:05:13 -04:00 committed by GitHub
parent d38357dd2f
commit 767b71ccf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -229,11 +229,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
// Request the kernel launcher.
auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
@ -252,6 +247,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params.rowscale_const = rowscale_const;
params.is_rms_norm = is_rms_norm;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
if (dropout_p > 0.f) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
@ -594,11 +594,6 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
// Request the kernel launcher.
auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
@ -621,6 +616,11 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
params.inverse_cols = 1.f / float(params.cols);
params.is_rms_norm = is_rms_norm;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
if (dropout_p > 0.f) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state