remove numpy dependency
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
This commit is contained in:
parent
88dc2040a0
commit
4e38df059e
@ -1,7 +1,5 @@
|
|||||||
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@ -15,7 +13,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
|||||||
ctx.save_for_backward(indices)
|
ctx.save_for_backward(indices)
|
||||||
assert input.ndim >= 2
|
assert input.ndim >= 2
|
||||||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
||||||
second_dim = np.prod(other_shape)
|
second_dim = other_shape.numel()
|
||||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||||
# return input[indices]
|
# return input[indices]
|
||||||
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
|
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
|
||||||
@ -71,7 +69,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
|
|||||||
ctx.save_for_backward(indices)
|
ctx.save_for_backward(indices)
|
||||||
assert input.ndim >= 2
|
assert input.ndim >= 2
|
||||||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
||||||
second_dim = np.prod(other_shape)
|
second_dim = other_shape.numel()
|
||||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||||
output = input[indices]
|
output = input[indices]
|
||||||
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user