From 2800efc71fbd56436829eafad90c795a4fe6a73f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 6 Jul 2023 15:33:33 -0700 Subject: [PATCH] [FT] rotary_cos/sin should have batch_size dimension --- ...coder_masked_multihead_attention_template.hpp | 16 ++++++++++++---- csrc/ft_attention/ft_attention.cpp | 6 +++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp index f99b818..f0e085c 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp @@ -1065,14 +1065,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params