Updated missing docstrings for args and returns in bert_padding.py (#795)

* Updated docstrings of bert_padding.py

Added docstrings for missing arguments in the unpad and pad methods.

* Update bert_padding.py

Fixed spelling mistakes
This commit is contained in:
Avelina9X 2024-01-27 17:16:25 +00:00 committed by GitHub
parent ffc8682dd5
commit c94cd09744
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -102,6 +102,7 @@ def unpad_input(hidden_states, attention_mask):
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
@ -170,6 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
@ -198,7 +200,9 @@ def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""