From c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b Mon Sep 17 00:00:00 2001 From: Avelina9X <37878580+Avelina9X@users.noreply.github.com> Date: Sat, 27 Jan 2024 17:16:25 +0000 Subject: [PATCH] Updated missing docstrings for args and returns in bert_padding.py (#795) * Updated docstrings of bert_padding.py Added docstrings for missing arguments in the unpad and pad methods. * Update bert_padding.py Fixed spelling mistakes --- flash_attn/bert_padding.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index eff42c9..1d447d3 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -102,6 +102,7 @@ def unpad_input(hidden_states, attention_mask): attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int """ @@ -170,6 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int """ @@ -198,7 +200,9 @@ def pad_input(hidden_states, indices, batch, seqlen): """ Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz) + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. Return: hidden_states: (batch, seqlen, ...) """