From 30e1ef0f79418af5ad52987e49d691f0d4519c46 Mon Sep 17 00:00:00 2001 From: Zhihao Shen Date: Wed, 18 Sep 2024 15:32:59 +0800 Subject: [PATCH] minify torch.torch.int32 to torch.int32 (#1237) --- flash_attn/bert_padding.py | 4 ++-- flash_attn/models/bert.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index ce8e4ca..3c2d351 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -113,7 +113,7 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None): used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to @@ -187,7 +187,7 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 6a78b1e..0904631 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -189,13 +189,13 @@ class BertEncoder(nn.Module): ).flatten() subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32) subset_cu_seqlens = F.pad( - torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0) + torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0) ) else: subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten() subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32) subset_cu_seqlens = F.pad( - torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0) + torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0) ) hidden_states_subset, hidden_states = index_first_axis_residual( hidden_states, subset_idx