简单修改一下。
This commit is contained in:
parent
80d7be70a5
commit
920ebe0f88
29
csrc/attention.cu
Normal file
29
csrc/attention.cu
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#include "core.h"
|
||||||
|
|
||||||
|
// calculate the vec cum of different matrix row and col.
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ scalar_t vecsum(scalar_t *q, scalar_t *k)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void attention_kernel(const scalar_t *q,
|
||||||
|
const scalar_t *k,
|
||||||
|
const scalar_t *v,
|
||||||
|
int head_num,
|
||||||
|
int head_dim,
|
||||||
|
int seq_len,
|
||||||
|
int batch_size,
|
||||||
|
int hidden_dim,
|
||||||
|
scalar_t *output)
|
||||||
|
{
|
||||||
|
// calculate the gemm.
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
// caculate the offset.
|
||||||
|
int q_offset = blockIdx.x * head_num * 1 * head_dim;
|
||||||
|
int k_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||||||
|
int v_offset = blockIdx.x * head_num * seq_len * head_dim;
|
||||||
|
// calculate the sum.
|
||||||
|
// calculate the softmax
|
||||||
|
// calculate the weighted sum.
|
||||||
|
}
|
||||||
@ -60,7 +60,7 @@ __global__ void matmul_sigmoid_cuda(const T *in1, const T *in2, T *output, int r
|
|||||||
|
|
||||||
#define BASE_BLOCK 256
|
#define BASE_BLOCK 256
|
||||||
#define CALL_ADD_FUNCTION \
|
#define CALL_ADD_FUNCTION \
|
||||||
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
add_two_tensors_cuda<<<(input1.size(0) * input1.size(1) + BASE_BLOCK - 1) / BASE_BLOCK, BASE_BLOCK>>>(src, src1, dest, input1.size(0) * input1.size(1));
|
||||||
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output)
|
||||||
{
|
{
|
||||||
// cout << input1.dtype() << " the size 1 is : " << input1.size(0) << " size 2 is " << input1.size(1) << "output dim is :" << output.size(0) << output.size(1) << endl;
|
// cout << input1.dtype() << " the size 1 is : " << input1.size(0) << " size 2 is " << input1.size(1) << "output dim is :" << output.size(0) << output.size(1) << endl;
|
||||||
|
|||||||
@ -1 +1,4 @@
|
|||||||
|
#include "core.h"
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
|
|
||||||
|
#define __nv_fp8_e4m3 fp8_e4m3
|
||||||
|
|||||||
@ -77,7 +77,6 @@ void rms_norm(torch::Tensor &states, float eps, float gamma)
|
|||||||
int block_size = 1024;
|
int block_size = 1024;
|
||||||
dim3 block(h);
|
dim3 block(h);
|
||||||
dim3 grid(block_size);
|
dim3 grid(block_size);
|
||||||
cout << states.scalar_type() << endl;
|
|
||||||
TYPING_DISPATCH(states.scalar_type(), [&]
|
TYPING_DISPATCH(states.scalar_type(), [&]
|
||||||
{ rms_norm_kernel<fi_type><<<block, grid>>>(reinterpret_cast<fi_type *>(states.data_ptr()), hidden_dim, eps, gamma); });
|
{ rms_norm_kernel<fi_type><<<block, grid>>>(reinterpret_cast<fi_type *>(states.data_ptr()), hidden_dim, eps, gamma); });
|
||||||
}
|
}
|
||||||
28
fi/test_module.py
Normal file
28
fi/test_module.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class TestModule(nn.Module):
|
||||||
|
def __init__(self, start_layer_index: int, end_layer_index: int, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.model = DecodeLayer()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for module in self.model:
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeLayer(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for i in range(10):
|
||||||
|
self.layers.append(nn.Linear(10, 10))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_module = TestModule(0, 3)
|
||||||
|
for x in test_module.named_parameters():
|
||||||
|
print(x[0])
|
||||||
Loading…
Reference in New Issue
Block a user