From 4e38df059e911ef3cec872adefbe32779d422e3a Mon Sep 17 00:00:00 2001 From: Antoine Adam <108952645+ajfadam@users.noreply.github.com> Date: Thu, 6 Oct 2022 19:17:15 +0200 Subject: [PATCH] remove numpy dependency According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy. --- flash_attn/bert_padding.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index 74c2ea5..b071416 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -1,7 +1,5 @@ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py -import numpy as np - import torch import torch.nn.functional as F @@ -15,7 +13,7 @@ class IndexFirstAxis(torch.autograd.Function): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = np.prod(other_shape) + second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, @@ -71,7 +69,7 @@ class IndexFirstAxisResidual(torch.autograd.Function): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = np.prod(other_shape) + second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. output = input[indices] # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last