From a43baa8b7f665f4a744076f20bdb07bf73b01c9c Mon Sep 17 00:00:00 2001 From: longfei li Date: Mon, 18 Nov 2024 19:54:12 +0800 Subject: [PATCH] test multi dimension matrix multiply --- .vscode/settings.json | 3 +++ csrc/core.h | 1 + csrc/core_bind.cpp | 1 + csrc/max.cu | 19 ++++++++++++++----- csrc/md.cu | 29 +++++++++++++++++++++++++++++ fi/load_model.py | 40 +++++++++++++++++++++++++++++++++++----- setup.py | 1 + test | Bin 0 -> 16480 bytes test.cc | 22 ++++++++++++++++++++++ 9 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 csrc/md.cu create mode 100755 test create mode 100644 test.cc diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..23830fb --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "git.ignoreLimitWarning": true +} diff --git a/csrc/core.h b/csrc/core.h index b523ebe..4c0b4f9 100644 --- a/csrc/core.h +++ b/csrc/core.h @@ -16,4 +16,5 @@ void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c) void print_idx(); void reducemax(const torch::Tensor &src, torch::Tensor &dest); void test_cute_tensor(); +void md_mm(const torch::Tensor &src); #endif \ No newline at end of file diff --git a/csrc/core_bind.cpp b/csrc/core_bind.cpp index ee7f396..1672e6a 100644 --- a/csrc/core_bind.cpp +++ b/csrc/core_bind.cpp @@ -15,4 +15,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("print_idx", &print_idx, "just_printidx"); m.def("reducemax", &reducemax, "reduce max"); m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor"); + m.def("md_mm", &md_mm, "just a test of multi dimension mm"); } diff --git a/csrc/max.cu b/csrc/max.cu index 6788b70..344b62a 100644 --- a/csrc/max.cu +++ b/csrc/max.cu @@ -81,17 +81,17 @@ __global__ void test_cute_tensor_kernel() Stride<_32, _2>{}); Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{})); __shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation) - printf("smem size is :%d\n", decltype(cosize(smem_layout))::value); + // printf("smem size is :%d\n", decltype(cosize(smem_layout))::value); Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout); printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n", bool_tensor.size(), ind_tensor.size(), rmem_4x8_col.size(), rmem_4x8_pad.size(), stensor.size()); - auto TA = make_layout(make_shape(Int<32>{}, Int<8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major - TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, // Atom: Copy TAs as if they were uint128_t - Layout>{}, // Thr layout 32x8 m-major - Layout>{}); // Val layout 4x1 m-major + auto TA = make_layout(make_shape(Int<32>{}, Int<8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major + TiledCopy copyA = make_tiled_copy(Copy_Atom, float>{}, // Atom: Copy TAs as if they were uint128_t + Layout>{}, // Thr layout 32x8 m-major + Layout>{}); // Val layout 4x1 m-major print_latex(copyA); } @@ -108,4 +108,13 @@ void test_cute_tensor() dim3 thread_block(16, 16); dim3 block(16); test_cute_tensor_kernel<<>>(); +} + +__global__ void md_op(const float *a) +{ + int tidx = threadIdx.x; + int bid = blockIdx.x; + int hid = blockIdx.y; + int offset = blockDim.x * blockDim.y; + // 绑定到自己的进 } \ No newline at end of file diff --git a/csrc/md.cu b/csrc/md.cu new file mode 100644 index 0000000..7f62793 --- /dev/null +++ b/csrc/md.cu @@ -0,0 +1,29 @@ +#include "core.h" + +#include +#include +#include +#include + +__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num) +{ + int batch_idx = blockIdx.x; + int head_idx = blockIdx.y; + int sequence_idx = blockIdx.z; + int tidx = threadIdx.x; +} + +void md_mm(const torch::Tensor &src) +{ + int batch_size = src.size(0); + int head_size = src.size(1); + int sequence_size = src.size(2); + int head_dim = src.size(3); + int data_block = sequence_size * head_dim; + int thread_num = 256; + dim3 grid(batch_size, head_size, (data_block + thread_num - 1) / thread_num); + dim3 block(thread_num); + md_mm_kernel<<>>(reinterpret_cast(src.data_ptr()), + src.stride(0), src.stride(1), src.stride(2), + thread_num); +} diff --git a/fi/load_model.py b/fi/load_model.py index 32679b1..1f9a5a4 100644 --- a/fi/load_model.py +++ b/fi/load_model.py @@ -1,14 +1,44 @@ # coding=utf-8 +import os import torch import transformers import torch.nn as nn -from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM -from transformers.models.llama import LlamaConfig, LlamaForCausalLM -from transformers import AutoModel, AutoConfig +from transformers import AutoModelForCausalLM, AutoConfig +from transformers.models import qwen2, gemma2, llama, gemma + +decode_layers = { + "gemma": gemma.modeling_gemma.GemmaDecoderLayer, + "gemma2": gemma2.modeling_gemma2.Gemma2DecoderLayer, + "qwen2": qwen2.modeling_qwen2.Qwen2DecoderLayer, +} + +MODELS = { + "gemma": gemma.GemmaForCausalLM, + "gemma2": gemma2.Gemma2ForCausalLM, + "llama": llama.LlamaForCausalLM, + "qwen2": qwen2.Qwen2ForCausalLM, +} class ModelLoader: - def __init__(self, model_path: str): - self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True) + def __init__(self, model_path: str, pipeline_num: int = 1): + self.config_path = os.path.join(model_path, "config.json") + self.model_config = AutoConfig.from_pretrained(self.config_path) + hidden_layers = self.model_config.get("num_hidden_layers", -1) + if hidden_layers == -1: + raise ValueError("do not has such parameter") + self.hidden_layers = hidden_layers + self.pipeline_num = pipeline_num + self.model = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True + ) + self.model_type = self.model_config["model_type"] + self.per_pipeline_layers = self.hidden_layers // self.pipeline_num + module_list = None + for x in self.model.modules(): + if isinstance(x, torch.nn.modules.container.ModuleList): + module_list = x + if module_list is None: + raise ValueError("do not have module list.") diff --git a/setup.py b/setup.py index fc3902f..3b9bf4d 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ files = [ "csrc/matrix.cu", "csrc/core_bind.cpp", "csrc/max.cu", + "csrc/md.cu", ] extension = CUDAExtension( name="torch_cuda_ext.core", diff --git a/test b/test new file mode 100755 index 0000000000000000000000000000000000000000..c1aeb719b1b4222107590498de4b52f086228556 GIT binary patch literal 16480 zcmeHOYiu0V6~5~SaY_<9fJOcCWFLSaIl*zATiagmg2lh zU8>FjU#Br6j~N1~DR0(P(gwzBLCJ0fRTiKZ8(0aI77``9Y9+=2gevX_+0~N~^FjK3 zZl_R@d-0-Vj1tvnk9Lv*icr5DX3LMaB7kz;P^f82$gs@!H)9s-xj&d z!yK>Lb=b@K6v~%Vt;=RcyE@lpQ?1!dt~Av;)!o(F)oB&;R)=l^wJVB)#?+>*+v#%5 z5+;|kX4>%Lk7O6Hob-FH+3>4ZrhoF}t5?>)H#^+%yJv3d*+DwgZn7aAD&$WtIq_6K z7YFH>d_1J2Qd!3V0~)ZZwMV4?m+ zgCO{8s^I&8FVwF$fM13qSY8FNT(ziJ%t)FjAQju*HIdKRMb{~~wyo^JZCmVAx{%(L zDZ1&xwk>_xd@jAs8O^4RY+=bQTSxNQVmzblTSnZ@bS{-0Opdti?a6VcV7mn;;}!># z@%S*X^`lNPleF_iw~%%w2DjPU+U)p97ed2{+nvc5ZD^!B2Xh&>zde4N#>x56!0zx$^VJUPS|qR*&!HNe=`W)HYBjI3tNA+QTCTgh=Bs!V z)v8N4{}Je(&vLRj^w8Ha%!E>4<}zzW_`GM?uaC$VFtnsgc%4k5N06E!2eqY z-fg_}--$gZ8WQ{Jf4xqr#Qn2wP5Jf2o}V>5ue(ydwgC84`N|?l<0?h^<0LsX_lI)1 zd_eOw0Xj8z#N%n=b87Ce$J4~))Z7n+hqo=-Tffr&z+V#Jn*;pP0AKIt+y62+bL5MO znYR*q{&;d|TfBX?{pG~|jjteTx%pj7slSa`jjQehiw^GJxDvLB{q^S&`OImz8C`z^ zO4Rkm<)e+Orb+L4rqG`w;r?siBC_U8V&-Jx@Y~lX4xg?~L|#Z7Ipbac2amCXhVs#| z##QK-*i(L{Hy%SQRa(0}v1j8jG*x2en7bsgZ{q`Cj?bMbmyf5AoEPe!fnMZR*jCuj z-3}o++73Hhtfpsa=7%oaLs4g5JpLzmNbGA?ntg+iCxJW*CDOW?wGJDt-{6;p#EjPZ z3Td&g^{n+xqjhLrD+A;y^7Zob{zlB)_5Muh^q#xkkCZmtd)Mi589nSkxcAsmFFD#@ z|EPBKQ`B?-wHsnEq;CWlpDdSM&@AX9pm&1OH@&Ap zzX1B!JLNKcJG=T+xqJe28gvO>+34FQmB<~#Dl*j+xoF9vh69Y#H~by2Urh@Ai|Esj zx&lWhj)Q+Km+70NTGlkMY{dW`V~* zrBJK~aU4Os`w5IMYkHujZ|ONTccMk;AE58y=t4g}S0Ufx$zKJYg?y!7KBUEjgeA;C zn1L_@VFtnsgc%4k5N06EK$w9r10N{^vQACbrO7%p`uatM*2qy2p4P-sSz?H{j(I8b z2YG#)erEwD>*6kAxx72IFfZ%i=zWNatUdetnR1@^CwQ@#tO0A{imd&5ljX7oEynGW zHF7J+2}-je-kHtR0bWlgYvgva2EEr&ku`D`8de=;K=4^^h_pM(@nx-?#J`m7Men_` z=^xZJgpZWEzmx0_-yJnC?ZRuM=7(7RFXrpm@Bf;3>%0Sw+xP;;m<}=BDHk@N|54e7 zn6cQ@*LPj?s_mntoLh>vw_BZ7TWfnsv+Z|xbXaYjYZ(6U2y{8>)I0c~TUqG(L?C&P z2SSv^=YuY!PuH&ot3^)+`bl<#M|u4wSY6iROt{=n468v@B3SP`)MF4HE{axJN7N^k z+z+CECo1Ge?o+|{YW-&K{u2CgR2FLg543)Zx+xHsKF*VWi|s7XDJF=aH{iyJ z0#^Zi5cJXAFJYDjU)?hyOqbv zb7wSTyUs2Jpy-yy#;l~8UvO``6Lyl8+!vvl%G(879f*g4?jQrIwn)%>NT=G!vZ1+Z!JjlDgaVNx2k2W?(*KQP?0 zC2q&J_R~HC8#_Mic)~VdqJLP~{Wovz*)rJYS>HG?FcQc9fS%sXaag2Nj_Y6(!i{K1 zDwDHI#dHdO{44!qvPz$p=70RH2*wRB{KsV|=IwEGleR!mauGDJpTLWY-2tXWW@~|m zoA#JZ60(n>GK$HwjSUx)A=*hnyBR9=Y|4V!%7B#QC8NMnT`gXWxKUu|C<=~9t-rs43N0TsV}aDN?r z2k{sAzbF{ZTm1GiZxR|6MW#Xf&jbGl=B(0wnb!-witTuMdsh1$$Z1Y5_A>7m+D*wt zp*B)g`Y-eE3^1DNiM_l>2#v8lRCEcj>PtXrPA>K`4;L!$Gtz$16M8r7X^t=aBr6IX zGGVeu|3tuknhk`;#E_}@6Z)fo z{VWTF{+{zMZp2Rb=K}UJ4-_iz&!Qi^e}2yP-rOn5noNWG!V7(k?PdJbi5=5^$o)m^ z1p6&2)J3sRM2$q~3B!7489xFXLuDBbdGD6r1!O;LFkfOX^c~nX`t0RBX>Th)6wx!> zJTExS<7w-ICQRbTVn))$_AAJ_uf%+HsxIgFtnNw0eyxX?in#Mrx%Ph57(49d0don2 q^jnLxU!FgUfl1r__EVjPTO1HdT+yZ1uYigc-TP^y(GoBSu +using namespace std; + +template +struct C +{ + using type = C; + static constexpr auto value = v; + using value_type = decltype(v); + inline constexpr operator value_type() const noexcept { return value; } + inline constexpr value_type operator()() const noexcept { return value; } +}; + +int main() +{ + using _1 = C<10>; + auto x = _1{}; + cout << _1::value << endl; + cout << _1::value_type() << endl; + cout << x.value << endl; + return 0; +} \ No newline at end of file