Fix random state for dropout_layer_norm (#315)
This commit is contained in:
parent
d38357dd2f
commit
767b71ccf0
@ -229,11 +229,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
// Request the kernel launcher.
|
||||
auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
||||
|
||||
// Query the kernel-specific launch parameters.
|
||||
launcher(launch_params, true);
|
||||
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
// Set the kernel runtime parameters.
|
||||
layer_norm::FwdParams ¶ms = launch_params.params;
|
||||
params.rows = rows;
|
||||
@ -252,6 +247,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
|
||||
params.rowscale_const = rowscale_const;
|
||||
params.is_rms_norm = is_rms_norm;
|
||||
|
||||
// Query the kernel-specific launch parameters.
|
||||
launcher(launch_params, true);
|
||||
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
if (dropout_p > 0.f) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
@ -594,11 +594,6 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
|
||||
// Request the kernel launcher.
|
||||
auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
|
||||
|
||||
// Query the kernel-specific launch parameters.
|
||||
launcher(launch_params, true);
|
||||
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
// Set the kernel runtime parameters.
|
||||
layer_norm::FwdParams ¶ms = launch_params.params;
|
||||
params.rows = rows;
|
||||
@ -621,6 +616,11 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
|
||||
params.inverse_cols = 1.f / float(params.cols);
|
||||
params.is_rms_norm = is_rms_norm;
|
||||
|
||||
// Query the kernel-specific launch parameters.
|
||||
launcher(launch_params, true);
|
||||
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
if (dropout_p > 0.f) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
|
||||
Loading…
Reference in New Issue
Block a user