test multi dimension matrix multiply
This commit is contained in:
parent
2285b8b6f2
commit
a43baa8b7f
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"git.ignoreLimitWarning": true
|
||||||
|
}
|
||||||
@ -16,4 +16,5 @@ void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c)
|
|||||||
void print_idx();
|
void print_idx();
|
||||||
void reducemax(const torch::Tensor &src, torch::Tensor &dest);
|
void reducemax(const torch::Tensor &src, torch::Tensor &dest);
|
||||||
void test_cute_tensor();
|
void test_cute_tensor();
|
||||||
|
void md_mm(const torch::Tensor &src);
|
||||||
#endif
|
#endif
|
||||||
@ -15,4 +15,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|||||||
m.def("print_idx", &print_idx, "just_printidx");
|
m.def("print_idx", &print_idx, "just_printidx");
|
||||||
m.def("reducemax", &reducemax, "reduce max");
|
m.def("reducemax", &reducemax, "reduce max");
|
||||||
m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor");
|
m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor");
|
||||||
|
m.def("md_mm", &md_mm, "just a test of multi dimension mm");
|
||||||
}
|
}
|
||||||
|
|||||||
19
csrc/max.cu
19
csrc/max.cu
@ -81,17 +81,17 @@ __global__ void test_cute_tensor_kernel()
|
|||||||
Stride<_32, _2>{});
|
Stride<_32, _2>{});
|
||||||
Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{}));
|
Layout smem_layout = make_layout(make_shape(Int<4>{}, Int<8>{}));
|
||||||
__shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation)
|
__shared__ float smem[decltype(cosize(smem_layout))::value]; // (static-only allocation)
|
||||||
printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
|
// printf("smem size is :%d\n", decltype(cosize(smem_layout))::value);
|
||||||
Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout);
|
Tensor stensor = make_tensor(make_smem_ptr(smem), smem_layout);
|
||||||
printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n",
|
printf("tensor size is: %d, ind size is: %d, rmem size is: %d , rmem4x8 is: %d, smem size is: %d\n",
|
||||||
bool_tensor.size(),
|
bool_tensor.size(),
|
||||||
ind_tensor.size(), rmem_4x8_col.size(),
|
ind_tensor.size(), rmem_4x8_col.size(),
|
||||||
rmem_4x8_pad.size(),
|
rmem_4x8_pad.size(),
|
||||||
stensor.size());
|
stensor.size());
|
||||||
auto TA = make_layout(make_shape(Int<32>{}, Int<8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
|
auto TA = make_layout(make_shape(Int<32>{}, Int<8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
|
||||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TA>{}, // Atom: Copy TAs as if they were uint128_t
|
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, float>{}, // Atom: Copy TAs as if they were uint128_t
|
||||||
Layout<Shape<_32, _8>>{}, // Thr layout 32x8 m-major
|
Layout<Shape<_32, _8>>{}, // Thr layout 32x8 m-major
|
||||||
Layout<Shape<_4, _1>>{}); // Val layout 4x1 m-major
|
Layout<Shape<_4, _1>>{}); // Val layout 4x1 m-major
|
||||||
print_latex(copyA);
|
print_latex(copyA);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,4 +108,13 @@ void test_cute_tensor()
|
|||||||
dim3 thread_block(16, 16);
|
dim3 thread_block(16, 16);
|
||||||
dim3 block(16);
|
dim3 block(16);
|
||||||
test_cute_tensor_kernel<<<block, thread_block>>>();
|
test_cute_tensor_kernel<<<block, thread_block>>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void md_op(const float *a)
|
||||||
|
{
|
||||||
|
int tidx = threadIdx.x;
|
||||||
|
int bid = blockIdx.x;
|
||||||
|
int hid = blockIdx.y;
|
||||||
|
int offset = blockDim.x * blockDim.y;
|
||||||
|
// 绑定到自己的进
|
||||||
}
|
}
|
||||||
29
csrc/md.cu
Normal file
29
csrc/md.cu
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#include "core.h"
|
||||||
|
|
||||||
|
#include <cute/tensor.hpp>
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/array.h>
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
__global__ void md_mm_kernel(const float *src, int stride_a, int stride_b, int stride_c, int thread_num)
|
||||||
|
{
|
||||||
|
int batch_idx = blockIdx.x;
|
||||||
|
int head_idx = blockIdx.y;
|
||||||
|
int sequence_idx = blockIdx.z;
|
||||||
|
int tidx = threadIdx.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
void md_mm(const torch::Tensor &src)
|
||||||
|
{
|
||||||
|
int batch_size = src.size(0);
|
||||||
|
int head_size = src.size(1);
|
||||||
|
int sequence_size = src.size(2);
|
||||||
|
int head_dim = src.size(3);
|
||||||
|
int data_block = sequence_size * head_dim;
|
||||||
|
int thread_num = 256;
|
||||||
|
dim3 grid(batch_size, head_size, (data_block + thread_num - 1) / thread_num);
|
||||||
|
dim3 block(thread_num);
|
||||||
|
md_mm_kernel<<<grid, block>>>(reinterpret_cast<float *>(src.data_ptr()),
|
||||||
|
src.stride(0), src.stride(1), src.stride(2),
|
||||||
|
thread_num);
|
||||||
|
}
|
||||||
@ -1,14 +1,44 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
from transformers import AutoModelForCausalLM, AutoConfig
|
||||||
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
|
from transformers.models import qwen2, gemma2, llama, gemma
|
||||||
from transformers import AutoModel, AutoConfig
|
|
||||||
|
decode_layers = {
|
||||||
|
"gemma": gemma.modeling_gemma.GemmaDecoderLayer,
|
||||||
|
"gemma2": gemma2.modeling_gemma2.Gemma2DecoderLayer,
|
||||||
|
"qwen2": qwen2.modeling_qwen2.Qwen2DecoderLayer,
|
||||||
|
}
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"gemma": gemma.GemmaForCausalLM,
|
||||||
|
"gemma2": gemma2.Gemma2ForCausalLM,
|
||||||
|
"llama": llama.LlamaForCausalLM,
|
||||||
|
"qwen2": qwen2.Qwen2ForCausalLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelLoader:
|
class ModelLoader:
|
||||||
|
|
||||||
def __init__(self, model_path: str):
|
def __init__(self, model_path: str, pipeline_num: int = 1):
|
||||||
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
self.config_path = os.path.join(model_path, "config.json")
|
||||||
|
self.model_config = AutoConfig.from_pretrained(self.config_path)
|
||||||
|
hidden_layers = self.model_config.get("num_hidden_layers", -1)
|
||||||
|
if hidden_layers == -1:
|
||||||
|
raise ValueError("do not has such parameter")
|
||||||
|
self.hidden_layers = hidden_layers
|
||||||
|
self.pipeline_num = pipeline_num
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, trust_remote_code=True
|
||||||
|
)
|
||||||
|
self.model_type = self.model_config["model_type"]
|
||||||
|
self.per_pipeline_layers = self.hidden_layers // self.pipeline_num
|
||||||
|
module_list = None
|
||||||
|
for x in self.model.modules():
|
||||||
|
if isinstance(x, torch.nn.modules.container.ModuleList):
|
||||||
|
module_list = x
|
||||||
|
if module_list is None:
|
||||||
|
raise ValueError("do not have module list.")
|
||||||
|
|||||||
1
setup.py
1
setup.py
@ -10,6 +10,7 @@ files = [
|
|||||||
"csrc/matrix.cu",
|
"csrc/matrix.cu",
|
||||||
"csrc/core_bind.cpp",
|
"csrc/core_bind.cpp",
|
||||||
"csrc/max.cu",
|
"csrc/max.cu",
|
||||||
|
"csrc/md.cu",
|
||||||
]
|
]
|
||||||
extension = CUDAExtension(
|
extension = CUDAExtension(
|
||||||
name="torch_cuda_ext.core",
|
name="torch_cuda_ext.core",
|
||||||
|
|||||||
22
test.cc
Normal file
22
test.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#include <iostream>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
template <auto v>
|
||||||
|
struct C
|
||||||
|
{
|
||||||
|
using type = C<v>;
|
||||||
|
static constexpr auto value = v;
|
||||||
|
using value_type = decltype(v);
|
||||||
|
inline constexpr operator value_type() const noexcept { return value; }
|
||||||
|
inline constexpr value_type operator()() const noexcept { return value; }
|
||||||
|
};
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
using _1 = C<10>;
|
||||||
|
auto x = _1{};
|
||||||
|
cout << _1::value << endl;
|
||||||
|
cout << _1::value_type() << endl;
|
||||||
|
cout << x.value << endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user