[ViT] Support timm checkpoint, add tests
This commit is contained in:
parent
2ec7d3f72c
commit
780e8eeabb
@ -1,9 +1,12 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from copy import deepcopy
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -218,6 +221,7 @@ class VisionTransformer(nn.Module):
|
||||
hidden_states = self._pos_embed(x)
|
||||
residual = None
|
||||
if self.global_pool != 'token' or all_tokens:
|
||||
# if True:
|
||||
for block in self.blocks:
|
||||
hidden_states, residual = block(hidden_states, residual)
|
||||
else:
|
||||
@ -225,10 +229,8 @@ class VisionTransformer(nn.Module):
|
||||
hidden_states, residual = block(hidden_states, residual)
|
||||
# For the last layer, we only want the 1st token of the output. So we do cross-attention
|
||||
# where the query is the 1st token and the key/value is the whole sequence.
|
||||
hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d')
|
||||
residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d')
|
||||
hidden_states, residual = self.blocks[-1](hidden_states_1st, residual_1st,
|
||||
mixer_kwargs={'x_kv': hidden_states})
|
||||
hidden_states, residual = self.blocks[-1](hidden_states, residual,
|
||||
mixer_subset=slice(0, 1))
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.drop_path(self.dropout(hidden_states)) + residual
|
||||
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
||||
@ -258,6 +260,29 @@ class VisionTransformer(nn.Module):
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
patch_embed_weight = state_dict['patch_embed.proj.weight']
|
||||
if patch_embed_weight.dim() == 4:
|
||||
# convert from Conv2d to Linear
|
||||
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
|
||||
'o c h w -> o (c h w)')
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
|
||||
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
n_layer = len(self.blocks)
|
||||
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
|
||||
if (self.blocks[-1].mixer.cross_attn
|
||||
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
|
||||
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
|
||||
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
|
||||
return super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ''):
|
||||
""" ViT weight initialization, original timm impl (for reproducibility) """
|
||||
|
||||
@ -89,12 +89,15 @@ class Block(nn.Module):
|
||||
p._shared_params = True
|
||||
|
||||
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
|
||||
mixer_kwargs=None):
|
||||
mixer_subset=None, mixer_kwargs=None):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
hidden_states: the sequence to the encoder layer (required).
|
||||
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
||||
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
||||
before applying the query projection. Useful for e.g., ViT where we only care
|
||||
about the CLS token in the last layer.
|
||||
"""
|
||||
if self.prenorm:
|
||||
if not self.fused_dropout_add_ln:
|
||||
@ -116,8 +119,12 @@ class Block(nn.Module):
|
||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
hidden_states = self.mixer(hidden_states,
|
||||
**(mixer_kwargs if mixer_kwargs is not None else {}))
|
||||
if mixer_kwargs is None:
|
||||
mixer_kwargs = {}
|
||||
mixer_kwargs['mixer_subset'] = mixer_subset
|
||||
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
||||
if mixer_subset is not None:
|
||||
residual = residual[:, mixer_subset]
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped = self.drop_path2(self.dropout2(hidden_states))
|
||||
|
||||
@ -420,7 +420,7 @@ class MHA(nn.Module):
|
||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
||||
inference_params=None, **kwargs):
|
||||
mixer_subset=None, inference_params=None, **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
||||
@ -433,6 +433,9 @@ class MHA(nn.Module):
|
||||
max_seqlen: int. Maximum sequence length in the batch.
|
||||
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
||||
(batch, seqlen). Only applicable when not using FlashAttention.
|
||||
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
||||
before applying the query projection. Useful for e.g., ViT where we only care
|
||||
about the CLS token in the last layer.
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
"""
|
||||
@ -454,6 +457,7 @@ class MHA(nn.Module):
|
||||
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
|
||||
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
|
||||
if not self.cross_attn:
|
||||
assert x_kv is None and mixer_subset is None
|
||||
if not self.return_residual:
|
||||
qkv = self.Wqkv(x)
|
||||
else:
|
||||
@ -491,14 +495,14 @@ class MHA(nn.Module):
|
||||
context = rearrange(context, 'b h d -> b 1 h d')
|
||||
else:
|
||||
if not self.return_residual:
|
||||
q = self.Wq(x)
|
||||
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
||||
kv = self.Wkv(x_kv if x_kv is not None else x)
|
||||
else:
|
||||
if x_kv is not None:
|
||||
kv, x_kv = self.Wkv(x_kv)
|
||||
else:
|
||||
kv, x = self.Wkv(x)
|
||||
q = self.Wq(x)
|
||||
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
||||
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
|
||||
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
|
||||
if self.dwconv:
|
||||
|
||||
@ -26,7 +26,7 @@ def test_opt_state_dict(model_name):
|
||||
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
|
||||
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
|
||||
def test_opt_optimized(model_name):
|
||||
"""Check that our implementation of OPT (without any optimizations enabled) matches the
|
||||
"""Check that our implementation of OPT (without all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
|
||||
49
tests/models/test_vit.py
Normal file
49
tests/models/test_vit.py
Normal file
@ -0,0 +1,49 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from timm.models.vision_transformer import vit_base_patch16_224
|
||||
|
||||
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
|
||||
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
|
||||
@pytest.mark.parametrize('optimized', [False, True])
|
||||
# @pytest.mark.parametrize('optimized', [True])
|
||||
def test_vit(optimized, fused_dense_gelu_dense):
|
||||
"""Check that our implementation of ViT matches the timm's implementation:
|
||||
the output of our forward pass in fp16 should be around the same as
|
||||
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
|
||||
kwargs = {}
|
||||
if optimized:
|
||||
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
|
||||
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
|
||||
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
|
||||
|
||||
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
|
||||
model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)
|
||||
|
||||
model.load_state_dict(model_ref.state_dict())
|
||||
|
||||
model.eval()
|
||||
model_ref.eval()
|
||||
model_timm.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
|
||||
out = model(x)
|
||||
out_timm = model_timm(x)
|
||||
out_ref = model_ref(x.float())
|
||||
|
||||
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
|
||||
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
|
||||
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()
|
||||
Loading…
Reference in New Issue
Block a user