本地修改一下。
This commit is contained in:
parent
920ebe0f88
commit
89e3b9d190
257
csrc/softmax.ptx
Normal file
257
csrc/softmax.ptx
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
//
|
||||||
|
// Generated by NVIDIA NVVM Compiler
|
||||||
|
//
|
||||||
|
// Compiler Build ID: CL-34097967
|
||||||
|
// Cuda compilation tools, release 12.4, V12.4.131
|
||||||
|
// Based on NVVM 7.0.1
|
||||||
|
//
|
||||||
|
|
||||||
|
.version 8.4
|
||||||
|
.target sm_52
|
||||||
|
.address_size 64
|
||||||
|
|
||||||
|
// .globl _Z7findMaxPKfPfi
|
||||||
|
.extern .shared .align 16 .b8 sharedMax[];
|
||||||
|
.extern .shared .align 16 .b8 sharedSum[];
|
||||||
|
|
||||||
|
.visible .entry _Z7findMaxPKfPfi(
|
||||||
|
.param .u64 _Z7findMaxPKfPfi_param_0,
|
||||||
|
.param .u64 _Z7findMaxPKfPfi_param_1,
|
||||||
|
.param .u32 _Z7findMaxPKfPfi_param_2
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .pred %p<6>;
|
||||||
|
.reg .f32 %f<9>;
|
||||||
|
.reg .b32 %r<15>;
|
||||||
|
.reg .b64 %rd<9>;
|
||||||
|
|
||||||
|
|
||||||
|
ld.param.u64 %rd1, [_Z7findMaxPKfPfi_param_0];
|
||||||
|
ld.param.u64 %rd2, [_Z7findMaxPKfPfi_param_1];
|
||||||
|
ld.param.u32 %r9, [_Z7findMaxPKfPfi_param_2];
|
||||||
|
mov.u32 %r1, %ntid.x;
|
||||||
|
mov.u32 %r2, %ctaid.x;
|
||||||
|
mov.u32 %r3, %tid.x;
|
||||||
|
mad.lo.s32 %r4, %r2, %r1, %r3;
|
||||||
|
setp.ge.u32 %p1, %r4, %r9;
|
||||||
|
mov.f32 %f8, 0fFF800000;
|
||||||
|
@%p1 bra $L__BB0_2;
|
||||||
|
|
||||||
|
cvta.to.global.u64 %rd3, %rd1;
|
||||||
|
mul.wide.u32 %rd4, %r4, 4;
|
||||||
|
add.s64 %rd5, %rd3, %rd4;
|
||||||
|
ld.global.f32 %f8, [%rd5];
|
||||||
|
|
||||||
|
$L__BB0_2:
|
||||||
|
shl.b32 %r10, %r3, 2;
|
||||||
|
mov.u32 %r11, sharedMax;
|
||||||
|
add.s32 %r5, %r11, %r10;
|
||||||
|
st.shared.f32 [%r5], %f8;
|
||||||
|
bar.sync 0;
|
||||||
|
shr.u32 %r14, %r1, 1;
|
||||||
|
setp.eq.s32 %p2, %r14, 0;
|
||||||
|
@%p2 bra $L__BB0_6;
|
||||||
|
|
||||||
|
$L__BB0_3:
|
||||||
|
setp.ge.u32 %p3, %r3, %r14;
|
||||||
|
@%p3 bra $L__BB0_5;
|
||||||
|
|
||||||
|
ld.shared.f32 %f4, [%r5];
|
||||||
|
shl.b32 %r12, %r14, 2;
|
||||||
|
add.s32 %r13, %r5, %r12;
|
||||||
|
ld.shared.f32 %f5, [%r13];
|
||||||
|
max.f32 %f6, %f4, %f5;
|
||||||
|
st.shared.f32 [%r5], %f6;
|
||||||
|
|
||||||
|
$L__BB0_5:
|
||||||
|
bar.sync 0;
|
||||||
|
shr.u32 %r14, %r14, 1;
|
||||||
|
setp.ne.s32 %p4, %r14, 0;
|
||||||
|
@%p4 bra $L__BB0_3;
|
||||||
|
|
||||||
|
$L__BB0_6:
|
||||||
|
setp.ne.s32 %p5, %r3, 0;
|
||||||
|
@%p5 bra $L__BB0_8;
|
||||||
|
|
||||||
|
ld.shared.f32 %f7, [sharedMax];
|
||||||
|
cvta.to.global.u64 %rd6, %rd2;
|
||||||
|
mul.wide.u32 %rd7, %r2, 4;
|
||||||
|
add.s64 %rd8, %rd6, %rd7;
|
||||||
|
st.global.f32 [%rd8], %f7;
|
||||||
|
|
||||||
|
$L__BB0_8:
|
||||||
|
ret;
|
||||||
|
|
||||||
|
}
|
||||||
|
// .globl _Z10computeExpPKffPfi
|
||||||
|
.visible .entry _Z10computeExpPKffPfi(
|
||||||
|
.param .u64 _Z10computeExpPKffPfi_param_0,
|
||||||
|
.param .f32 _Z10computeExpPKffPfi_param_1,
|
||||||
|
.param .u64 _Z10computeExpPKffPfi_param_2,
|
||||||
|
.param .u32 _Z10computeExpPKffPfi_param_3
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .pred %p<2>;
|
||||||
|
.reg .f32 %f<20>;
|
||||||
|
.reg .b32 %r<8>;
|
||||||
|
.reg .b64 %rd<8>;
|
||||||
|
|
||||||
|
|
||||||
|
ld.param.u64 %rd1, [_Z10computeExpPKffPfi_param_0];
|
||||||
|
ld.param.f32 %f1, [_Z10computeExpPKffPfi_param_1];
|
||||||
|
ld.param.u64 %rd2, [_Z10computeExpPKffPfi_param_2];
|
||||||
|
ld.param.u32 %r2, [_Z10computeExpPKffPfi_param_3];
|
||||||
|
mov.u32 %r3, %ctaid.x;
|
||||||
|
mov.u32 %r4, %ntid.x;
|
||||||
|
mov.u32 %r5, %tid.x;
|
||||||
|
mad.lo.s32 %r1, %r3, %r4, %r5;
|
||||||
|
setp.ge.u32 %p1, %r1, %r2;
|
||||||
|
@%p1 bra $L__BB1_2;
|
||||||
|
|
||||||
|
cvta.to.global.u64 %rd3, %rd1;
|
||||||
|
mul.wide.u32 %rd4, %r1, 4;
|
||||||
|
add.s64 %rd5, %rd3, %rd4;
|
||||||
|
ld.global.f32 %f2, [%rd5];
|
||||||
|
sub.f32 %f3, %f2, %f1;
|
||||||
|
mov.f32 %f4, 0f3F000000;
|
||||||
|
mov.f32 %f5, 0f3BBB989D;
|
||||||
|
fma.rn.f32 %f6, %f3, %f5, %f4;
|
||||||
|
cvt.sat.f32.f32 %f7, %f6;
|
||||||
|
mov.f32 %f8, 0f4B400001;
|
||||||
|
mov.f32 %f9, 0f437C0000;
|
||||||
|
fma.rm.f32 %f10, %f7, %f9, %f8;
|
||||||
|
add.f32 %f11, %f10, 0fCB40007F;
|
||||||
|
neg.f32 %f12, %f11;
|
||||||
|
mov.f32 %f13, 0f3FB8AA3B;
|
||||||
|
fma.rn.f32 %f14, %f3, %f13, %f12;
|
||||||
|
mov.f32 %f15, 0f32A57060;
|
||||||
|
fma.rn.f32 %f16, %f3, %f15, %f14;
|
||||||
|
mov.b32 %r6, %f10;
|
||||||
|
shl.b32 %r7, %r6, 23;
|
||||||
|
mov.b32 %f17, %r7;
|
||||||
|
ex2.approx.ftz.f32 %f18, %f16;
|
||||||
|
mul.f32 %f19, %f18, %f17;
|
||||||
|
cvta.to.global.u64 %rd6, %rd2;
|
||||||
|
add.s64 %rd7, %rd6, %rd4;
|
||||||
|
st.global.f32 [%rd7], %f19;
|
||||||
|
|
||||||
|
$L__BB1_2:
|
||||||
|
ret;
|
||||||
|
|
||||||
|
}
|
||||||
|
// .globl _Z13block_softmaxPKf
|
||||||
|
.visible .entry _Z13block_softmaxPKf(
|
||||||
|
.param .u64 _Z13block_softmaxPKf_param_0
|
||||||
|
)
|
||||||
|
{
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ret;
|
||||||
|
|
||||||
|
}
|
||||||
|
// .globl _Z10computeSumPKfPfi
|
||||||
|
.visible .entry _Z10computeSumPKfPfi(
|
||||||
|
.param .u64 _Z10computeSumPKfPfi_param_0,
|
||||||
|
.param .u64 _Z10computeSumPKfPfi_param_1,
|
||||||
|
.param .u32 _Z10computeSumPKfPfi_param_2
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .pred %p<6>;
|
||||||
|
.reg .f32 %f<9>;
|
||||||
|
.reg .b32 %r<15>;
|
||||||
|
.reg .b64 %rd<9>;
|
||||||
|
|
||||||
|
|
||||||
|
ld.param.u64 %rd1, [_Z10computeSumPKfPfi_param_0];
|
||||||
|
ld.param.u64 %rd2, [_Z10computeSumPKfPfi_param_1];
|
||||||
|
ld.param.u32 %r9, [_Z10computeSumPKfPfi_param_2];
|
||||||
|
mov.u32 %r1, %ntid.x;
|
||||||
|
mov.u32 %r2, %ctaid.x;
|
||||||
|
mov.u32 %r3, %tid.x;
|
||||||
|
mad.lo.s32 %r4, %r2, %r1, %r3;
|
||||||
|
setp.ge.u32 %p1, %r4, %r9;
|
||||||
|
mov.f32 %f8, 0f00000000;
|
||||||
|
@%p1 bra $L__BB3_2;
|
||||||
|
|
||||||
|
cvta.to.global.u64 %rd3, %rd1;
|
||||||
|
mul.wide.u32 %rd4, %r4, 4;
|
||||||
|
add.s64 %rd5, %rd3, %rd4;
|
||||||
|
ld.global.f32 %f8, [%rd5];
|
||||||
|
|
||||||
|
$L__BB3_2:
|
||||||
|
shl.b32 %r10, %r3, 2;
|
||||||
|
mov.u32 %r11, sharedSum;
|
||||||
|
add.s32 %r5, %r11, %r10;
|
||||||
|
st.shared.f32 [%r5], %f8;
|
||||||
|
bar.sync 0;
|
||||||
|
shr.u32 %r14, %r1, 1;
|
||||||
|
setp.eq.s32 %p2, %r14, 0;
|
||||||
|
@%p2 bra $L__BB3_6;
|
||||||
|
|
||||||
|
$L__BB3_3:
|
||||||
|
setp.ge.u32 %p3, %r3, %r14;
|
||||||
|
@%p3 bra $L__BB3_5;
|
||||||
|
|
||||||
|
shl.b32 %r12, %r14, 2;
|
||||||
|
add.s32 %r13, %r5, %r12;
|
||||||
|
ld.shared.f32 %f4, [%r5];
|
||||||
|
ld.shared.f32 %f5, [%r13];
|
||||||
|
add.f32 %f6, %f5, %f4;
|
||||||
|
st.shared.f32 [%r5], %f6;
|
||||||
|
|
||||||
|
$L__BB3_5:
|
||||||
|
bar.sync 0;
|
||||||
|
shr.u32 %r14, %r14, 1;
|
||||||
|
setp.ne.s32 %p4, %r14, 0;
|
||||||
|
@%p4 bra $L__BB3_3;
|
||||||
|
|
||||||
|
$L__BB3_6:
|
||||||
|
setp.ne.s32 %p5, %r3, 0;
|
||||||
|
@%p5 bra $L__BB3_8;
|
||||||
|
|
||||||
|
ld.shared.f32 %f7, [sharedSum];
|
||||||
|
cvta.to.global.u64 %rd6, %rd2;
|
||||||
|
mul.wide.u32 %rd7, %r2, 4;
|
||||||
|
add.s64 %rd8, %rd6, %rd7;
|
||||||
|
st.global.f32 [%rd8], %f7;
|
||||||
|
|
||||||
|
$L__BB3_8:
|
||||||
|
ret;
|
||||||
|
|
||||||
|
}
|
||||||
|
// .globl _Z14computeSoftmaxPffi
|
||||||
|
.visible .entry _Z14computeSoftmaxPffi(
|
||||||
|
.param .u64 _Z14computeSoftmaxPffi_param_0,
|
||||||
|
.param .f32 _Z14computeSoftmaxPffi_param_1,
|
||||||
|
.param .u32 _Z14computeSoftmaxPffi_param_2
|
||||||
|
)
|
||||||
|
{
|
||||||
|
.reg .pred %p<2>;
|
||||||
|
.reg .f32 %f<4>;
|
||||||
|
.reg .b32 %r<6>;
|
||||||
|
.reg .b64 %rd<5>;
|
||||||
|
|
||||||
|
|
||||||
|
ld.param.u64 %rd1, [_Z14computeSoftmaxPffi_param_0];
|
||||||
|
ld.param.f32 %f1, [_Z14computeSoftmaxPffi_param_1];
|
||||||
|
ld.param.u32 %r2, [_Z14computeSoftmaxPffi_param_2];
|
||||||
|
mov.u32 %r3, %ctaid.x;
|
||||||
|
mov.u32 %r4, %ntid.x;
|
||||||
|
mov.u32 %r5, %tid.x;
|
||||||
|
mad.lo.s32 %r1, %r3, %r4, %r5;
|
||||||
|
setp.ge.u32 %p1, %r1, %r2;
|
||||||
|
@%p1 bra $L__BB4_2;
|
||||||
|
|
||||||
|
cvta.to.global.u64 %rd2, %rd1;
|
||||||
|
mul.wide.u32 %rd3, %r1, 4;
|
||||||
|
add.s64 %rd4, %rd2, %rd3;
|
||||||
|
ld.global.f32 %f2, [%rd4];
|
||||||
|
div.rn.f32 %f3, %f2, %f1;
|
||||||
|
st.global.f32 [%rd4], %f3;
|
||||||
|
|
||||||
|
$L__BB4_2:
|
||||||
|
ret;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
0
fi/models/__init__.py
Normal file
0
fi/models/__init__.py
Normal file
0
fi/models/deepseek_v3.py
Normal file
0
fi/models/deepseek_v3.py
Normal file
153
fi/models/llama.py
Normal file
153
fi/models/llama.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from flash_attn_2_cuda import varlen_fwd
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.up = nn.Linear(self.hidden_size, self.intermediate_size)
|
||||||
|
self.gate = nn.Linear(self.hidden_size, self.intermediate_size)
|
||||||
|
self.down = nn.Linear(self.intermediate_size, self.hidden_size)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
|
hidden_states = self.down(self.gate(hidden_states) * self.up(hidden_states))
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class LLAMAAttention(nn.Module):
|
||||||
|
def __init__(self, config: LlamaConfig, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.q = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.k = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.v = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.o = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.num_head = config.num_attention_heads
|
||||||
|
self.kv_head = config.num_key_value_heads
|
||||||
|
self.head_dim = config.head_dim
|
||||||
|
assert (
|
||||||
|
self.num_head * self.head_dim == config.hidden_size
|
||||||
|
), "make sure the num_head*head_dim == hidden_size"
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor):
|
||||||
|
q = self.q(hidden_states)
|
||||||
|
k = self.k(hidden_states)
|
||||||
|
v = self.v(hidden_states)
|
||||||
|
q = q.view(-1, self.num_head, self.head_dim)
|
||||||
|
k = k.view(-1, self.kv_head, self.head_dim)
|
||||||
|
v = v.view(-1, self.kv_head, self.head_dim)
|
||||||
|
# process positionids
|
||||||
|
# varlen_fwd # this function is the core of the process.
|
||||||
|
hidden_states = varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
None, # output
|
||||||
|
cu_seqlens_q, # seqlen_q
|
||||||
|
cu_seqlens_k, # seqlen_k, v
|
||||||
|
seqused_k, # seqused_k
|
||||||
|
None, # block table
|
||||||
|
None, # alibi slopse
|
||||||
|
max_seqlen_q, # int
|
||||||
|
max_seqlen_k, # int
|
||||||
|
0.0, # dropout
|
||||||
|
0.0, # softmax_scale
|
||||||
|
False, # zero_tensors
|
||||||
|
True, # is causal
|
||||||
|
-1, # window size left
|
||||||
|
-1, # window size right
|
||||||
|
False, # return softmax
|
||||||
|
None, # gen
|
||||||
|
)
|
||||||
|
# mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||||
|
# const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||||
|
# const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||||
|
# c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
# const at::Tensor &cu_seqlens_q, // b+1
|
||||||
|
# const at::Tensor &cu_seqlens_k, // b+1
|
||||||
|
# c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||||
|
# c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
||||||
|
# c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||||
|
# int max_seqlen_q,
|
||||||
|
# const int max_seqlen_k,
|
||||||
|
# const float p_dropout,
|
||||||
|
# const float softmax_scale,
|
||||||
|
# const bool zero_tensors,
|
||||||
|
# bool is_causal,
|
||||||
|
# int window_size_left,
|
||||||
|
# int window_size_right,
|
||||||
|
# const bool return_softmax,
|
||||||
|
# c10::optional<at::Generator> gen_) {
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class LLAMADecodeLayer(nn.Module):
|
||||||
|
def __init__(self, config: LlamaConfig, idx: int, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LLAMAModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_layer = config.num_hidden_layers
|
||||||
|
|
||||||
|
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[LLAMADecodeLayer(config=config, idx=i) for i in range(self.num_layer)]
|
||||||
|
)
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
def forward(self, input_ids, hidden_states, position_ids):
|
||||||
|
if input_ids is not None:
|
||||||
|
hidden_states = self.token_embed(input_ids)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, position_ids)
|
||||||
|
output = self.lm_head(hidden_states)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def unpad_input(self, hidden_states, attention_mask):
|
||||||
|
hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...")
|
||||||
|
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1) # some time is eq(1)
|
||||||
|
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
|
cu_seqlens = F.pad(
|
||||||
|
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states[indices].unsqueeze(0)
|
||||||
|
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
|
||||||
|
|
||||||
|
def pad_input(self, hidden_states, indices, batch, seqlen):
|
||||||
|
"""
|
||||||
|
:param hidden_states: Shape is [L,H] not [B,L,H]
|
||||||
|
:param indices: from unpad_input return indices
|
||||||
|
:param batch:
|
||||||
|
:param seqlen: from unpad_input return max_seqlen_in_batch
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
output = torch.zeros(
|
||||||
|
batch * seqlen,
|
||||||
|
*hidden_states.shape[1:],
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
output[indices] = hidden_states
|
||||||
|
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
||||||
112
test_fa.py
Normal file
112
test_fa.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TORCH_USE_CUDA_DSA"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_data():
|
||||||
|
nhead = 32
|
||||||
|
head_dim = 128
|
||||||
|
|
||||||
|
batch_size = 3
|
||||||
|
k_past_len = [16, 32, 64]
|
||||||
|
|
||||||
|
k_data = []
|
||||||
|
v_data = []
|
||||||
|
q_data = []
|
||||||
|
for i in k_past_len:
|
||||||
|
k_data.append(
|
||||||
|
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
|
||||||
|
)
|
||||||
|
v_data.append(
|
||||||
|
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
|
||||||
|
)
|
||||||
|
q_data.append(
|
||||||
|
torch.randn(size=(1, nhead, head_dim), dtype=torch.float16, device="cuda")
|
||||||
|
)
|
||||||
|
|
||||||
|
k_data = torch.concat(k_data, dim=0)
|
||||||
|
v_data = torch.concat(v_data, dim=0)
|
||||||
|
q_data = torch.concat(q_data, dim=0)
|
||||||
|
torch.save(k_data, "k.pkl")
|
||||||
|
torch.save(v_data, "v.pkl")
|
||||||
|
torch.save(q_data, "q.pkl")
|
||||||
|
|
||||||
|
|
||||||
|
def test_fa():
|
||||||
|
nhead = 32
|
||||||
|
head_dim = 128
|
||||||
|
|
||||||
|
batch_size = 3
|
||||||
|
k_past_len = [16, 32, 64]
|
||||||
|
|
||||||
|
k_data = []
|
||||||
|
v_data = []
|
||||||
|
q_data = []
|
||||||
|
k_data = torch.load("./k.pkl")
|
||||||
|
q_data = torch.load("./q.pkl")
|
||||||
|
v_data = torch.load("./v.pkl")
|
||||||
|
print(k_data.size(), v_data.size(), q_data.size())
|
||||||
|
k_past_len = torch.tensor(k_past_len, dtype=torch.int32, device="cuda")
|
||||||
|
res = torch.cumsum(k_past_len, dim=-1)
|
||||||
|
k_past_len = torch.cat(
|
||||||
|
(res, torch.tensor([k_data.size(0)], device=res.device, dtype=torch.int32))
|
||||||
|
)
|
||||||
|
k_past_len = k_past_len.to(torch.int32)
|
||||||
|
|
||||||
|
# cu_k_len = torch.cumsum(torch.tensor(k_past_len))
|
||||||
|
# print(cu_k_len, cu_k_len.size())
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
fa_res = flash_attn_varlen_func(
|
||||||
|
q_data,
|
||||||
|
k_data,
|
||||||
|
v_data,
|
||||||
|
cu_seqlens_k=k_past_len,
|
||||||
|
cu_seqlens_q=k_past_len,
|
||||||
|
max_seqlen_k=3,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
torch.save(fa_res, "./fa.pkl")
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_sdqa():
|
||||||
|
nhead = 32
|
||||||
|
head_dim = 128
|
||||||
|
|
||||||
|
batch_size = 3
|
||||||
|
k_past_len = [16, 32, 64]
|
||||||
|
|
||||||
|
k_data = []
|
||||||
|
v_data = []
|
||||||
|
q_data = []
|
||||||
|
q_data = torch.load("./q.pkl")
|
||||||
|
k_data = torch.load("./k.pkl")
|
||||||
|
v_data = torch.load("./v.pkl")
|
||||||
|
q = q_data[0, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
||||||
|
k = k_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
||||||
|
v = v_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
naive_res = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||||
|
print(naive_res.size())
|
||||||
|
naive_res = naive_res.view(1, nhead, head_dim)
|
||||||
|
fa_res = torch.load("./fa.pkl")
|
||||||
|
if torch.allclose(naive_res, fa_res[0, :, :]):
|
||||||
|
print("this is dam right")
|
||||||
|
else:
|
||||||
|
print("max diff is ", torch.abs(naive_res - fa_res[0, :, :]).max())
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
# do not know if such method is right.
|
||||||
|
# chunk prefill did use such method?
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
generate_data()
|
||||||
|
test_fa()
|
||||||
|
test_torch_sdqa()
|
||||||
0
training/__init__.py
Normal file
0
training/__init__.py
Normal file
24
training/dp.py
Normal file
24
training/dp.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
|
||||||
|
class DataParallel(object):
|
||||||
|
def __init__(self, dp_num, world_size: int, rank: int, module=None):
|
||||||
|
self.world_size = world_size
|
||||||
|
self.rank = rank
|
||||||
|
self.dp_num = dp_num
|
||||||
|
|
||||||
|
def forward_pipeline(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def backward_pipeline(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync_grad(self, param):
|
||||||
|
pass
|
||||||
9
training/pp.py
Normal file
9
training/pp.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineParallel(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
Loading…
Reference in New Issue
Block a user