From 89e3b9d19032e7a7cabeca269011c5e5f0ac744a Mon Sep 17 00:00:00 2001 From: long0x0 Date: Fri, 28 Mar 2025 22:19:03 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=AC=E5=9C=B0=E4=BF=AE=E6=94=B9=E4=B8=80?= =?UTF-8?q?=E4=B8=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- csrc/softmax.ptx | 257 +++++++++++++++++++++++++++++++++++++++ fi/models/__init__.py | 0 fi/models/deepseek_v3.py | 0 fi/models/llama.py | 153 +++++++++++++++++++++++ test_fa.py | 112 +++++++++++++++++ training/__init__.py | 0 training/dp.py | 24 ++++ training/pp.py | 9 ++ 8 files changed, 555 insertions(+) create mode 100644 csrc/softmax.ptx create mode 100644 fi/models/__init__.py create mode 100644 fi/models/deepseek_v3.py create mode 100644 fi/models/llama.py create mode 100644 test_fa.py create mode 100644 training/__init__.py create mode 100644 training/dp.py create mode 100644 training/pp.py diff --git a/csrc/softmax.ptx b/csrc/softmax.ptx new file mode 100644 index 0000000..2aeed96 --- /dev/null +++ b/csrc/softmax.ptx @@ -0,0 +1,257 @@ +// +// Generated by NVIDIA NVVM Compiler +// +// Compiler Build ID: CL-34097967 +// Cuda compilation tools, release 12.4, V12.4.131 +// Based on NVVM 7.0.1 +// + +.version 8.4 +.target sm_52 +.address_size 64 + + // .globl _Z7findMaxPKfPfi +.extern .shared .align 16 .b8 sharedMax[]; +.extern .shared .align 16 .b8 sharedSum[]; + +.visible .entry _Z7findMaxPKfPfi( + .param .u64 _Z7findMaxPKfPfi_param_0, + .param .u64 _Z7findMaxPKfPfi_param_1, + .param .u32 _Z7findMaxPKfPfi_param_2 +) +{ + .reg .pred %p<6>; + .reg .f32 %f<9>; + .reg .b32 %r<15>; + .reg .b64 %rd<9>; + + + ld.param.u64 %rd1, [_Z7findMaxPKfPfi_param_0]; + ld.param.u64 %rd2, [_Z7findMaxPKfPfi_param_1]; + ld.param.u32 %r9, [_Z7findMaxPKfPfi_param_2]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + setp.ge.u32 %p1, %r4, %r9; + mov.f32 %f8, 0fFF800000; + @%p1 bra $L__BB0_2; + + cvta.to.global.u64 %rd3, %rd1; + mul.wide.u32 %rd4, %r4, 4; + add.s64 %rd5, %rd3, %rd4; + ld.global.f32 %f8, [%rd5]; + +$L__BB0_2: + shl.b32 %r10, %r3, 2; + mov.u32 %r11, sharedMax; + add.s32 %r5, %r11, %r10; + st.shared.f32 [%r5], %f8; + bar.sync 0; + shr.u32 %r14, %r1, 1; + setp.eq.s32 %p2, %r14, 0; + @%p2 bra $L__BB0_6; + +$L__BB0_3: + setp.ge.u32 %p3, %r3, %r14; + @%p3 bra $L__BB0_5; + + ld.shared.f32 %f4, [%r5]; + shl.b32 %r12, %r14, 2; + add.s32 %r13, %r5, %r12; + ld.shared.f32 %f5, [%r13]; + max.f32 %f6, %f4, %f5; + st.shared.f32 [%r5], %f6; + +$L__BB0_5: + bar.sync 0; + shr.u32 %r14, %r14, 1; + setp.ne.s32 %p4, %r14, 0; + @%p4 bra $L__BB0_3; + +$L__BB0_6: + setp.ne.s32 %p5, %r3, 0; + @%p5 bra $L__BB0_8; + + ld.shared.f32 %f7, [sharedMax]; + cvta.to.global.u64 %rd6, %rd2; + mul.wide.u32 %rd7, %r2, 4; + add.s64 %rd8, %rd6, %rd7; + st.global.f32 [%rd8], %f7; + +$L__BB0_8: + ret; + +} + // .globl _Z10computeExpPKffPfi +.visible .entry _Z10computeExpPKffPfi( + .param .u64 _Z10computeExpPKffPfi_param_0, + .param .f32 _Z10computeExpPKffPfi_param_1, + .param .u64 _Z10computeExpPKffPfi_param_2, + .param .u32 _Z10computeExpPKffPfi_param_3 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<20>; + .reg .b32 %r<8>; + .reg .b64 %rd<8>; + + + ld.param.u64 %rd1, [_Z10computeExpPKffPfi_param_0]; + ld.param.f32 %f1, [_Z10computeExpPKffPfi_param_1]; + ld.param.u64 %rd2, [_Z10computeExpPKffPfi_param_2]; + ld.param.u32 %r2, [_Z10computeExpPKffPfi_param_3]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r3, %r4, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra $L__BB1_2; + + cvta.to.global.u64 %rd3, %rd1; + mul.wide.u32 %rd4, %r1, 4; + add.s64 %rd5, %rd3, %rd4; + ld.global.f32 %f2, [%rd5]; + sub.f32 %f3, %f2, %f1; + mov.f32 %f4, 0f3F000000; + mov.f32 %f5, 0f3BBB989D; + fma.rn.f32 %f6, %f3, %f5, %f4; + cvt.sat.f32.f32 %f7, %f6; + mov.f32 %f8, 0f4B400001; + mov.f32 %f9, 0f437C0000; + fma.rm.f32 %f10, %f7, %f9, %f8; + add.f32 %f11, %f10, 0fCB40007F; + neg.f32 %f12, %f11; + mov.f32 %f13, 0f3FB8AA3B; + fma.rn.f32 %f14, %f3, %f13, %f12; + mov.f32 %f15, 0f32A57060; + fma.rn.f32 %f16, %f3, %f15, %f14; + mov.b32 %r6, %f10; + shl.b32 %r7, %r6, 23; + mov.b32 %f17, %r7; + ex2.approx.ftz.f32 %f18, %f16; + mul.f32 %f19, %f18, %f17; + cvta.to.global.u64 %rd6, %rd2; + add.s64 %rd7, %rd6, %rd4; + st.global.f32 [%rd7], %f19; + +$L__BB1_2: + ret; + +} + // .globl _Z13block_softmaxPKf +.visible .entry _Z13block_softmaxPKf( + .param .u64 _Z13block_softmaxPKf_param_0 +) +{ + + + + ret; + +} + // .globl _Z10computeSumPKfPfi +.visible .entry _Z10computeSumPKfPfi( + .param .u64 _Z10computeSumPKfPfi_param_0, + .param .u64 _Z10computeSumPKfPfi_param_1, + .param .u32 _Z10computeSumPKfPfi_param_2 +) +{ + .reg .pred %p<6>; + .reg .f32 %f<9>; + .reg .b32 %r<15>; + .reg .b64 %rd<9>; + + + ld.param.u64 %rd1, [_Z10computeSumPKfPfi_param_0]; + ld.param.u64 %rd2, [_Z10computeSumPKfPfi_param_1]; + ld.param.u32 %r9, [_Z10computeSumPKfPfi_param_2]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + setp.ge.u32 %p1, %r4, %r9; + mov.f32 %f8, 0f00000000; + @%p1 bra $L__BB3_2; + + cvta.to.global.u64 %rd3, %rd1; + mul.wide.u32 %rd4, %r4, 4; + add.s64 %rd5, %rd3, %rd4; + ld.global.f32 %f8, [%rd5]; + +$L__BB3_2: + shl.b32 %r10, %r3, 2; + mov.u32 %r11, sharedSum; + add.s32 %r5, %r11, %r10; + st.shared.f32 [%r5], %f8; + bar.sync 0; + shr.u32 %r14, %r1, 1; + setp.eq.s32 %p2, %r14, 0; + @%p2 bra $L__BB3_6; + +$L__BB3_3: + setp.ge.u32 %p3, %r3, %r14; + @%p3 bra $L__BB3_5; + + shl.b32 %r12, %r14, 2; + add.s32 %r13, %r5, %r12; + ld.shared.f32 %f4, [%r5]; + ld.shared.f32 %f5, [%r13]; + add.f32 %f6, %f5, %f4; + st.shared.f32 [%r5], %f6; + +$L__BB3_5: + bar.sync 0; + shr.u32 %r14, %r14, 1; + setp.ne.s32 %p4, %r14, 0; + @%p4 bra $L__BB3_3; + +$L__BB3_6: + setp.ne.s32 %p5, %r3, 0; + @%p5 bra $L__BB3_8; + + ld.shared.f32 %f7, [sharedSum]; + cvta.to.global.u64 %rd6, %rd2; + mul.wide.u32 %rd7, %r2, 4; + add.s64 %rd8, %rd6, %rd7; + st.global.f32 [%rd8], %f7; + +$L__BB3_8: + ret; + +} + // .globl _Z14computeSoftmaxPffi +.visible .entry _Z14computeSoftmaxPffi( + .param .u64 _Z14computeSoftmaxPffi_param_0, + .param .f32 _Z14computeSoftmaxPffi_param_1, + .param .u32 _Z14computeSoftmaxPffi_param_2 +) +{ + .reg .pred %p<2>; + .reg .f32 %f<4>; + .reg .b32 %r<6>; + .reg .b64 %rd<5>; + + + ld.param.u64 %rd1, [_Z14computeSoftmaxPffi_param_0]; + ld.param.f32 %f1, [_Z14computeSoftmaxPffi_param_1]; + ld.param.u32 %r2, [_Z14computeSoftmaxPffi_param_2]; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r1, %r3, %r4, %r5; + setp.ge.u32 %p1, %r1, %r2; + @%p1 bra $L__BB4_2; + + cvta.to.global.u64 %rd2, %rd1; + mul.wide.u32 %rd3, %r1, 4; + add.s64 %rd4, %rd2, %rd3; + ld.global.f32 %f2, [%rd4]; + div.rn.f32 %f3, %f2, %f1; + st.global.f32 [%rd4], %f3; + +$L__BB4_2: + ret; + +} + diff --git a/fi/models/__init__.py b/fi/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fi/models/deepseek_v3.py b/fi/models/deepseek_v3.py new file mode 100644 index 0000000..e69de29 diff --git a/fi/models/llama.py b/fi/models/llama.py new file mode 100644 index 0000000..b486441 --- /dev/null +++ b/fi/models/llama.py @@ -0,0 +1,153 @@ +# coding=utf-8 + +import torch +import torch.nn as nn + +from dataclasses import dataclass +from transformers.models.llama.configuration_llama import LlamaConfig + +from einops import rearrange +import torch.nn.functional as F + +from flash_attn_2_cuda import varlen_fwd + + +class MLP(nn.Module): + def __init__( + self, + config: LlamaConfig, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up = nn.Linear(self.hidden_size, self.intermediate_size) + self.gate = nn.Linear(self.hidden_size, self.intermediate_size) + self.down = nn.Linear(self.intermediate_size, self.hidden_size) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.down(self.gate(hidden_states) * self.up(hidden_states)) + return hidden_states + + +class LLAMAAttention(nn.Module): + def __init__(self, config: LlamaConfig, *args, **kwargs): + super().__init__(*args, **kwargs) + self.q = nn.Linear(config.hidden_size, config.hidden_size) + self.k = nn.Linear(config.hidden_size, config.hidden_size) + self.v = nn.Linear(config.hidden_size, config.hidden_size) + self.o = nn.Linear(config.hidden_size, config.hidden_size) + self.num_head = config.num_attention_heads + self.kv_head = config.num_key_value_heads + self.head_dim = config.head_dim + assert ( + self.num_head * self.head_dim == config.hidden_size + ), "make sure the num_head*head_dim == hidden_size" + + def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor): + q = self.q(hidden_states) + k = self.k(hidden_states) + v = self.v(hidden_states) + q = q.view(-1, self.num_head, self.head_dim) + k = k.view(-1, self.kv_head, self.head_dim) + v = v.view(-1, self.kv_head, self.head_dim) + # process positionids + # varlen_fwd # this function is the core of the process. + hidden_states = varlen_fwd( + q, + k, + v, + None, # output + cu_seqlens_q, # seqlen_q + cu_seqlens_k, # seqlen_k, v + seqused_k, # seqused_k + None, # block table + None, # alibi slopse + max_seqlen_q, # int + max_seqlen_k, # int + 0.0, # dropout + 0.0, # softmax_scale + False, # zero_tensors + True, # is causal + -1, # window size left + -1, # window size right + False, # return softmax + None, # gen + ) + # mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + # const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + # const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + # c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + # const at::Tensor &cu_seqlens_q, // b+1 + # const at::Tensor &cu_seqlens_k, // b+1 + # c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + # c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + # c10::optional &alibi_slopes_, // num_heads or b x num_heads + # int max_seqlen_q, + # const int max_seqlen_k, + # const float p_dropout, + # const float softmax_scale, + # const bool zero_tensors, + # bool is_causal, + # int window_size_left, + # int window_size_right, + # const bool return_softmax, + # c10::optional gen_) { + return hidden_states + + +class LLAMADecodeLayer(nn.Module): + def __init__(self, config: LlamaConfig, idx: int, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class LLAMAModel(nn.Module): + + def __init__(self, config: LlamaConfig, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_layer = config.num_hidden_layers + + self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [LLAMADecodeLayer(config=config, idx=i) for i in range(self.num_layer)] + ) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + def forward(self, input_ids, hidden_states, position_ids): + if input_ids is not None: + hidden_states = self.token_embed(input_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, position_ids) + output = self.lm_head(hidden_states) + return output + + def unpad_input(self, hidden_states, attention_mask): + hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...") + valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1) # some time is eq(1) + seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + hidden_states = hidden_states[indices].unsqueeze(0) + return hidden_states, indices, cu_seqlens, max_seqlen_in_batch + + def pad_input(self, hidden_states, indices, batch, seqlen): + """ + :param hidden_states: Shape is [L,H] not [B,L,H] + :param indices: from unpad_input return indices + :param batch: + :param seqlen: from unpad_input return max_seqlen_in_batch + :return: + """ + output = torch.zeros( + batch * seqlen, + *hidden_states.shape[1:], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/test_fa.py b/test_fa.py new file mode 100644 index 0000000..785a5ed --- /dev/null +++ b/test_fa.py @@ -0,0 +1,112 @@ +import torch +from flash_attn import flash_attn_varlen_func +from torch.nn import functional as F + +import os + +os.environ["TORCH_USE_CUDA_DSA"] = "1" + + +def generate_data(): + nhead = 32 + head_dim = 128 + + batch_size = 3 + k_past_len = [16, 32, 64] + + k_data = [] + v_data = [] + q_data = [] + for i in k_past_len: + k_data.append( + torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda") + ) + v_data.append( + torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda") + ) + q_data.append( + torch.randn(size=(1, nhead, head_dim), dtype=torch.float16, device="cuda") + ) + + k_data = torch.concat(k_data, dim=0) + v_data = torch.concat(v_data, dim=0) + q_data = torch.concat(q_data, dim=0) + torch.save(k_data, "k.pkl") + torch.save(v_data, "v.pkl") + torch.save(q_data, "q.pkl") + + +def test_fa(): + nhead = 32 + head_dim = 128 + + batch_size = 3 + k_past_len = [16, 32, 64] + + k_data = [] + v_data = [] + q_data = [] + k_data = torch.load("./k.pkl") + q_data = torch.load("./q.pkl") + v_data = torch.load("./v.pkl") + print(k_data.size(), v_data.size(), q_data.size()) + k_past_len = torch.tensor(k_past_len, dtype=torch.int32, device="cuda") + res = torch.cumsum(k_past_len, dim=-1) + k_past_len = torch.cat( + (res, torch.tensor([k_data.size(0)], device=res.device, dtype=torch.int32)) + ) + k_past_len = k_past_len.to(torch.int32) + + # cu_k_len = torch.cumsum(torch.tensor(k_past_len)) + # print(cu_k_len, cu_k_len.size()) + torch.cuda.synchronize() + fa_res = flash_attn_varlen_func( + q_data, + k_data, + v_data, + cu_seqlens_k=k_past_len, + cu_seqlens_q=k_past_len, + max_seqlen_k=3, + max_seqlen_q=1, + causal=True, + ) + # torch.cuda.synchronize() + torch.save(fa_res, "./fa.pkl") + + +def test_torch_sdqa(): + nhead = 32 + head_dim = 128 + + batch_size = 3 + k_past_len = [16, 32, 64] + + k_data = [] + v_data = [] + q_data = [] + q_data = torch.load("./q.pkl") + k_data = torch.load("./k.pkl") + v_data = torch.load("./v.pkl") + q = q_data[0, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous() + k = k_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous() + v = v_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous() + # torch.cuda.synchronize() + naive_res = F.scaled_dot_product_attention(q, k, v, is_causal=True) + print(naive_res.size()) + naive_res = naive_res.view(1, nhead, head_dim) + fa_res = torch.load("./fa.pkl") + if torch.allclose(naive_res, fa_res[0, :, :]): + print("this is dam right") + else: + print("max diff is ", torch.abs(naive_res - fa_res[0, :, :]).max()) + # torch.cuda.synchronize() + + +# do not know if such method is right. +# chunk prefill did use such method? + + +if __name__ == "__main__": + generate_data() + test_fa() + test_torch_sdqa() diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/dp.py b/training/dp.py new file mode 100644 index 0000000..8d212e9 --- /dev/null +++ b/training/dp.py @@ -0,0 +1,24 @@ +# coding=utf-8 +import os +import torch +import torch.nn as nn +import torch.multiprocessing as mp + + +class DataParallel(object): + def __init__(self, dp_num, world_size: int, rank: int, module=None): + self.world_size = world_size + self.rank = rank + self.dp_num = dp_num + + def forward_pipeline(self): + pass + + def backward_pipeline(self): + pass + + def train(self): + pass + + def sync_grad(self, param): + pass diff --git a/training/pp.py b/training/pp.py new file mode 100644 index 0000000..6935ec5 --- /dev/null +++ b/training/pp.py @@ -0,0 +1,9 @@ +# coding=utf-8 + +import torch +import torch.nn as nn + + +class PipelineParallel(object): + def __init__(self): + pass