Compare commits
2 Commits
a1aa7fd0d6
...
e33d87b0aa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e33d87b0aa | ||
|
|
89e3b9d190 |
257
csrc/softmax.ptx
Normal file
257
csrc/softmax.ptx
Normal file
@ -0,0 +1,257 @@
|
||||
//
|
||||
// Generated by NVIDIA NVVM Compiler
|
||||
//
|
||||
// Compiler Build ID: CL-34097967
|
||||
// Cuda compilation tools, release 12.4, V12.4.131
|
||||
// Based on NVVM 7.0.1
|
||||
//
|
||||
|
||||
.version 8.4
|
||||
.target sm_52
|
||||
.address_size 64
|
||||
|
||||
// .globl _Z7findMaxPKfPfi
|
||||
.extern .shared .align 16 .b8 sharedMax[];
|
||||
.extern .shared .align 16 .b8 sharedSum[];
|
||||
|
||||
.visible .entry _Z7findMaxPKfPfi(
|
||||
.param .u64 _Z7findMaxPKfPfi_param_0,
|
||||
.param .u64 _Z7findMaxPKfPfi_param_1,
|
||||
.param .u32 _Z7findMaxPKfPfi_param_2
|
||||
)
|
||||
{
|
||||
.reg .pred %p<6>;
|
||||
.reg .f32 %f<9>;
|
||||
.reg .b32 %r<15>;
|
||||
.reg .b64 %rd<9>;
|
||||
|
||||
|
||||
ld.param.u64 %rd1, [_Z7findMaxPKfPfi_param_0];
|
||||
ld.param.u64 %rd2, [_Z7findMaxPKfPfi_param_1];
|
||||
ld.param.u32 %r9, [_Z7findMaxPKfPfi_param_2];
|
||||
mov.u32 %r1, %ntid.x;
|
||||
mov.u32 %r2, %ctaid.x;
|
||||
mov.u32 %r3, %tid.x;
|
||||
mad.lo.s32 %r4, %r2, %r1, %r3;
|
||||
setp.ge.u32 %p1, %r4, %r9;
|
||||
mov.f32 %f8, 0fFF800000;
|
||||
@%p1 bra $L__BB0_2;
|
||||
|
||||
cvta.to.global.u64 %rd3, %rd1;
|
||||
mul.wide.u32 %rd4, %r4, 4;
|
||||
add.s64 %rd5, %rd3, %rd4;
|
||||
ld.global.f32 %f8, [%rd5];
|
||||
|
||||
$L__BB0_2:
|
||||
shl.b32 %r10, %r3, 2;
|
||||
mov.u32 %r11, sharedMax;
|
||||
add.s32 %r5, %r11, %r10;
|
||||
st.shared.f32 [%r5], %f8;
|
||||
bar.sync 0;
|
||||
shr.u32 %r14, %r1, 1;
|
||||
setp.eq.s32 %p2, %r14, 0;
|
||||
@%p2 bra $L__BB0_6;
|
||||
|
||||
$L__BB0_3:
|
||||
setp.ge.u32 %p3, %r3, %r14;
|
||||
@%p3 bra $L__BB0_5;
|
||||
|
||||
ld.shared.f32 %f4, [%r5];
|
||||
shl.b32 %r12, %r14, 2;
|
||||
add.s32 %r13, %r5, %r12;
|
||||
ld.shared.f32 %f5, [%r13];
|
||||
max.f32 %f6, %f4, %f5;
|
||||
st.shared.f32 [%r5], %f6;
|
||||
|
||||
$L__BB0_5:
|
||||
bar.sync 0;
|
||||
shr.u32 %r14, %r14, 1;
|
||||
setp.ne.s32 %p4, %r14, 0;
|
||||
@%p4 bra $L__BB0_3;
|
||||
|
||||
$L__BB0_6:
|
||||
setp.ne.s32 %p5, %r3, 0;
|
||||
@%p5 bra $L__BB0_8;
|
||||
|
||||
ld.shared.f32 %f7, [sharedMax];
|
||||
cvta.to.global.u64 %rd6, %rd2;
|
||||
mul.wide.u32 %rd7, %r2, 4;
|
||||
add.s64 %rd8, %rd6, %rd7;
|
||||
st.global.f32 [%rd8], %f7;
|
||||
|
||||
$L__BB0_8:
|
||||
ret;
|
||||
|
||||
}
|
||||
// .globl _Z10computeExpPKffPfi
|
||||
.visible .entry _Z10computeExpPKffPfi(
|
||||
.param .u64 _Z10computeExpPKffPfi_param_0,
|
||||
.param .f32 _Z10computeExpPKffPfi_param_1,
|
||||
.param .u64 _Z10computeExpPKffPfi_param_2,
|
||||
.param .u32 _Z10computeExpPKffPfi_param_3
|
||||
)
|
||||
{
|
||||
.reg .pred %p<2>;
|
||||
.reg .f32 %f<20>;
|
||||
.reg .b32 %r<8>;
|
||||
.reg .b64 %rd<8>;
|
||||
|
||||
|
||||
ld.param.u64 %rd1, [_Z10computeExpPKffPfi_param_0];
|
||||
ld.param.f32 %f1, [_Z10computeExpPKffPfi_param_1];
|
||||
ld.param.u64 %rd2, [_Z10computeExpPKffPfi_param_2];
|
||||
ld.param.u32 %r2, [_Z10computeExpPKffPfi_param_3];
|
||||
mov.u32 %r3, %ctaid.x;
|
||||
mov.u32 %r4, %ntid.x;
|
||||
mov.u32 %r5, %tid.x;
|
||||
mad.lo.s32 %r1, %r3, %r4, %r5;
|
||||
setp.ge.u32 %p1, %r1, %r2;
|
||||
@%p1 bra $L__BB1_2;
|
||||
|
||||
cvta.to.global.u64 %rd3, %rd1;
|
||||
mul.wide.u32 %rd4, %r1, 4;
|
||||
add.s64 %rd5, %rd3, %rd4;
|
||||
ld.global.f32 %f2, [%rd5];
|
||||
sub.f32 %f3, %f2, %f1;
|
||||
mov.f32 %f4, 0f3F000000;
|
||||
mov.f32 %f5, 0f3BBB989D;
|
||||
fma.rn.f32 %f6, %f3, %f5, %f4;
|
||||
cvt.sat.f32.f32 %f7, %f6;
|
||||
mov.f32 %f8, 0f4B400001;
|
||||
mov.f32 %f9, 0f437C0000;
|
||||
fma.rm.f32 %f10, %f7, %f9, %f8;
|
||||
add.f32 %f11, %f10, 0fCB40007F;
|
||||
neg.f32 %f12, %f11;
|
||||
mov.f32 %f13, 0f3FB8AA3B;
|
||||
fma.rn.f32 %f14, %f3, %f13, %f12;
|
||||
mov.f32 %f15, 0f32A57060;
|
||||
fma.rn.f32 %f16, %f3, %f15, %f14;
|
||||
mov.b32 %r6, %f10;
|
||||
shl.b32 %r7, %r6, 23;
|
||||
mov.b32 %f17, %r7;
|
||||
ex2.approx.ftz.f32 %f18, %f16;
|
||||
mul.f32 %f19, %f18, %f17;
|
||||
cvta.to.global.u64 %rd6, %rd2;
|
||||
add.s64 %rd7, %rd6, %rd4;
|
||||
st.global.f32 [%rd7], %f19;
|
||||
|
||||
$L__BB1_2:
|
||||
ret;
|
||||
|
||||
}
|
||||
// .globl _Z13block_softmaxPKf
|
||||
.visible .entry _Z13block_softmaxPKf(
|
||||
.param .u64 _Z13block_softmaxPKf_param_0
|
||||
)
|
||||
{
|
||||
|
||||
|
||||
|
||||
ret;
|
||||
|
||||
}
|
||||
// .globl _Z10computeSumPKfPfi
|
||||
.visible .entry _Z10computeSumPKfPfi(
|
||||
.param .u64 _Z10computeSumPKfPfi_param_0,
|
||||
.param .u64 _Z10computeSumPKfPfi_param_1,
|
||||
.param .u32 _Z10computeSumPKfPfi_param_2
|
||||
)
|
||||
{
|
||||
.reg .pred %p<6>;
|
||||
.reg .f32 %f<9>;
|
||||
.reg .b32 %r<15>;
|
||||
.reg .b64 %rd<9>;
|
||||
|
||||
|
||||
ld.param.u64 %rd1, [_Z10computeSumPKfPfi_param_0];
|
||||
ld.param.u64 %rd2, [_Z10computeSumPKfPfi_param_1];
|
||||
ld.param.u32 %r9, [_Z10computeSumPKfPfi_param_2];
|
||||
mov.u32 %r1, %ntid.x;
|
||||
mov.u32 %r2, %ctaid.x;
|
||||
mov.u32 %r3, %tid.x;
|
||||
mad.lo.s32 %r4, %r2, %r1, %r3;
|
||||
setp.ge.u32 %p1, %r4, %r9;
|
||||
mov.f32 %f8, 0f00000000;
|
||||
@%p1 bra $L__BB3_2;
|
||||
|
||||
cvta.to.global.u64 %rd3, %rd1;
|
||||
mul.wide.u32 %rd4, %r4, 4;
|
||||
add.s64 %rd5, %rd3, %rd4;
|
||||
ld.global.f32 %f8, [%rd5];
|
||||
|
||||
$L__BB3_2:
|
||||
shl.b32 %r10, %r3, 2;
|
||||
mov.u32 %r11, sharedSum;
|
||||
add.s32 %r5, %r11, %r10;
|
||||
st.shared.f32 [%r5], %f8;
|
||||
bar.sync 0;
|
||||
shr.u32 %r14, %r1, 1;
|
||||
setp.eq.s32 %p2, %r14, 0;
|
||||
@%p2 bra $L__BB3_6;
|
||||
|
||||
$L__BB3_3:
|
||||
setp.ge.u32 %p3, %r3, %r14;
|
||||
@%p3 bra $L__BB3_5;
|
||||
|
||||
shl.b32 %r12, %r14, 2;
|
||||
add.s32 %r13, %r5, %r12;
|
||||
ld.shared.f32 %f4, [%r5];
|
||||
ld.shared.f32 %f5, [%r13];
|
||||
add.f32 %f6, %f5, %f4;
|
||||
st.shared.f32 [%r5], %f6;
|
||||
|
||||
$L__BB3_5:
|
||||
bar.sync 0;
|
||||
shr.u32 %r14, %r14, 1;
|
||||
setp.ne.s32 %p4, %r14, 0;
|
||||
@%p4 bra $L__BB3_3;
|
||||
|
||||
$L__BB3_6:
|
||||
setp.ne.s32 %p5, %r3, 0;
|
||||
@%p5 bra $L__BB3_8;
|
||||
|
||||
ld.shared.f32 %f7, [sharedSum];
|
||||
cvta.to.global.u64 %rd6, %rd2;
|
||||
mul.wide.u32 %rd7, %r2, 4;
|
||||
add.s64 %rd8, %rd6, %rd7;
|
||||
st.global.f32 [%rd8], %f7;
|
||||
|
||||
$L__BB3_8:
|
||||
ret;
|
||||
|
||||
}
|
||||
// .globl _Z14computeSoftmaxPffi
|
||||
.visible .entry _Z14computeSoftmaxPffi(
|
||||
.param .u64 _Z14computeSoftmaxPffi_param_0,
|
||||
.param .f32 _Z14computeSoftmaxPffi_param_1,
|
||||
.param .u32 _Z14computeSoftmaxPffi_param_2
|
||||
)
|
||||
{
|
||||
.reg .pred %p<2>;
|
||||
.reg .f32 %f<4>;
|
||||
.reg .b32 %r<6>;
|
||||
.reg .b64 %rd<5>;
|
||||
|
||||
|
||||
ld.param.u64 %rd1, [_Z14computeSoftmaxPffi_param_0];
|
||||
ld.param.f32 %f1, [_Z14computeSoftmaxPffi_param_1];
|
||||
ld.param.u32 %r2, [_Z14computeSoftmaxPffi_param_2];
|
||||
mov.u32 %r3, %ctaid.x;
|
||||
mov.u32 %r4, %ntid.x;
|
||||
mov.u32 %r5, %tid.x;
|
||||
mad.lo.s32 %r1, %r3, %r4, %r5;
|
||||
setp.ge.u32 %p1, %r1, %r2;
|
||||
@%p1 bra $L__BB4_2;
|
||||
|
||||
cvta.to.global.u64 %rd2, %rd1;
|
||||
mul.wide.u32 %rd3, %r1, 4;
|
||||
add.s64 %rd4, %rd2, %rd3;
|
||||
ld.global.f32 %f2, [%rd4];
|
||||
div.rn.f32 %f3, %f2, %f1;
|
||||
st.global.f32 [%rd4], %f3;
|
||||
|
||||
$L__BB4_2:
|
||||
ret;
|
||||
|
||||
}
|
||||
|
||||
0
fi/models/__init__.py
Normal file
0
fi/models/__init__.py
Normal file
0
fi/models/deepseek_v3.py
Normal file
0
fi/models/deepseek_v3.py
Normal file
153
fi/models/llama.py
Normal file
153
fi/models/llama.py
Normal file
@ -0,0 +1,153 @@
|
||||
# coding=utf-8
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from dataclasses import dataclass
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
|
||||
from flash_attn_2_cuda import varlen_fwd
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.up = nn.Linear(self.hidden_size, self.intermediate_size)
|
||||
self.gate = nn.Linear(self.hidden_size, self.intermediate_size)
|
||||
self.down = nn.Linear(self.intermediate_size, self.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
hidden_states = self.down(self.gate(hidden_states) * self.up(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LLAMAAttention(nn.Module):
|
||||
def __init__(self, config: LlamaConfig, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.q = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.k = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.v = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.o = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.num_head = config.num_attention_heads
|
||||
self.kv_head = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
assert (
|
||||
self.num_head * self.head_dim == config.hidden_size
|
||||
), "make sure the num_head*head_dim == hidden_size"
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, position_ids: torch.Tensor):
|
||||
q = self.q(hidden_states)
|
||||
k = self.k(hidden_states)
|
||||
v = self.v(hidden_states)
|
||||
q = q.view(-1, self.num_head, self.head_dim)
|
||||
k = k.view(-1, self.kv_head, self.head_dim)
|
||||
v = v.view(-1, self.kv_head, self.head_dim)
|
||||
# process positionids
|
||||
# varlen_fwd # this function is the core of the process.
|
||||
hidden_states = varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None, # output
|
||||
cu_seqlens_q, # seqlen_q
|
||||
cu_seqlens_k, # seqlen_k, v
|
||||
seqused_k, # seqused_k
|
||||
None, # block table
|
||||
None, # alibi slopse
|
||||
max_seqlen_q, # int
|
||||
max_seqlen_k, # int
|
||||
0.0, # dropout
|
||||
0.0, # softmax_scale
|
||||
False, # zero_tensors
|
||||
True, # is causal
|
||||
-1, # window size left
|
||||
-1, # window size right
|
||||
False, # return softmax
|
||||
None, # gen
|
||||
)
|
||||
# mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
# const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
# const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
# c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
# const at::Tensor &cu_seqlens_q, // b+1
|
||||
# const at::Tensor &cu_seqlens_k, // b+1
|
||||
# c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
# c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
||||
# c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
# int max_seqlen_q,
|
||||
# const int max_seqlen_k,
|
||||
# const float p_dropout,
|
||||
# const float softmax_scale,
|
||||
# const bool zero_tensors,
|
||||
# bool is_causal,
|
||||
# int window_size_left,
|
||||
# int window_size_right,
|
||||
# const bool return_softmax,
|
||||
# c10::optional<at::Generator> gen_) {
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LLAMADecodeLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig, idx: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class LLAMAModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_layer = config.num_hidden_layers
|
||||
|
||||
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList(
|
||||
[LLAMADecodeLayer(config=config, idx=i) for i in range(self.num_layer)]
|
||||
)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
def forward(self, input_ids, hidden_states, position_ids):
|
||||
if input_ids is not None:
|
||||
hidden_states = self.token_embed(input_ids)
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, position_ids)
|
||||
output = self.lm_head(hidden_states)
|
||||
return output
|
||||
|
||||
def unpad_input(self, hidden_states, attention_mask):
|
||||
hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...")
|
||||
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1) # some time is eq(1)
|
||||
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
hidden_states = hidden_states[indices].unsqueeze(0)
|
||||
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
def pad_input(self, hidden_states, indices, batch, seqlen):
|
||||
"""
|
||||
:param hidden_states: Shape is [L,H] not [B,L,H]
|
||||
:param indices: from unpad_input return indices
|
||||
:param batch:
|
||||
:param seqlen: from unpad_input return max_seqlen_in_batch
|
||||
:return:
|
||||
"""
|
||||
output = torch.zeros(
|
||||
batch * seqlen,
|
||||
*hidden_states.shape[1:],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
output[indices] = hidden_states
|
||||
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
||||
112
test_fa.py
Normal file
112
test_fa.py
Normal file
@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from torch.nn import functional as F
|
||||
|
||||
import os
|
||||
|
||||
os.environ["TORCH_USE_CUDA_DSA"] = "1"
|
||||
|
||||
|
||||
def generate_data():
|
||||
nhead = 32
|
||||
head_dim = 128
|
||||
|
||||
batch_size = 3
|
||||
k_past_len = [16, 32, 64]
|
||||
|
||||
k_data = []
|
||||
v_data = []
|
||||
q_data = []
|
||||
for i in k_past_len:
|
||||
k_data.append(
|
||||
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
|
||||
)
|
||||
v_data.append(
|
||||
torch.randn(size=(i, nhead, head_dim), dtype=torch.float16, device="cuda")
|
||||
)
|
||||
q_data.append(
|
||||
torch.randn(size=(1, nhead, head_dim), dtype=torch.float16, device="cuda")
|
||||
)
|
||||
|
||||
k_data = torch.concat(k_data, dim=0)
|
||||
v_data = torch.concat(v_data, dim=0)
|
||||
q_data = torch.concat(q_data, dim=0)
|
||||
torch.save(k_data, "k.pkl")
|
||||
torch.save(v_data, "v.pkl")
|
||||
torch.save(q_data, "q.pkl")
|
||||
|
||||
|
||||
def test_fa():
|
||||
nhead = 32
|
||||
head_dim = 128
|
||||
|
||||
batch_size = 3
|
||||
k_past_len = [16, 32, 64]
|
||||
|
||||
k_data = []
|
||||
v_data = []
|
||||
q_data = []
|
||||
k_data = torch.load("./k.pkl")
|
||||
q_data = torch.load("./q.pkl")
|
||||
v_data = torch.load("./v.pkl")
|
||||
print(k_data.size(), v_data.size(), q_data.size())
|
||||
k_past_len = torch.tensor(k_past_len, dtype=torch.int32, device="cuda")
|
||||
res = torch.cumsum(k_past_len, dim=-1)
|
||||
k_past_len = torch.cat(
|
||||
(res, torch.tensor([k_data.size(0)], device=res.device, dtype=torch.int32))
|
||||
)
|
||||
k_past_len = k_past_len.to(torch.int32)
|
||||
|
||||
# cu_k_len = torch.cumsum(torch.tensor(k_past_len))
|
||||
# print(cu_k_len, cu_k_len.size())
|
||||
torch.cuda.synchronize()
|
||||
fa_res = flash_attn_varlen_func(
|
||||
q_data,
|
||||
k_data,
|
||||
v_data,
|
||||
cu_seqlens_k=k_past_len,
|
||||
cu_seqlens_q=k_past_len,
|
||||
max_seqlen_k=3,
|
||||
max_seqlen_q=1,
|
||||
causal=True,
|
||||
)
|
||||
# torch.cuda.synchronize()
|
||||
torch.save(fa_res, "./fa.pkl")
|
||||
|
||||
|
||||
def test_torch_sdqa():
|
||||
nhead = 32
|
||||
head_dim = 128
|
||||
|
||||
batch_size = 3
|
||||
k_past_len = [16, 32, 64]
|
||||
|
||||
k_data = []
|
||||
v_data = []
|
||||
q_data = []
|
||||
q_data = torch.load("./q.pkl")
|
||||
k_data = torch.load("./k.pkl")
|
||||
v_data = torch.load("./v.pkl")
|
||||
q = q_data[0, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
||||
k = k_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
||||
v = v_data[:16, :, :].view(1, -1, nhead, head_dim).transpose(1, 2).contiguous()
|
||||
# torch.cuda.synchronize()
|
||||
naive_res = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
print(naive_res.size())
|
||||
naive_res = naive_res.view(1, nhead, head_dim)
|
||||
fa_res = torch.load("./fa.pkl")
|
||||
if torch.allclose(naive_res, fa_res[0, :, :]):
|
||||
print("this is dam right")
|
||||
else:
|
||||
print("max diff is ", torch.abs(naive_res - fa_res[0, :, :]).max())
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
|
||||
# do not know if such method is right.
|
||||
# chunk prefill did use such method?
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_data()
|
||||
test_fa()
|
||||
test_torch_sdqa()
|
||||
0
training/__init__.py
Normal file
0
training/__init__.py
Normal file
24
training/dp.py
Normal file
24
training/dp.py
Normal file
@ -0,0 +1,24 @@
|
||||
# coding=utf-8
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
class DataParallel(object):
|
||||
def __init__(self, dp_num, world_size: int, rank: int, module=None):
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.dp_num = dp_num
|
||||
|
||||
def forward_pipeline(self):
|
||||
pass
|
||||
|
||||
def backward_pipeline(self):
|
||||
pass
|
||||
|
||||
def train(self):
|
||||
pass
|
||||
|
||||
def sync_grad(self, param):
|
||||
pass
|
||||
9
training/pp.py
Normal file
9
training/pp.py
Normal file
@ -0,0 +1,9 @@
|
||||
# coding=utf-8
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class PipelineParallel(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
Loading…
Reference in New Issue
Block a user