[ViT] Support timm checkpoint, add tests

This commit is contained in:
Tri Dao 2023-01-16 01:20:04 -08:00
parent 2ec7d3f72c
commit 780e8eeabb
5 changed files with 96 additions and 11 deletions

View File

@ -1,9 +1,12 @@
# Copyright (c) 2022, Tri Dao. # Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math import math
import re
from functools import partial from functools import partial
from copy import deepcopy from copy import deepcopy
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -218,6 +221,7 @@ class VisionTransformer(nn.Module):
hidden_states = self._pos_embed(x) hidden_states = self._pos_embed(x)
residual = None residual = None
if self.global_pool != 'token' or all_tokens: if self.global_pool != 'token' or all_tokens:
# if True:
for block in self.blocks: for block in self.blocks:
hidden_states, residual = block(hidden_states, residual) hidden_states, residual = block(hidden_states, residual)
else: else:
@ -225,10 +229,8 @@ class VisionTransformer(nn.Module):
hidden_states, residual = block(hidden_states, residual) hidden_states, residual = block(hidden_states, residual)
# For the last layer, we only want the 1st token of the output. So we do cross-attention # For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence. # where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d') hidden_states, residual = self.blocks[-1](hidden_states, residual,
residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d') mixer_subset=slice(0, 1))
hidden_states, residual = self.blocks[-1](hidden_states_1st, residual_1st,
mixer_kwargs={'x_kv': hidden_states})
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
residual = self.drop_path(self.dropout(hidden_states)) + residual residual = self.drop_path(self.dropout(hidden_states)) + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
@ -258,6 +260,29 @@ class VisionTransformer(nn.Module):
x = self.forward_head(x) x = self.forward_head(x)
return x return x
def load_state_dict(self, state_dict, strict=True):
patch_embed_weight = state_dict['patch_embed.proj.weight']
if patch_embed_weight.dim() == 4:
# convert from Conv2d to Linear
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
'o c h w -> o (c h w)')
def key_mapping_attn(key):
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_layer = len(self.blocks)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if (self.blocks[-1].mixer.cross_attn
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
return super().load_state_dict(state_dict, strict=strict)
def init_weights_vit_timm(module: nn.Module, name: str = ''): def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """ """ ViT weight initialization, original timm impl (for reproducibility) """

View File

@ -89,12 +89,15 @@ class Block(nn.Module):
p._shared_params = True p._shared_params = True
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_kwargs=None): mixer_subset=None, mixer_kwargs=None):
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
Args: Args:
hidden_states: the sequence to the encoder layer (required). hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
""" """
if self.prenorm: if self.prenorm:
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
@ -116,8 +119,12 @@ class Block(nn.Module):
self.dropout1.p if self.training else 0.0, self.norm1.eps, self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32 rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
) )
hidden_states = self.mixer(hidden_states, if mixer_kwargs is None:
**(mixer_kwargs if mixer_kwargs is not None else {})) mixer_kwargs = {}
mixer_kwargs['mixer_subset'] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity): if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states)) dropped = self.drop_path2(self.dropout2(hidden_states))

View File

@ -420,7 +420,7 @@ class MHA(nn.Module):
return _update_kv_cache(kv, inference_params, self.layer_idx) return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
inference_params=None, **kwargs): mixer_subset=None, inference_params=None, **kwargs):
""" """
Arguments: Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
@ -433,6 +433,9 @@ class MHA(nn.Module):
max_seqlen: int. Maximum sequence length in the batch. max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out. key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention. (batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex) inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
""" """
@ -454,6 +457,7 @@ class MHA(nn.Module):
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs} kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs}) if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
if not self.cross_attn: if not self.cross_attn:
assert x_kv is None and mixer_subset is None
if not self.return_residual: if not self.return_residual:
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
else: else:
@ -491,14 +495,14 @@ class MHA(nn.Module):
context = rearrange(context, 'b h d -> b 1 h d') context = rearrange(context, 'b h d -> b 1 h d')
else: else:
if not self.return_residual: if not self.return_residual:
q = self.Wq(x) q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
kv = self.Wkv(x_kv if x_kv is not None else x) kv = self.Wkv(x_kv if x_kv is not None else x)
else: else:
if x_kv is not None: if x_kv is not None:
kv, x_kv = self.Wkv(x_kv) kv, x_kv = self.Wkv(x_kv)
else: else:
kv, x = self.Wkv(x) kv, x = self.Wkv(x)
q = self.Wq(x) q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim) kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
if self.dwconv: if self.dwconv:

View File

@ -26,7 +26,7 @@ def test_opt_state_dict(model_name):
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]) @pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name): def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without any optimizations enabled) matches the """Check that our implementation of OPT (without all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """

49
tests/models/test_vit.py Normal file
View File

@ -0,0 +1,49 @@
import re
import torch
import pytest
from timm.models.vision_transformer import vit_base_patch16_224
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_dense_gelu_dense):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
kwargs = {}
if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)
model.load_state_dict(model_ref.state_dict())
model.eval()
model_ref.eval()
model_timm.eval()
torch.manual_seed(0)
batch_size = 2
x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
out = model(x)
out_timm = model_timm(x)
out_ref = model_ref(x.float())
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()