本地修改一下。

This commit is contained in:
long0x0 2025-03-28 22:19:03 +08:00
parent 920ebe0f88
commit 89e3b9d190
8 changed files with 555 additions and 0 deletions

257
csrc/softmax.ptx Normal file
View File

@ -0,0 +1,257 @@
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-34097967
// Cuda compilation tools, release 12.4, V12.4.131
// Based on NVVM 7.0.1
//
.version 8.4
.target sm_52
.address_size 64
// .globl _Z7findMaxPKfPfi
.extern .shared .align 16 .b8 sharedMax[];
.extern .shared .align 16 .b8 sharedSum[];
.visible .entry _Z7findMaxPKfPfi(
.param .u64 _Z7findMaxPKfPfi_param_0,
.param .u64 _Z7findMaxPKfPfi_param_1,
.param .u32 _Z7findMaxPKfPfi_param_2
)
{
.reg .pred %p<6>;
.reg .f32 %f<9>;
.reg .b32 %r<15>;
.reg .b64 %rd<9>;
ld.param.u64 %rd1, [_Z7findMaxPKfPfi_param_0];
ld.param.u64 %rd2, [_Z7findMaxPKfPfi_param_1];
ld.param.u32 %r9, [_Z7findMaxPKfPfi_param_2];
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %tid.x;
mad.lo.s32 %r4, %r2, %r1, %r3;
setp.ge.u32 %p1, %r4, %r9;
mov.f32 %f8, 0fFF800000;
@%p1 bra $L__BB0_2;
cvta.to.global.u64 %rd3, %rd1;
mul.wide.u32 %rd4, %r4, 4;
add.s64 %rd5, %rd3, %rd4;
ld.global.f32 %f8, [%rd5];
$L__BB0_2:
shl.b32 %r10, %r3, 2;
mov.u32 %r11, sharedMax;
add.s32 %r5, %r11, %r10;
st.shared.f32 [%r5], %f8;
bar.sync 0;
shr.u32 %r14, %r1, 1;
setp.eq.s32 %p2, %r14, 0;
@%p2 bra $L__BB0_6;
$L__BB0_3:
setp.ge.u32 %p3, %r3, %r14;
@%p3 bra $L__BB0_5;
ld.shared.f32 %f4, [%r5];
shl.b32 %r12, %r14, 2;
add.s32 %r13, %r5, %r12;
ld.shared.f32 %f5, [%r13];
max.f32 %f6, %f4, %f5;
st.shared.f32 [%r5], %f6;
$L__BB0_5:
bar.sync 0;
shr.u32 %r14, %r14, 1;
setp.ne.s32 %p4, %r14, 0;
@%p4 bra $L__BB0_3;
$L__BB0_6:
setp.ne.s32 %p5, %r3, 0;
@%p5 bra $L__BB0_8;
ld.shared.f32 %f7, [sharedMax];
cvta.to.global.u64 %rd6, %rd2;
mul.wide.u32 %rd7, %r2, 4;
add.s64 %rd8, %rd6, %rd7;
st.global.f32 [%rd8], %f7;
$L__BB0_8:
ret;
}
// .globl _Z10computeExpPKffPfi
.visible .entry _Z10computeExpPKffPfi(
.param .u64 _Z10computeExpPKffPfi_param_0,
.param .f32 _Z10computeExpPKffPfi_param_1,
.param .u64 _Z10computeExpPKffPfi_param_2,
.param .u32 _Z10computeExpPKffPfi_param_3
)
{
.reg .pred %p<2>;
.reg .f32 %f<20>;
.reg .b32 %r<8>;
.reg .b64 %rd<8>;
ld.param.u64 %rd1, [_Z10computeExpPKffPfi_param_0];
ld.param.f32 %f1, [_Z10computeExpPKffPfi_param_1];
ld.param.u64 %rd2, [_Z10computeExpPKffPfi_param_2];
ld.param.u32 %r2, [_Z10computeExpPKffPfi_param_3];
mov.u32 %r3, %ctaid.x;
mov.u32 %r4, %ntid.x;
mov.u32 %r5, %tid.x;
mad.lo.s32 %r1, %r3, %r4, %r5;
setp.ge.u32 %p1, %r1, %r2;
@%p1 bra $L__BB1_2;
cvta.to.global.u64 %rd3, %rd1;
mul.wide.u32 %rd4, %r1, 4;
add.s64 %rd5, %rd3, %rd4;
ld.global.f32 %f2, [%rd5];
sub.f32 %f3, %f2, %f1;
mov.f32 %f4, 0f3F000000;
mov.f32 %f5, 0f3BBB989D;
fma.rn.f32 %f6, %f3, %f5, %f4;
cvt.sat.f32.f32 %f7, %f6;
mov.f32 %f8, 0f4B400001;
mov.f32 %f9, 0f437C0000;
fma.rm.f32 %f10, %f7, %f9, %f8;
add.f32 %f11, %f10, 0fCB40007F;
neg.f32 %f12, %f11;
mov.f32 %f13, 0f3FB8AA3B;
fma.rn.f32 %f14, %f3, %f13, %f12;
mov.f32 %f15, 0f32A57060;
fma.rn.f32 %f16, %f3, %f15, %f14;
mov.b32 %r6, %f10;
shl.b32 %r7, %r6, 23;
mov.b32 %f17, %r7;
ex2.approx.ftz.f32 %f18, %f16;
mul.f32 %f19, %f18, %f17;
cvta.to.global.u64 %rd6, %rd2;
add.s64 %rd7, %rd6, %rd4;
st.global.f32 [%rd7], %f19;
$L__BB1_2:
ret;
}
// .globl _Z13block_softmaxPKf
.visible .entry _Z13block_softmaxPKf(
.param .u64 _Z13block_softmaxPKf_param_0
)
{
ret;
}
// .globl _Z10computeSumPKfPfi
.visible .entry _Z10computeSumPKfPfi(
.param .u64 _Z10computeSumPKfPfi_param_0,
.param .u64 _Z10computeSumPKfPfi_param_1,
.param .u32 _Z10computeSumPKfPfi_param_2
)
{
.reg .pred %p<6>;
.reg .f32 %f<9>;
.reg .b32 %r<15>;
.reg .b64 %rd<9>;
ld.param.u64 %rd1, [_Z10computeSumPKfPfi_param_0];
ld.param.u64 %rd2, [_Z10computeSumPKfPfi_param_1];
ld.param.u32 %r9, [_Z10computeSumPKfPfi_param_2];
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %tid.x;
mad.lo.s32 %r4, %r2, %r1, %r3;
setp.ge.u32 %p1, %r4, %r9;
mov.f32 %f8, 0f00000000;
@%p1 bra $L__BB3_2;
cvta.to.global.u64 %rd3, %rd1;
mul.wide.u32 %rd4, %r4, 4;
add.s64 %rd5, %rd3, %rd4;
ld.global.f32 %f8, [%rd5];
$L__BB3_2:
shl.b32 %r10, %r3, 2;
mov.u32 %r11, sharedSum;
add.s32 %r5, %r11, %r10;
st.shared.f32 [%r5], %f8;
bar.sync 0;
shr.u32 %r14, %r1, 1;
setp.eq.s32 %p2, %r14, 0;
@%p2 bra $L__BB3_6;
$L__BB3_3:
setp.ge.u32 %p3, %r3, %r14;
@%p3 bra $L__BB3_5;
shl.b32 %r12, %r14, 2;
add.s32 %r13, %r5, %r12;
ld.shared.f32 %f4, [%r5];
ld.shared.f32 %f5, [%r13];
add.f32 %f6, %f5, %f4;
st.shared.f32 [%r5], %f6;
$L__BB3_5:
bar.sync 0;
shr.u32 %r14, %r14, 1;
setp.ne.s32 %p4, %r14, 0;
@%p4 bra $L__BB3_3;
$L__BB3_6:
setp.ne.s32 %p5, %r3, 0;
@%p5 bra $L__BB3_8;
ld.shared.f32 %f7, [sharedSum];
cvta.to.global.u64 %rd6, %rd2;
mul.wide.u32 %rd7, %r2, 4;
add.s64 %rd8, %rd6, %rd7;
st.global.f32 [%rd8], %f7;
$L__BB3_8:
ret;
}
// .globl _Z14computeSoftmaxPffi
.visible .entry _Z14computeSoftmaxPffi(
.param .u64 _Z14computeSoftmaxPffi_param_0,
.param .f32 _Z14computeSoftmaxPffi_param_1,
.param .u32 _Z14computeSoftmaxPffi_param_2
)
{
.reg .pred %p<2>;
.reg .f32 %f<4>;
.reg .b32 %r<6>;
.reg .b64 %rd<5>;
ld.param.u64 %rd1, [_Z14computeSoftmaxPffi_param_0];
ld.param.f32 %f1, [_Z14computeSoftmaxPffi_param_1];
ld.param.u32 %r2, [_Z14computeSoftmaxPffi_param_2];
mov.u32 %r3, %ctaid.x;
mov.u32 %r4, %ntid.x;
mov.u32 %r5, %tid.x;
mad.lo.s32 %r1, %r3, %r4, %r5;
setp.ge.u32 %p1, %r1, %r2;
@%p1 bra $L__BB4_2;
cvta.to.global.u64 %rd2, %rd1;
mul.wide.u32 %rd3, %r1, 4;
add.s64 %rd4, %rd2, %rd3;
ld.global.f32 %f2, [%rd4];
div.rn.f32 %f3, %f2, %f1;
st.global.f32 [%rd4], %f3;
$L__BB4_2:
ret;
}

