[ViT] Support timm checkpoint, add tests
This commit is contained in:
parent
2ec7d3f72c
commit
780e8eeabb
@ -1,9 +1,12 @@
|
|||||||
# Copyright (c) 2022, Tri Dao.
|
# Copyright (c) 2022, Tri Dao.
|
||||||
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -218,6 +221,7 @@ class VisionTransformer(nn.Module):
|
|||||||
hidden_states = self._pos_embed(x)
|
hidden_states = self._pos_embed(x)
|
||||||
residual = None
|
residual = None
|
||||||
if self.global_pool != 'token' or all_tokens:
|
if self.global_pool != 'token' or all_tokens:
|
||||||
|
# if True:
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
hidden_states, residual = block(hidden_states, residual)
|
hidden_states, residual = block(hidden_states, residual)
|
||||||
else:
|
else:
|
||||||
@ -225,10 +229,8 @@ class VisionTransformer(nn.Module):
|
|||||||
hidden_states, residual = block(hidden_states, residual)
|
hidden_states, residual = block(hidden_states, residual)
|
||||||
# For the last layer, we only want the 1st token of the output. So we do cross-attention
|
# For the last layer, we only want the 1st token of the output. So we do cross-attention
|
||||||
# where the query is the 1st token and the key/value is the whole sequence.
|
# where the query is the 1st token and the key/value is the whole sequence.
|
||||||
hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d')
|
hidden_states, residual = self.blocks[-1](hidden_states, residual,
|
||||||
residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d')
|
mixer_subset=slice(0, 1))
|
||||||
hidden_states, residual = self.blocks[-1](hidden_states_1st, residual_1st,
|
|
||||||
mixer_kwargs={'x_kv': hidden_states})
|
|
||||||
if not self.fused_dropout_add_ln:
|
if not self.fused_dropout_add_ln:
|
||||||
residual = self.drop_path(self.dropout(hidden_states)) + residual
|
residual = self.drop_path(self.dropout(hidden_states)) + residual
|
||||||
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
||||||
@ -258,6 +260,29 @@ class VisionTransformer(nn.Module):
|
|||||||
x = self.forward_head(x)
|
x = self.forward_head(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
patch_embed_weight = state_dict['patch_embed.proj.weight']
|
||||||
|
if patch_embed_weight.dim() == 4:
|
||||||
|
# convert from Conv2d to Linear
|
||||||
|
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
|
||||||
|
'o c h w -> o (c h w)')
|
||||||
|
def key_mapping_attn(key):
|
||||||
|
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
|
||||||
|
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
|
||||||
|
return key
|
||||||
|
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||||
|
n_layer = len(self.blocks)
|
||||||
|
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
|
||||||
|
if (self.blocks[-1].mixer.cross_attn
|
||||||
|
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
|
||||||
|
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
|
||||||
|
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
|
||||||
|
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
|
||||||
|
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
|
||||||
|
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
|
||||||
|
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
|
||||||
|
return super().load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
|
||||||
def init_weights_vit_timm(module: nn.Module, name: str = ''):
|
def init_weights_vit_timm(module: nn.Module, name: str = ''):
|
||||||
""" ViT weight initialization, original timm impl (for reproducibility) """
|
""" ViT weight initialization, original timm impl (for reproducibility) """
|
||||||
|
|||||||
@ -89,12 +89,15 @@ class Block(nn.Module):
|
|||||||
p._shared_params = True
|
p._shared_params = True
|
||||||
|
|
||||||
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
|
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
|
||||||
mixer_kwargs=None):
|
mixer_subset=None, mixer_kwargs=None):
|
||||||
r"""Pass the input through the encoder layer.
|
r"""Pass the input through the encoder layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states: the sequence to the encoder layer (required).
|
hidden_states: the sequence to the encoder layer (required).
|
||||||
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
||||||
|
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
||||||
|
before applying the query projection. Useful for e.g., ViT where we only care
|
||||||
|
about the CLS token in the last layer.
|
||||||
"""
|
"""
|
||||||
if self.prenorm:
|
if self.prenorm:
|
||||||
if not self.fused_dropout_add_ln:
|
if not self.fused_dropout_add_ln:
|
||||||
@ -116,8 +119,12 @@ class Block(nn.Module):
|
|||||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||||
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||||
)
|
)
|
||||||
hidden_states = self.mixer(hidden_states,
|
if mixer_kwargs is None:
|
||||||
**(mixer_kwargs if mixer_kwargs is not None else {}))
|
mixer_kwargs = {}
|
||||||
|
mixer_kwargs['mixer_subset'] = mixer_subset
|
||||||
|
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
||||||
|
if mixer_subset is not None:
|
||||||
|
residual = residual[:, mixer_subset]
|
||||||
if not isinstance(self.mlp, nn.Identity):
|
if not isinstance(self.mlp, nn.Identity):
|
||||||
if not self.fused_dropout_add_ln:
|
if not self.fused_dropout_add_ln:
|
||||||
dropped = self.drop_path2(self.dropout2(hidden_states))
|
dropped = self.drop_path2(self.dropout2(hidden_states))
|
||||||
|
|||||||
@ -420,7 +420,7 @@ class MHA(nn.Module):
|
|||||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||||
|
|
||||||
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
||||||
inference_params=None, **kwargs):
|
mixer_subset=None, inference_params=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
||||||
@ -433,6 +433,9 @@ class MHA(nn.Module):
|
|||||||
max_seqlen: int. Maximum sequence length in the batch.
|
max_seqlen: int. Maximum sequence length in the batch.
|
||||||
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
||||||
(batch, seqlen). Only applicable when not using FlashAttention.
|
(batch, seqlen). Only applicable when not using FlashAttention.
|
||||||
|
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
||||||
|
before applying the query projection. Useful for e.g., ViT where we only care
|
||||||
|
about the CLS token in the last layer.
|
||||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||||
"""
|
"""
|
||||||
@ -454,6 +457,7 @@ class MHA(nn.Module):
|
|||||||
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
|
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
|
||||||
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
|
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
|
||||||
if not self.cross_attn:
|
if not self.cross_attn:
|
||||||
|
assert x_kv is None and mixer_subset is None
|
||||||
if not self.return_residual:
|
if not self.return_residual:
|
||||||
qkv = self.Wqkv(x)
|
qkv = self.Wqkv(x)
|
||||||
else:
|
else:
|
||||||
@ -491,14 +495,14 @@ class MHA(nn.Module):
|
|||||||
context = rearrange(context, 'b h d -> b 1 h d')
|
context = rearrange(context, 'b h d -> b 1 h d')
|
||||||
else:
|
else:
|
||||||
if not self.return_residual:
|
if not self.return_residual:
|
||||||
q = self.Wq(x)
|
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
||||||
kv = self.Wkv(x_kv if x_kv is not None else x)
|
kv = self.Wkv(x_kv if x_kv is not None else x)
|
||||||
else:
|
else:
|
||||||
if x_kv is not None:
|
if x_kv is not None:
|
||||||
kv, x_kv = self.Wkv(x_kv)
|
kv, x_kv = self.Wkv(x_kv)
|
||||||
else:
|
else:
|
||||||
kv, x = self.Wkv(x)
|
kv, x = self.Wkv(x)
|
||||||
q = self.Wq(x)
|
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
||||||
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
|
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
|
||||||
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
|
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
|
||||||
if self.dwconv:
|
if self.dwconv:
|
||||||
|
|||||||
@ -26,7 +26,7 @@ def test_opt_state_dict(model_name):
|
|||||||
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
|
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
|
||||||
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
|
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
|
||||||
def test_opt_optimized(model_name):
|
def test_opt_optimized(model_name):
|
||||||
"""Check that our implementation of OPT (without any optimizations enabled) matches the
|
"""Check that our implementation of OPT (without all optimizations enabled) matches the
|
||||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||||
"""
|
"""
|
||||||
|
|||||||
49
tests/models/test_vit.py
Normal file
49
tests/models/test_vit.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from timm.models.vision_transformer import vit_base_patch16_224
|
||||||
|
|
||||||
|
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
|
||||||
|
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
|
||||||
|
@pytest.mark.parametrize('optimized', [False, True])
|
||||||
|
# @pytest.mark.parametrize('optimized', [True])
|
||||||
|
def test_vit(optimized, fused_dense_gelu_dense):
|
||||||
|
"""Check that our implementation of ViT matches the timm's implementation:
|
||||||
|
the output of our forward pass in fp16 should be around the same as
|
||||||
|
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
|
||||||
|
"""
|
||||||
|
dtype = torch.float16
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if optimized:
|
||||||
|
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
|
||||||
|
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
|
||||||
|
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
|
||||||
|
model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
model.load_state_dict(model_ref.state_dict())
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model_ref.eval()
|
||||||
|
model_timm.eval()
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
batch_size = 2
|
||||||
|
x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
|
||||||
|
out = model(x)
|
||||||
|
out_timm = model_timm(x)
|
||||||
|
out_ref = model_ref(x.float())
|
||||||
|
|
||||||
|
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||||
|
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||||
|
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
|
||||||
|
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
|
||||||
|
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()
|
||||||
Loading…
Reference in New Issue
Block a user