简单修改一下。
This commit is contained in:
parent
80d7be70a5
commit
920ebe0f88
29
csrc/attention.cu
Normal file
29
csrc/attention.cu
Normal file
@ -0,0 +1,29 @@
|
||||
#include "core.h"
|
||||
|
||||
// calculate the vec cum of different matrix row and col.
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t vecsum(scalar_t *q, scalar_t *k)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void attention_kernel(const scalar_t *q,
|
||||
const scalar_t *k,
|
||||
const scalar_t *v,
|
||||
int head_num,
|
||||
int head_dim,
|
||||
int seq_len,
|
||||
int batch_size,
|
||||
int hidden_dim,
|
||||
scalar_t *output)
|
||||
{
|
||||
// calculate the gemm.
|
||||
int tid = threadIdx.x;
|
||||
// caculate the offset.
|
||||
int q_offset = blockIdx.x * head_num * 1 * head_dim;
|
||||
int k_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||||
int v_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||||
// calculate the sum.
|
||||
// calculate the softmax
|
||||
// calculate the weighted sum.
|
||||
}
|
||||
@ -60,7 +60,7 @@ __global__ void matmul_sigmoid_cuda(const T *in1, const T *in2, T *output, int r
|
||||
|
||||
#define BASE_BLOCK 256
|
||||
#define CALL_ADD_FUNCTION \
|
||||
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
||||
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
||||
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
||||
{
|
||||
// cout << input1.dtype() << " the size 1 is : " << input1.size(0) << " size 2 is " << input1.size(1) << "output dim is :" << output.size(0) << output.size(1) << endl;
|
||||
|
||||
@ -1 +1,4 @@
|
||||
#include <cuda_fp8.h>
|
||||
#include "core.h"
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#define __nv_fp8_e4m3 fp8_e4m3
|
||||
|
||||
@ -77,7 +77,6 @@ void rms_norm(torch::Tensor &states, float eps, float gamma)
|
||||
int block_size = 1024;
|
||||
dim3 block(h);
|
||||
dim3 grid(block_size);
|
||||
cout << states.scalar_type() << endl;
|
||||
TYPING_DISPATCH(states.scalar_type(), [&]
|
||||
{ rms_norm_kernel<fi_type><<<block, grid>>>(reinterpret_cast<fi_type *>(states.data_ptr()), hidden_dim, eps, gamma); });
|
||||
}
|
||||
28
fi/test_module.py
Normal file
28
fi/test_module.py
Normal file
@ -0,0 +1,28 @@
|
||||
# coding=utf-8
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, start_layer_index: int, end_layer_index: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model = DecodeLayer()
|
||||
|
||||
def forward(self, x):
|
||||
for module in self.model:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
|
||||
class DecodeLayer(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(10):
|
||||
self.layers.append(nn.Linear(10, 10))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_module = TestModule(0, 3)
|
||||
for x in test_module.named_parameters():
|
||||
print(x[0])
|
||||
Loading…
Reference in New Issue
Block a user