diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index 74c2ea5..b071416 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -1,7 +1,5 @@ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py -import numpy as np - import torch import torch.nn.functional as F @@ -15,7 +13,7 @@ class IndexFirstAxis(torch.autograd.Function): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = np.prod(other_shape) + second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, @@ -71,7 +69,7 @@ class IndexFirstAxisResidual(torch.autograd.Function): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = np.prod(other_shape) + second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. output = input[indices] # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last