343 lines
14 KiB
Python
343 lines
14 KiB
Python
|
|
import math
|
||
|
|
from functools import partial
|
||
|
|
from collections import namedtuple
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from torch.nn.modules.utils import _pair
|
||
|
|
|
||
|
|
import hydra
|
||
|
|
|
||
|
|
from einops import reduce, rearrange
|
||
|
|
|
||
|
|
|
||
|
|
def pooling(x, pooling_mode='CLS', key_padding_mask=None, batch_first=True):
|
||
|
|
if pooling_mode not in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN']:
|
||
|
|
raise NotImplementedError(f'pooling_mode must be MEAN, SUM, CLS, LAST, FLATTEN')
|
||
|
|
if pooling_mode in ['MEAN', 'SUM']:
|
||
|
|
if key_padding_mask is not None:
|
||
|
|
mask = rearrange(~key_padding_mask.bool_matrix,
|
||
|
|
'b s -> b s 1' if batch_first else 'b s -> s b 1')
|
||
|
|
x = x.masked_fill(mask, 0)
|
||
|
|
s = reduce(x, 'b s ... -> b ...' if batch_first else 's b ... -> b ...', 'sum')
|
||
|
|
if pooling_mode == 'SUM':
|
||
|
|
return s
|
||
|
|
else:
|
||
|
|
if key_padding_mask is None:
|
||
|
|
return s / x.shape[1 if batch_first else 0]
|
||
|
|
else:
|
||
|
|
lengths = rearrange(key_padding_mask._lengths, 'b -> b 1')
|
||
|
|
return s / lengths
|
||
|
|
elif pooling_mode == 'CLS':
|
||
|
|
return x[:, 0] if batch_first else x[0]
|
||
|
|
elif pooling_mode == 'LAST':
|
||
|
|
if key_padding_mask is None:
|
||
|
|
return x[:, -1] if batch_first else x[-1]
|
||
|
|
else:
|
||
|
|
lengths = key_padding_mask._lengths
|
||
|
|
if batch_first:
|
||
|
|
batch_size = x.shape[0]
|
||
|
|
return x[torch.arange(batch_size, device=x.device), lengths - 1]
|
||
|
|
else:
|
||
|
|
batch_size = x.shape[1]
|
||
|
|
return x[lengths - 1, torch.arange(batch_size, device=x.device)]
|
||
|
|
elif pooling_mode == 'FLATTEN':
|
||
|
|
return rearrange(x, 'b ... -> b (...)' if batch_first else 's b ... -> b (s ...)')
|
||
|
|
|
||
|
|
|
||
|
|
class ClassificationHeadLinear(nn.Module):
|
||
|
|
"""Head for sentence-level classification tasks."""
|
||
|
|
|
||
|
|
def __init__(self, d_model, num_classes, pooling_mode='MEAN',
|
||
|
|
batch_first=False, **kwargs):
|
||
|
|
super().__init__()
|
||
|
|
assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported'
|
||
|
|
self.pooling_mode = pooling_mode
|
||
|
|
self.batch_first = batch_first
|
||
|
|
self.out_proj = nn.Linear(d_model, num_classes)
|
||
|
|
|
||
|
|
def forward(self, hidden_states, key_padding_mask=None, **kwargs):
|
||
|
|
"""
|
||
|
|
hidden_states: (B, S, D) if batch_first else (S, B, D)
|
||
|
|
"""
|
||
|
|
hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode,
|
||
|
|
key_padding_mask=key_padding_mask, batch_first=self.batch_first)
|
||
|
|
hidden_states = self.out_proj(hidden_states)
|
||
|
|
return hidden_states
|
||
|
|
|
||
|
|
|
||
|
|
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/models/reformer/modeling_reformer.py
|
||
|
|
class ClassificationHead(nn.Module):
|
||
|
|
"""Head for sentence-level classification tasks."""
|
||
|
|
|
||
|
|
def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN',
|
||
|
|
batch_first=False):
|
||
|
|
super().__init__()
|
||
|
|
assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported'
|
||
|
|
self.pooling_mode = pooling_mode
|
||
|
|
self.batch_first = batch_first
|
||
|
|
self.dense = nn.Linear(d_model, d_inner)
|
||
|
|
self.dropout = nn.Dropout(dropout)
|
||
|
|
self.out_proj = nn.Linear(d_inner, num_classes)
|
||
|
|
|
||
|
|
def forward(self, hidden_states, key_padding_mask=None, **kwargs):
|
||
|
|
"""
|
||
|
|
hidden_states: (B, S, D) if batch_first else (S, B, D)
|
||
|
|
"""
|
||
|
|
hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode,
|
||
|
|
key_padding_mask=key_padding_mask, batch_first=self.batch_first)
|
||
|
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
hidden_states = self.dense(hidden_states)
|
||
|
|
# Huggingface uses tanh instead of relu
|
||
|
|
hidden_states = torch.relu(hidden_states)
|
||
|
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
hidden_states = self.out_proj(hidden_states)
|
||
|
|
return hidden_states
|
||
|
|
|
||
|
|
|
||
|
|
class ClassificationHeadDual(nn.Module):
|
||
|
|
"""Head for sentence-level classification tasks."""
|
||
|
|
|
||
|
|
def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN',
|
||
|
|
batch_first=False, interaction='NLI'):
|
||
|
|
super().__init__()
|
||
|
|
assert pooling_mode in ['MEAN', 'SUM', 'CLS'], 'pooling_mode not supported'
|
||
|
|
assert interaction in [None, 'NLI'], 'interaction not supported'
|
||
|
|
self.pooling_mode = pooling_mode
|
||
|
|
self.batch_first = batch_first
|
||
|
|
self.interaction = interaction
|
||
|
|
self.dense = nn.Linear(d_model * (4 if self.interaction == 'NLI' else 2), d_inner)
|
||
|
|
self.dropout = nn.Dropout(dropout)
|
||
|
|
self.out_proj = nn.Linear(d_inner, num_classes)
|
||
|
|
|
||
|
|
def forward(self, hidden_states1, hidden_states2,
|
||
|
|
key_padding_mask1=None, key_padding_mask2=None, **kwargs):
|
||
|
|
"""
|
||
|
|
hidden_states: (B, S, D) if batch_first else (S, B, D)
|
||
|
|
"""
|
||
|
|
x1 = pooling(hidden_states1, pooling_mode=self.pooling_mode,
|
||
|
|
key_padding_mask=key_padding_mask1, batch_first=self.batch_first)
|
||
|
|
x2 = pooling(hidden_states2, pooling_mode=self.pooling_mode,
|
||
|
|
key_padding_mask=key_padding_mask2, batch_first=self.batch_first)
|
||
|
|
hidden_states = (torch.cat([x1, x2, x1 * x2, x1 - x2], dim=-1) if self.interaction == 'NLI'
|
||
|
|
else torch.cat([x1, x2], dim=-1))
|
||
|
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
hidden_states = self.dense(hidden_states)
|
||
|
|
# Huggingface uses tanh instead of relu
|
||
|
|
hidden_states = torch.relu(hidden_states)
|
||
|
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
hidden_states = self.out_proj(hidden_states)
|
||
|
|
return hidden_states
|
||
|
|
|
||
|
|
|
||
|
|
class LMHead(nn.Module):
|
||
|
|
|
||
|
|
def __init__(self, d_model, num_classes, batch_first=True, bias=True):
|
||
|
|
super().__init__()
|
||
|
|
self.lm_head = nn.Linear(d_model, num_classes, bias=bias)
|
||
|
|
|
||
|
|
def forward(self, hidden_states, **kwargs):
|
||
|
|
"""
|
||
|
|
hidden_states: (B, S, D) if batch_first else (S, B, D)
|
||
|
|
"""
|
||
|
|
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
|
||
|
|
return CausalLMOutput(self.lm_head(hidden_states))
|
||
|
|
|
||
|
|
|
||
|
|
def sinusoidal_init_(tensor):
|
||
|
|
"""
|
||
|
|
tensor: (max_len, d_model)
|
||
|
|
"""
|
||
|
|
max_len, d_model = tensor.shape
|
||
|
|
position = rearrange(torch.arange(0.0, max_len), 's -> s 1')
|
||
|
|
div_term = torch.exp(-math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model)
|
||
|
|
tensor[:, 0::2] = torch.sin(position * div_term)
|
||
|
|
tensor[:, 1::2] = torch.cos(position * div_term)
|
||
|
|
return tensor
|
||
|
|
|
||
|
|
|
||
|
|
# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py
|
||
|
|
class PositionalEncoding(nn.Module):
|
||
|
|
r"""Inject some information about the relative or absolute position of the tokens
|
||
|
|
in the sequence. The positional encodings have the same dimension as
|
||
|
|
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
||
|
|
functions of different frequencies.
|
||
|
|
.. math::
|
||
|
|
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
||
|
|
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
||
|
|
\text{where pos is the word position and i is the embed idx)
|
||
|
|
Args:
|
||
|
|
d_model: the embed dim (required).
|
||
|
|
dropout: the dropout value (default=0.1).
|
||
|
|
max_len: the max. length of the incoming sequence (default=5000).
|
||
|
|
Examples:
|
||
|
|
>>> pos_encoder = PositionalEncoding(d_model)
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False, initializer=None):
|
||
|
|
super().__init__()
|
||
|
|
self.batch_first = batch_first
|
||
|
|
self.dropout = nn.Dropout(p=dropout)
|
||
|
|
pe = torch.empty(max_len, d_model)
|
||
|
|
if initializer is None:
|
||
|
|
sinusoidal_init_(pe)
|
||
|
|
pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d')
|
||
|
|
self.register_buffer('pe', pe)
|
||
|
|
else:
|
||
|
|
hydra.utils.call(initializer, pe)
|
||
|
|
pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d')
|
||
|
|
self.pe = nn.Parameter(pe)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
r"""Inputs of forward function
|
||
|
|
Args:
|
||
|
|
x: the sequence fed to the positional encoder model (required).
|
||
|
|
Shape:
|
||
|
|
x: [sequence length, batch size, embed dim] if not batch_first else [B, S, D]
|
||
|
|
output: [sequence length, batch size, embed dim] if not batch_first else [B, S, D]
|
||
|
|
Examples:
|
||
|
|
>>> output = pos_encoder(x)
|
||
|
|
"""
|
||
|
|
x = x + (self.pe[:, :x.size(1)] if self.batch_first else self.pe[:x.size(0)])
|
||
|
|
return self.dropout(x)
|
||
|
|
|
||
|
|
|
||
|
|
# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
|
||
|
|
class Mlp(nn.Module):
|
||
|
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||
|
|
"""
|
||
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
|
||
|
|
act_fn=None, drop=0., device=None, dtype=None):
|
||
|
|
"""TD [2021-10-27] act_fn takes precedence over act_layer if set.
|
||
|
|
This is to support Pytorch 1.10 Transformer interface that construct the activation
|
||
|
|
*function*, not the activation *layer*.
|
||
|
|
"""
|
||
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||
|
|
super().__init__()
|
||
|
|
out_features = out_features or in_features
|
||
|
|
hidden_features = hidden_features or in_features
|
||
|
|
drop_probs = _pair(drop)
|
||
|
|
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
|
||
|
|
self.act = act_layer() if act_fn is None else act_fn
|
||
|
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||
|
|
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
|
||
|
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.fc1(x)
|
||
|
|
x = self.act(x)
|
||
|
|
x = self.drop1(x)
|
||
|
|
x = self.fc2(x)
|
||
|
|
x = self.drop2(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
class MlpBig(nn.Module):
|
||
|
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||
|
|
"""
|
||
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
|
||
|
|
act_fn=None, drop=0., device=None, dtype=None):
|
||
|
|
"""Copied from Mlp above. If num_layers > 2, add more Mlp layers, doubling each time.
|
||
|
|
"""
|
||
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||
|
|
super().__init__()
|
||
|
|
out_features = out_features or in_features
|
||
|
|
hidden_features = hidden_features or in_features
|
||
|
|
cur_hidden_features = hidden_features
|
||
|
|
layers = []
|
||
|
|
for _ in range(4):
|
||
|
|
layers.append(nn.Linear(in_features, cur_hidden_features, **factory_kwargs))
|
||
|
|
layers.append(act_layer())
|
||
|
|
layers.append(nn.Dropout(drop))
|
||
|
|
in_features = cur_hidden_features
|
||
|
|
cur_hidden_features *= 2
|
||
|
|
layers.append(nn.Linear(in_features, out_features, **factory_kwargs))
|
||
|
|
layers.append(nn.Dropout(drop))
|
||
|
|
self.fwd = nn.Sequential(*layers)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.fwd(x)
|
||
|
|
|
||
|
|
class GluMlp(nn.Module):
|
||
|
|
""" MLP w/ GLU style gating
|
||
|
|
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
|
||
|
|
"""
|
||
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
|
||
|
|
super().__init__()
|
||
|
|
out_features = out_features or in_features
|
||
|
|
hidden_features = hidden_features or in_features
|
||
|
|
assert hidden_features % 2 == 0
|
||
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||
|
|
self.act = act_layer()
|
||
|
|
self.fc2 = nn.Linear(hidden_features // 2, out_features)
|
||
|
|
self.drop = nn.Dropout(drop)
|
||
|
|
|
||
|
|
def init_weights(self):
|
||
|
|
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||
|
|
fc1_mid = self.fc1.bias.shape[0] // 2
|
||
|
|
nn.init.ones_(self.fc1.bias[fc1_mid:])
|
||
|
|
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.fc1(x)
|
||
|
|
x, gates = x.chunk(2, dim=-1)
|
||
|
|
x = x * self.act(gates)
|
||
|
|
x = self.drop(x)
|
||
|
|
x = self.fc2(x)
|
||
|
|
x = self.drop(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
class GatedMlp(nn.Module):
|
||
|
|
""" MLP as used in gMLP
|
||
|
|
"""
|
||
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
|
||
|
|
gate_layer=None, drop=0.):
|
||
|
|
super().__init__()
|
||
|
|
out_features = out_features or in_features
|
||
|
|
hidden_features = hidden_features or in_features
|
||
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||
|
|
self.act = act_layer()
|
||
|
|
if gate_layer is not None:
|
||
|
|
assert hidden_features % 2 == 0
|
||
|
|
self.gate = gate_layer(hidden_features)
|
||
|
|
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
|
||
|
|
else:
|
||
|
|
self.gate = nn.Identity()
|
||
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||
|
|
self.drop = nn.Dropout(drop)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.fc1(x)
|
||
|
|
x = self.act(x)
|
||
|
|
x = self.drop(x)
|
||
|
|
x = self.gate(x)
|
||
|
|
x = self.fc2(x)
|
||
|
|
x = self.drop(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
class ConvMlp(nn.Module):
|
||
|
|
""" MLP using 1x1 convs that keeps spatial dims
|
||
|
|
"""
|
||
|
|
def __init__(
|
||
|
|
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
|
||
|
|
super().__init__()
|
||
|
|
out_features = out_features or in_features
|
||
|
|
hidden_features = hidden_features or in_features
|
||
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
|
||
|
|
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
|
||
|
|
self.act = act_layer()
|
||
|
|
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
|
||
|
|
self.drop = nn.Dropout(drop)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.fc1(x)
|
||
|
|
x = self.norm(x)
|
||
|
|
x = self.act(x)
|
||
|
|
x = self.drop(x)
|
||
|
|
x = self.fc2(x)
|
||
|
|
return x
|
||
|
|
|