# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py import numpy as np import torch import torch.nn.functional as F from einops import rearrange, repeat class IndexFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = np.prod(other_shape) # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output): indices, = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] grad_output = rearrange(grad_output, 'b ... -> b (...)') grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. # grad_input[indices] = grad_output grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None index_first_axis = IndexFirstAxis.apply class IndexPutFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, values, indices, first_axis_dim): ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. output[indices] = values # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) return output @staticmethod def backward(ctx, grad_output): indices, = ctx.saved_tensors # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. grad_values = grad_output[indices] # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) return grad_values, None, None index_put_first_axis = IndexPutFirstAxis.apply class IndexFirstAxisResidual(torch.autograd.Function): @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = np.prod(other_shape) # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. output = input[indices] # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last # memory format to channel_first. In other words, input might not be contiguous. # If we don't detach, Pytorch complains about output being a view and is being modified inplace return output, input.detach() @staticmethod def backward(ctx, grad_output, grad_residual): indices, = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] assert grad_residual.shape[1:] == other_shape grad_input = grad_residual # grad_input[indices] += grad_output indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) indices = indices.expand_as(grad_output) grad_input.scatter_add_(0, indices, grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None index_first_axis_residual = IndexFirstAxisResidual.apply def unpad_input(hidden_states, attention_mask): """ Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, # so we write custom forward and backward to make it a bit faster. return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, cu_seqlens, max_seqlen_in_batch) def pad_input(hidden_states, indices, batch, seqlen): """ Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz) Return: hidden_states: (batch, seqlen, ...) """ dim = hidden_states.shape[-1] # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) # output[indices] = hidden_states output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange(output, '(b s) ... -> b s ...', b=batch)