[FT] Fix FT's single query attention for bf16 hdim128 rotary

This commit is contained in:
Tri Dao 2023-03-28 21:27:00 -07:00
parent 4d87e4d875
commit f5d0fbd468
4 changed files with 23 additions and 22 deletions

View File

@ -1669,22 +1669,6 @@ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, i
return;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
#endif
template<>
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
@ -1776,6 +1760,20 @@ write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpo
smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y;
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
}
#endif
template<>

View File

@ -494,7 +494,8 @@ class MHA(nn.Module):
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim
self.rotary_emb_dim,
not self.rotary_emb.interleaved # neox_rotary_style
)
context = rearrange(context, 'b h d -> b 1 h d')
else:
@ -607,7 +608,8 @@ class ParallelMHA(nn.Module):
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim
self.rotary_emb_dim,
not self.rotary_emb.interleaved # neox_rotary_style
)
context = rearrange(context, 'b h d -> b 1 h d')
if seqlen is None:

View File

@ -82,6 +82,8 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
Arguments:
input_ids: (batch, seq_len)
max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
@ -111,7 +113,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
start = time.time()
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
@ -129,7 +131,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
inference_params.sequence_len_offset)
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
next_token = sample(logits, top_k=top_k, temperature=temperature)
else:

View File

@ -15,7 +15,6 @@ from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache
@ -61,7 +60,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and",
return_tensors="pt").input_ids.to(device=device)
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40