minify torch.torch.int32 to torch.int32 (#1237)

This commit is contained in:
Zhihao Shen 2024-09-18 15:32:59 +08:00 committed by GitHub
parent 83e41b3ca4
commit 30e1ef0f79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -113,7 +113,7 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None):
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
@ -187,7 +187,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to

View File

@ -189,13 +189,13 @@ class BertEncoder(nn.Module):
).flatten()
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0)
)
else:
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0)
)
hidden_states_subset, hidden_states = index_first_axis_residual(
hidden_states, subset_idx