Updated missing docstrings for args and returns in bert_padding.py (#795)
* Updated docstrings of bert_padding.py Added docstrings for missing arguments in the unpad and pad methods. * Update bert_padding.py Fixed spelling mistakes
This commit is contained in:
parent
ffc8682dd5
commit
c94cd09744
@ -102,6 +102,7 @@ def unpad_input(hidden_states, attention_mask):
|
||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||
Return:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
||||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||||
max_seqlen_in_batch: int
|
||||
"""
|
||||
@ -170,6 +171,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
|
||||
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
|
||||
Return:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
||||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||||
max_seqlen_in_batch: int
|
||||
"""
|
||||
@ -198,7 +200,9 @@ def pad_input(hidden_states, indices, batch, seqlen):
|
||||
"""
|
||||
Arguments:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz)
|
||||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
||||
batch: int, batch size for the padded sequence.
|
||||
seqlen: int, maximum sequence length for the padded sequence.
|
||||
Return:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user