[ViT] Support timm checkpoint, add tests

This commit is contained in:
Tri Dao 2023-01-16 01:20:04 -08:00
parent 2ec7d3f72c
commit 780e8eeabb
5 changed files with 96 additions and 11 deletions

View File

@ -1,9 +1,12 @@
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math
import re
from functools import partial
from copy import deepcopy
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -218,6 +221,7 @@ class VisionTransformer(nn.Module):
hidden_states = self._pos_embed(x)
residual = None
if self.global_pool != 'token' or all_tokens:
# if True:
for block in self.blocks:
hidden_states, residual = block(hidden_states, residual)
else:
@ -225,10 +229,8 @@ class VisionTransformer(nn.Module):
hidden_states, residual = block(hidden_states, residual)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d')
residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d')
hidden_states, residual = self.blocks[-1](hidden_states_1st, residual_1st,
mixer_kwargs={'x_kv': hidden_states})
hidden_states, residual = self.blocks[-1](hidden_states, residual,
mixer_subset=slice(0, 1))
if not self.fused_dropout_add_ln:
residual = self.drop_path(self.dropout(hidden_states)) + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
@ -258,6 +260,29 @@ class VisionTransformer(nn.Module):
x = self.forward_head(x)
return x
def load_state_dict(self, state_dict, strict=True):
patch_embed_weight = state_dict['patch_embed.proj.weight']
if patch_embed_weight.dim() == 4:
# convert from Conv2d to Linear
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
'o c h w -> o (c h w)')
def key_mapping_attn(key):
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_layer = len(self.blocks)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if (self.blocks[-1].mixer.cross_attn
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
return super().load_state_dict(state_dict, strict=strict)
def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """

View File

@ -89,12 +89,15 @@ class Block(nn.Module):
p._shared_params = True
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_kwargs=None):
mixer_subset=None, mixer_kwargs=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
if self.prenorm:
if not self.fused_dropout_add_ln:
@ -116,8 +119,12 @@ class Block(nn.Module):
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
)
hidden_states = self.mixer(hidden_states,
**(mixer_kwargs if mixer_kwargs is not None else {}))
if mixer_kwargs is None:
mixer_kwargs = {}
mixer_kwargs['mixer_subset'] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))

View File

@ -420,7 +420,7 @@ class MHA(nn.Module):
return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
inference_params=None, **kwargs):
mixer_subset=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
@ -433,6 +433,9 @@ class MHA(nn.Module):
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
@ -454,6 +457,7 @@ class MHA(nn.Module):
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
if not self.cross_attn:
assert x_kv is None and mixer_subset is None
if not self.return_residual:
qkv = self.Wqkv(x)
else:
@ -491,14 +495,14 @@ class MHA(nn.Module):
context = rearrange(context, 'b h d -> b 1 h d')
else:
if not self.return_residual:
q = self.Wq(x)
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
kv = self.Wkv(x_kv if x_kv is not None else x)
else:
if x_kv is not None:
kv, x_kv = self.Wkv(x_kv)
else:
kv, x = self.Wkv(x)
q = self.Wq(x)
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
if self.dwconv:

View File

@ -26,7 +26,7 @@ def test_opt_state_dict(model_name):
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without any optimizations enabled) matches the
"""Check that our implementation of OPT (without all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""

49
tests/models/test_vit.py Normal file
View File

@ -0,0 +1,49 @@
import re
import torch
import pytest
from timm.models.vision_transformer import vit_base_patch16_224
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_dense_gelu_dense):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
kwargs = {}
if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)
model.load_state_dict(model_ref.state_dict())
model.eval()
model_ref.eval()
model_timm.eval()
torch.manual_seed(0)
batch_size = 2
x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
out = model(x)
out_timm = model_timm(x)
out_ref = model_ref(x.float())
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()