remove numpy dependency
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
This commit is contained in:
parent
88dc2040a0
commit
4e38df059e
@ -1,7 +1,5 @@
|
||||
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -15,7 +13,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
||||
ctx.save_for_backward(indices)
|
||||
assert input.ndim >= 2
|
||||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
||||
second_dim = np.prod(other_shape)
|
||||
second_dim = other_shape.numel()
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
# return input[indices]
|
||||
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
|
||||
@ -71,7 +69,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
|
||||
ctx.save_for_backward(indices)
|
||||
assert input.ndim >= 2
|
||||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
||||
second_dim = np.prod(other_shape)
|
||||
second_dim = other_shape.numel()
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
output = input[indices]
|
||||
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
||||
|
||||
Loading…
Reference in New Issue
Block a user