2023-07-17 20:26:11 +08:00
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
2022-05-21 05:21:58 +08:00
#pragma once
// Philox CUDA.
2023-07-17 20:26:11 +08:00
namespace flash {
struct ull2 {
unsigned long long x;
unsigned long long y;
};
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res;
unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t"
: "=l"(tmp)
: "r"(a), "r"(b));
res = (uint2*)(&tmp);
return *res;
}
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
inline __device__ uint4 philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9;
constexpr unsigned long kPhilox10B = 0xBB67AE85;
uint2 key = reinterpret_cast<uint2&>(seed);
uint4 counter;
ull2 *tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset;
tmp->y = subsequence;
#pragma unroll
for (int i = 0; i < 6; i++) {
counter = philox_single_round(counter, key);
key.x += (kPhilox10A);
key.y += (kPhilox10B);
}
uint4 output = philox_single_round(counter, key);
return output;
}
} // namespace flash
2022-05-21 05:21:58 +08:00
namespace {
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset)
2023-07-17 20:26:11 +08:00
: STATE(0)
, seed_(seed)
, offset_(offset)
, key(reinterpret_cast<const uint2&>(seed)) {
2022-05-21 05:21:58 +08:00
//key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0);
//counter.z = (unsigned int)(subsequence);
//counter.w = (unsigned int)(subsequence >> 32);
//STATE = 0;
//incr_n(offset / 4);
2023-07-17 20:26:11 +08:00
// key = reinterpret_cast<const uint2&>(seed);
2022-05-21 05:21:58 +08:00
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset / 4;
tmp->y = subsequence;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ inline uint4 operator()() {
2023-07-17 20:26:11 +08:00
// // if (STATE == 0) {
// uint4 counter_ = counter;
// uint2 key_ = key;
// // 7-round philox
// #pragma unroll
// for (int i = 0; i < 6; i++) {
// counter_ = flash::philox_single_round(counter_, key_);
// key_.x += (kPhilox10A);
// key_.y += (kPhilox10B);
// }
// // output = philox_single_round(counter_, key_);
// uint4 output = flash::philox_single_round(counter_, key_);
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// // }
// incr();
// // }
// // return a float4 directly
// // unsigned long ret;
// // switch(STATE) {
// // case 0: ret = output.x; break;
// // case 1: ret = output.y; break;
// // case 2: ret = output.z; break;
// // case 3: ret = output.w; break;
// //}
// // STATE = (STATE + 1) % 4;
// return output;
return flash::philox(seed_, offset_, offset_);
2022-05-21 05:21:58 +08:00
}
private:
2023-07-17 20:26:11 +08:00
unsigned long long offset_, seed_;
2022-05-21 05:21:58 +08:00
struct ull2 {
uint64_t x;
uint64_t y;
};
uint4 counter;
2023-07-17 20:26:11 +08:00
// uint4 output;
2022-05-21 05:21:58 +08:00
const uint2 key;
2023-07-17 20:26:11 +08:00
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
2022-05-21 05:21:58 +08:00
2023-07-17 20:26:11 +08:00
__device__ uint4 incr128 (uint4 ctr)
{
2022-05-21 05:21:58 +08:00
uint4 res;
asm ("add.cc.u32 %0, %4, %8;\n\t"
"addc.cc.u32 %1, %5, %9;\n\t"
"addc.cc.u32 %2, %6, %10;\n\t"
"addc.u32 %3, %7, %11;\n\t"
: "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
: "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
"n"(1), "n"(0), "n"(0), "n"(0));
return res;
}
__device__ inline void incr() {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
2023-07-17 20:26:11 +08:00
counter = incr128(counter);
2022-05-21 05:21:58 +08:00
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
}
2022-10-19 14:04:01 +08:00
2022-05-21 05:21:58 +08:00
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
2023-07-17 20:26:11 +08:00
// static const unsigned long kPhiloxSA = 0xD2511F53;
// static const unsigned long kPhiloxSB = 0xCD9E8D57;
2022-05-21 05:21:58 +08:00
};
2022-10-19 14:04:01 +08:00
2022-05-21 05:21:58 +08:00
} // namespace