torch_ext/csrc/random_env.cu
2024-12-14 13:34:30 +08:00

35 lines
1.0 KiB
Plaintext

#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <cub/cub.cuh>
#include <cub/util_device.cuh>
__global__ void initRandom(curandState *state, unsigned long seed)
{
int id = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(seed, id, 0, &state[id]);
}
__global__ void random_generate(float *out, curandState *state)
{
curandState localState = state[id];
__shared__ float shared_data[1024];
int idx = threadIdx.x;
typedef cub::BlockReduce<float, 1024> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int i = 0; i < 1024; i++)
{
shared_data[idx] += curand_uniform(&localState);
float sum = BlockReduce(temp_storage).Sum(shared_data[idx]);
shared_data[idx] += shared_data[idx] / sum;
}
out[idx] = shared_data[idx];
}
void random_invoke()
{
curandState *devStates;
int thread_num = 1024;
float out[1024];
initRandom<<<1, thread_num>>>(devStates, 1234);
random_generate<<<1, thread_num>>>(out, devStates);
}