minify torch.torch.int32 to torch.int32 (#1237)
This commit is contained in:
parent
83e41b3ca4
commit
30e1ef0f79
@ -113,7 +113,7 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
|||||||
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
||||||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
||||||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
||||||
@ -187,7 +187,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
|
|||||||
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
||||||
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
||||||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
||||||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
||||||
|
|||||||
@ -189,13 +189,13 @@ class BertEncoder(nn.Module):
|
|||||||
).flatten()
|
).flatten()
|
||||||
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
|
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
|
||||||
subset_cu_seqlens = F.pad(
|
subset_cu_seqlens = F.pad(
|
||||||
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
||||||
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
subset_cu_seqlens = F.pad(
|
subset_cu_seqlens = F.pad(
|
||||||
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
)
|
)
|
||||||
hidden_states_subset, hidden_states = index_first_axis_residual(
|
hidden_states_subset, hidden_states = index_first_axis_residual(
|
||||||
hidden_states, subset_idx
|
hidden_states, subset_idx
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user