0
fi/models/__init__.py Normal file
View File

0
fi/models/deepseek_v3.py Normal file
View File

153
fi/models/llama.py Normal file
View File

@ -0,0 +1,153 @@
# coding=utf-8
import torch
import torch.nn as nn
from dataclasses import dataclass
from transformers.models.llama.configuration_llama import LlamaConfig
from einops import rearrange
import torch.nn.functional as F
from flash_attn_2_cuda import varlen_fwd
class MLP(nn.Module):
def __init__(
self,
config: LlamaConfig,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up = nn.Linear(self.hidden_size, self.intermediate_size)
self.gate = nn.Linear(self.hidden_size, self.intermediate_size)
self.down = nn.Linear(self.intermediate_size, self.hidden_size)
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.down(self.gate(hidden_states) * self.up(hidden_states))
return hidden_states
class LLAMAAttention(nn.Module):
def __init__(self, config: LlamaConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.q = nn.Linear(config.hidden_size, config.hidden_size)
self.k = nn.Linear(config.hidden_size, config.hidden_size)
self.v = nn.Linear(config.hidden_size, config.hidden_size)
self.o = nn.Linear(config.hidden_size, config.hidden_size)
self.num_head = config.num_attention_heads
self.kv_head = config.num_key_value_heads
self.head_dim = config.head_dim
assert (
self.num_head * self.head_dim == config.hidden_size
), "make sure the num_head*head_dim == hidden_size"
def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor):
q = self.q(hidden_states)
k = self.k(hidden_states)
v = self.v(hidden_states)
q = q.view(-1, self.num_head, self.head_dim)
k = k.view(-1, self.kv_head, self.head_dim)
v = v.view(-1, self.kv_head, self.head_dim)
# process positionids
# varlen_fwd # this function is the core of the process.
hidden_states = varlen_fwd(
q,
k,
v,
None, # output
cu_seqlens_q, # seqlen_q
cu_seqlens_k, # seqlen_k, v
seqused_k, # seqused_k
None, # block table
None, # alibi slopse
max_seqlen_q, # int
max_seqlen_k, # int
0.0, # dropout
0.0, # softmax_scale
False, # zero_tensors
True, # is causal
-1, # window size left
-1, # window size right
False, # return softmax
None, # gen
)
# mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
# const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
# const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
# c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
# const at::Tensor &cu_seqlens_q, // b+1
# const at::Tensor &cu_seqlens_k, // b+1
# c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
# c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
# c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
# int max_seqlen_q,
# const int max_seqlen_k,
# const float p_dropout,
# const float softmax_scale,
# const bool zero_tensors,
# bool is_causal,
# int window_size_left,
# int window_size_right,
# const bool return_softmax,
# c10::optional<at::Generator> gen_) {
return hidden_states
class LLAMADecodeLayer(nn.Module):
def __init__(self, config: LlamaConfig, idx: int, *args, **kwargs):
super().__init__(*args, **kwargs)
class LLAMAModel(nn.Module):
def __init__(self, config: LlamaConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_layer = config.num_hidden_layers
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[LLAMADecodeLayer(config=config, idx=i) for i in range(self.num_layer)]
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids, hidden_states, position_ids):
if input_ids is not None:
hidden_states = self.token_embed(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, position_ids)
output = self.lm_head(hidden_states)
return output
def unpad_input(self, hidden_states, attention_mask):
hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...")
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1) # some time is eq(1)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
hidden_states = hidden_states[indices].unsqueeze(0)
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def pad_input(self, hidden_states, indices, batch, seqlen):
"""
:param hidden_states: Shape is [L,H] not [B,L,H]
:param indices: from unpad_input return indices
:param batch:
:param seqlen: from unpad_input return max_seqlen_in_batch
:return:
"""
output = torch.zeros(
batch * seqlen,
*hidden_states.shape[1:],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
output[indices] = hidden_states
return rearrange(output, "(b s) ... -> b s ...", b=batch)

112
test_fa.py Normal file
View File

@ -0,0 +1,112 @@
import torch
from flash_attn import flash_attn_varlen_func
from torch.nn import functional as F
import os
os.environ["TORCH_USE_CUDA_DSA"] = "1"
def generate_data():
nhead = 32
head_dim = 128
batch_size = 3
k_past_len = [16, 32, 64]
k_data = []
v_data = []
q_data = []
for i in k_past_len:
k_data.append(
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
)
v_data.append(
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
)
q_data.append(
torch.randn(size=(1, nhead, head_dim), dtype=torch.float16, device="cuda")
)
k_data = torch.concat(k_data, dim=0)
v_data = torch.concat(v_data, dim=0)
q_data = torch.concat(q_data, dim=0)
torch.save(k_data, "k.pkl")
torch.save(v_data, "v.pkl")
torch.save(q_data, "q.pkl")
def test_fa():
nhead = 32
head_dim = 128
batch_size = 3
k_past_len = [16, 32, 64]
k_data = []
v_data = []
q_data = []
k_data = torch.load("./k.pkl")
q_data = torch.load("./q.pkl")
v_data = torch.load("./v.pkl")
print(k_data.size(), v_data.size(), q_data.size())
k_past_len = torch.tensor(k_past_len, dtype=torch.int32, device="cuda")
res = torch.cumsum(k_past_len, dim=-1)
k_past_len = torch.cat(
(res, torch.tensor([k_data.size(0)], device=res.device, dtype=torch.int32))
)
k_past_len = k_past_len.to(torch.int32)
# cu_k_len = torch.cumsum(torch.tensor(k_past_len))
# print(cu_k_len, cu_k_len.size())
torch.cuda.synchronize()
fa_res = flash_attn_varlen_func(
q_data,
k_data,
v_data,
cu_seqlens_k=k_past_len,
cu_seqlens_q=k_past_len,
max_seqlen_k=3,
max_seqlen_q=1,
causal=True,
)
# torch.cuda.synchronize()
torch.save(fa_res, "./fa.pkl")
def test_torch_sdqa():
nhead = 32
head_dim = 128
batch_size = 3
k_past_len = [16, 32, 64]
k_data = []
v_data = []
q_data = []
q_data = torch.load("./q.pkl")
k_data = torch.load("./k.pkl")
v_data = torch.load("./v.pkl")
q = q_data[0, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
k = k_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
v = v_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
# torch.cuda.synchronize()
naive_res = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(naive_res.size())
naive_res = naive_res.view(1, nhead, head_dim)
fa_res = torch.load("./fa.pkl")
if torch.allclose(naive_res, fa_res[0, :, :]):
print("this is dam right")
else:
print("max diff is ", torch.abs(naive_res - fa_res[0, :, :]).max())
# torch.cuda.synchronize()
# do not know if such method is right.
# chunk prefill did use such method?
if __name__ == "__main__":
generate_data()
test_fa()
test_torch_sdqa()

0
training/__init__.py Normal file
View File

24
training/dp.py Normal file
View File

@ -0,0 +1,24 @@
# coding=utf-8
import os
import torch
import torch.nn as nn
import torch.multiprocessing as mp
class DataParallel(object):
def __init__(self, dp_num, world_size: int, rank: int, module=None):
self.world_size = world_size
self.rank = rank
self.dp_num = dp_num
def forward_pipeline(self):
pass
def backward_pipeline(self):
pass
def train(self):
pass
def sync_grad(self, param):
pass

9
training/pp.py Normal file
View File

@ -0,0 +1,9 @@
# coding=utf-8
import torch
import torch.nn as nn
class PipelineParallel(object):
def __init__(self):
pass