torch_ext/csrc/softmax.cu
2024-11-16 19:26:54 +08:00

147 lines
4.0 KiB
Plaintext

#include <cuda_runtime.h>
#include <iostream>
#include <cmath>
// Kernel function to compute the maximum value in the array
__global__ void findMax(const float *input, float *maxValue, int size)
{
extern __shared__ float sharedMax[];
unsigned int tid = threadIdx.x;
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
sharedMax[tid] = (i < size) ? input[i] : -INFINITY;
__syncthreads();
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1)
{
if (tid < s)
{
sharedMax[tid] = fmaxf(sharedMax[tid], sharedMax[tid + s]);
}
__syncthreads();
}
if (tid == 0)
{
maxValue[blockIdx.x] = sharedMax[0];
}
}
// Kernel function to compute the exponential of the input values minus the max value
__global__ void computeExp(const float *input, float maxValue, float *output, int size)
{
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < size)
{
output[i] = expf(input[i] - maxValue);
}
}
__global__ void block_softmax(const float *input)
{
}
// Kernel function to compute the sum of the exponential values
__global__ void computeSum(const float *expValues, float *sumValue, int size)
{
extern __shared__ float sharedSum[];
unsigned int tid = threadIdx.x;
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
sharedSum[tid] = (i < size) ? expValues[i] : 0.0f;
__syncthreads();
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1)
{
if (tid < s)
{
sharedSum[tid] += sharedSum[tid + s];
}
__syncthreads();
}
if (tid == 0)
{
sumValue[blockIdx.x] = sharedSum[0];
}
}
// Kernel function to compute the final softmax values
__global__ void computeSoftmax(float *output, float sumValue, int size)
{
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < size)
{
output[i] /= sumValue;
}
}
void softmax(const float *input, float *output, int size)
{
const int threadsPerBlock = 256;
const int blocksPerGrid = (size + threadsPerBlock - 1) / threadsPerBlock;
float *d_input;
float *d_output;
float *d_maxValue;
float *d_sumValue;
float h_maxValue[blocksPerGrid];
float h_sumValue[blocksPerGrid];
cudaMalloc((void **)&d_input, size * sizeof(float));
cudaMalloc((void **)&d_output, size * sizeof(float));
cudaMalloc((void **)&d_maxValue, blocksPerGrid * sizeof(float));
cudaMalloc((void **)&d_sumValue, blocksPerGrid * sizeof(float));
cudaMemcpy(d_input, input, size * sizeof(float), cudaMemcpyHostToDevice);
findMax<<<blocksPerGrid, threadsPerBlock, threadsPerBlock * sizeof(float)>>>(d_input, d_maxValue, size);
cudaMemcpy(h_maxValue, d_maxValue, blocksPerGrid * sizeof(float), cudaMemcpyDeviceToHost);
float maxValue = -INFINITY;
for (int i = 0; i < blocksPerGrid; ++i)
{
maxValue = fmaxf(maxValue, h_maxValue[i]);
}
computeExp<<<blocksPerGrid, threadsPerBlock>>>(d_input, maxValue, d_output, size);
computeSum<<<blocksPerGrid, threadsPerBlock, threadsPerBlock * sizeof(float)>>>(d_output, d_sumValue, size);
cudaMemcpy(h_sumValue, d_sumValue, blocksPerGrid * sizeof(float), cudaMemcpyDeviceToHost);
float sumValue = 0.0f;
for (int i = 0; i < blocksPerGrid; ++i)
{
sumValue += h_sumValue[i];
}
computeSoftmax<<<blocksPerGrid, threadsPerBlock>>>(d_output, sumValue, size);
cudaMemcpy(output, d_output, size * sizeof(float), cudaMemcpyDeviceToHost);
cudaFree(d_input);
cudaFree(d_output);
cudaFree(d_maxValue);
cudaFree(d_sumValue);
}
int main()
{
const int size = 1024;
float input[size];
float output[size];
// Initialize input array with some values
for (int i = 0; i < size; ++i)
{
input[i] = static_cast<float>(i);
}
softmax(input, output, size);
// Print the output values
for (int i = 0; i < size; ++i)
{
std::cout << output[i] << " ";
}
std::cout << std::endl;
return 0;
}