[FT] Fix FT's single query attention for bf16 hdim128 rotary
This commit is contained in:
parent
4d87e4d875
commit
f5d0fbd468
@ -1669,22 +1669,6 @@ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, i
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template<>
|
||||
__device__ __inline__ void
|
||||
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __inline__ void
|
||||
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
template<>
|
||||
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
|
||||
{
|
||||
@ -1776,6 +1760,20 @@ write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpo
|
||||
smem[transpose_idx] = vec.x;
|
||||
smem[smem_pitch + transpose_idx] = vec.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __inline__ void
|
||||
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
||||
{
|
||||
write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __inline__ void
|
||||
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
|
||||
{
|
||||
write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<>
|
||||
|
||||
@ -494,7 +494,8 @@ class MHA(nn.Module):
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim
|
||||
self.rotary_emb_dim,
|
||||
not self.rotary_emb.interleaved # neox_rotary_style
|
||||
)
|
||||
context = rearrange(context, 'b h d -> b 1 h d')
|
||||
else:
|
||||
@ -607,7 +608,8 @@ class ParallelMHA(nn.Module):
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim
|
||||
self.rotary_emb_dim,
|
||||
not self.rotary_emb.interleaved # neox_rotary_style
|
||||
)
|
||||
context = rearrange(context, 'b h d -> b 1 h d')
|
||||
if seqlen is None:
|
||||
|
||||
@ -82,6 +82,8 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
Arguments:
|
||||
input_ids: (batch, seq_len)
|
||||
max_length: int
|
||||
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
||||
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
||||
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
||||
sequences: (batch, max_length)
|
||||
scores: tuples of (batch, vocab_size)
|
||||
@ -111,7 +113,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
start = time.time()
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits)
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= seqlen_og:
|
||||
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
@ -129,7 +131,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
inference_params.sequence_len_offset)
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits)
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
|
||||
next_token = sample(logits, top_k=top_k, temperature=temperature)
|
||||
else:
|
||||
|
||||
@ -15,7 +15,6 @@ from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
|
||||
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.distributed import all_gather_raw
|
||||
from flash_attn.utils.generation import update_graph_cache
|
||||
|
||||
|
||||
@ -61,7 +60,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
|
||||
torch.manual_seed(0)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
input_ids = tokenizer("Hello, my dog is cute and",
|
||||
return_tensors="pt").input_ids.to(device=device)
|
||||
return_tensors="pt").input_ids.to(device=device)
|
||||
max_length = 30
|
||||
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
|
||||
# max_length = input_ids.shape[1] + 40
|
||||
|
||||
Loading…
Reference in New Issue
Block a user