简单修改一下。

This commit is contained in:
long0x0 2025-01-04 13:47:42 +08:00
parent 80d7be70a5
commit 920ebe0f88
5 changed files with 62 additions and 3 deletions

29
csrc/attention.cu Normal file
View File

@ -0,0 +1,29 @@
#include "core.h"
// calculate the vec cum of different matrix row and col.
template <typename scalar_t>
__device__ scalar_t vecsum(scalar_t *q, scalar_t *k)
{
}
template <typename scalar_t>
__global__ void attention_kernel(const scalar_t *q,
const scalar_t *k,
const scalar_t *v,
int head_num,
int head_dim,
int seq_len,
int batch_size,
int hidden_dim,
scalar_t *output)
{
// calculate the gemm.
int tid = threadIdx.x;
// caculate the offset.
int q_offset = blockIdx.x * head_num * 1 * head_dim;
int k_offset = blockIdx.x * head_num * seq_len * head_dim;
int v_offset = blockIdx.x * head_num * seq_len * head_dim;
// calculate the sum.
// calculate the softmax
// calculate the weighted sum.
}

View File

@ -60,7 +60,7 @@ __global__ void matmul_sigmoid_cuda(const T *in1, const T *in2, T *output, int r
#define BASE_BLOCK 256
#define CALL_ADD_FUNCTION \
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
{
// cout << input1.dtype() << " the size 1 is : " << input1.size(0) << " size 2 is " << input1.size(1) << "output dim is :" << output.size(0) << output.size(1) << endl;

View File

@ -1 +1,4 @@
#include <cuda_fp8.h>
#include "core.h"
#include <cuda_fp8.h>
#define __nv_fp8_e4m3 fp8_e4m3

View File

@ -77,7 +77,6 @@ void rms_norm(torch::Tensor &states, float eps, float gamma)
int block_size = 1024;
dim3 block(h);
dim3 grid(block_size);
cout << states.scalar_type() << endl;
TYPING_DISPATCH(states.scalar_type(), [&]
{ rms_norm_kernel<fi_type><<<block, grid>>>(reinterpret_cast<fi_type *>(states.data_ptr()), hidden_dim, eps, gamma); });
}

28
fi/test_module.py Normal file
View File

@ -0,0 +1,28 @@
# coding=utf-8
import torch.nn as nn
class TestModule(nn.Module):
def __init__(self, start_layer_index: int, end_layer_index: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = DecodeLayer()
def forward(self, x):
for module in self.model:
x = module(x)
return x
class DecodeLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.layers = nn.ModuleList()
for i in range(10):
self.layers.append(nn.Linear(10, 10))
if __name__ == "__main__":
test_module = TestModule(0, 3)
for x in test_module.named_parameters():
print(x[0])