diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp index 41a8485..03b7199 100644 --- a/csrc/ft_attention/ft_attention.cpp +++ b/csrc/ft_attention/ft_attention.cpp @@ -23,17 +23,6 @@ AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ } -// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ -// if (TYPE == at::ScalarType::Half) { \ -// using scalar_t = at::Half; \ -// __VA_ARGS__(); \ -// } else if (TYPE == at::ScalarType::Float) { \ -// using scalar_t = float; \ -// __VA_ARGS__(); \ -// } else { \ -// AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ -// } - template void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); @@ -66,6 +55,7 @@ void set_params(Masked_multihead_attention_params ¶ms, const int timestep, const int rotary_embedding_dim, const bool neox_rotary_style, + const int qkv_batch_stride, T *q_ptr, T *k_ptr, T *v_ptr, @@ -85,7 +75,7 @@ void set_params(Masked_multihead_attention_params ¶ms, params.v_cache = v_cache_ptr; params.out = out_ptr; params.cache_indir = nullptr; - params.stride = 0; + params.stride = qkv_batch_stride; params.batch_size = batch_size; params.beam_width = 1; params.memory_max_len = memory_max_seqlen; @@ -98,8 +88,7 @@ void set_params(Masked_multihead_attention_params ¶ms, params.total_padding_tokens = nullptr; params.masked_tokens = nullptr; params.prefix_prompt_lengths = nullptr; - // params.max_prefix_prompt_length = memory_max_seqlen; // TODO: waht should this be? - params.max_prefix_prompt_length = 0; // TODO: waht should this be? + params.max_prefix_prompt_length = 0; params.relative_attention_bias = nullptr; params.relative_attention_bias_stride = 0; params.cross_attention_out = nullptr; @@ -127,10 +116,15 @@ torch::Tensor single_query_attention(const torch::Tensor q, CHECK_SHAPE(q, batch_size, nheads, headdim); CHECK_SHAPE(k, batch_size, nheads, headdim); CHECK_SHAPE(v, batch_size, nheads, headdim); - // TODO: Check shape of k_cache: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 - // TODO: avoid contiguous requirment by storing the stride - CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v); - CHECK_CONTIGUOUS(v_cache); + CHECK_SHAPE(v_cache, batch_size, nheads, memory_max_seqlen, headdim); + // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 + int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; + CHECK_SHAPE(k_cache, batch_size, nheads, headdim / packsize, memory_max_seqlen, packsize); + TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); + TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); + TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); + TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0)); + CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); if (length_per_sample_.has_value()) { auto length_per_sample = length_per_sample_.value(); @@ -146,11 +140,11 @@ torch::Tensor single_query_attention(const torch::Tensor q, torch::Tensor out = torch::empty_like(q); - DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] { + DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { using DataType = typename SATypeConverter::Type; Masked_multihead_attention_params params; set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep, - rotary_embedding_dim, neox_rotary_style, + rotary_embedding_dim, neox_rotary_style, q.stride(0), reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index c4e026e..4ec3082 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -57,7 +57,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.to(device=device) max_length = 30 - # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda') + # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # max_length = input_ids.shape[1] + 40 # Slow generation for reference