From 780e8eeabb84fe3f41e8244f04521743b032ba35 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 16 Jan 2023 01:20:04 -0800 Subject: [PATCH] [ViT] Support timm checkpoint, add tests --- flash_attn/models/vit.py | 33 ++++++++++++++++++++++--- flash_attn/modules/block.py | 13 +++++++--- flash_attn/modules/mha.py | 10 +++++--- tests/models/test_opt.py | 2 +- tests/models/test_vit.py | 49 +++++++++++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 11 deletions(-) create mode 100644 tests/models/test_vit.py diff --git a/flash_attn/models/vit.py b/flash_attn/models/vit.py index 41765c0..646c5e7 100644 --- a/flash_attn/models/vit.py +++ b/flash_attn/models/vit.py @@ -1,9 +1,12 @@ # Copyright (c) 2022, Tri Dao. # Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py import math +import re from functools import partial from copy import deepcopy +from collections import OrderedDict + import torch import torch.nn as nn import torch.nn.functional as F @@ -218,6 +221,7 @@ class VisionTransformer(nn.Module): hidden_states = self._pos_embed(x) residual = None if self.global_pool != 'token' or all_tokens: + # if True: for block in self.blocks: hidden_states, residual = block(hidden_states, residual) else: @@ -225,10 +229,8 @@ class VisionTransformer(nn.Module): hidden_states, residual = block(hidden_states, residual) # For the last layer, we only want the 1st token of the output. So we do cross-attention # where the query is the 1st token and the key/value is the whole sequence. - hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d') - residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d') - hidden_states, residual = self.blocks[-1](hidden_states_1st, residual_1st, - mixer_kwargs={'x_kv': hidden_states}) + hidden_states, residual = self.blocks[-1](hidden_states, residual, + mixer_subset=slice(0, 1)) if not self.fused_dropout_add_ln: residual = self.drop_path(self.dropout(hidden_states)) + residual hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) @@ -258,6 +260,29 @@ class VisionTransformer(nn.Module): x = self.forward_head(x) return x + def load_state_dict(self, state_dict, strict=True): + patch_embed_weight = state_dict['patch_embed.proj.weight'] + if patch_embed_weight.dim() == 4: + # convert from Conv2d to Linear + state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight, + 'o c h w -> o (c h w)') + def key_mapping_attn(key): + key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key) + key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key) + return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + n_layer = len(self.blocks) + # Convert from Wqkv to Wq and Wkv for cross attention (last layer) + if (self.blocks[-1].mixer.cross_attn + and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict): + Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight') + bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias') + state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim] + state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:] + state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim] + state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:] + return super().load_state_dict(state_dict, strict=strict) + def init_weights_vit_timm(module: nn.Module, name: str = ''): """ ViT weight initialization, original timm impl (for reproducibility) """ diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index 7911bbb..d692702 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -89,12 +89,15 @@ class Block(nn.Module): p._shared_params = True def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, - mixer_kwargs=None): + mixer_subset=None, mixer_kwargs=None): r"""Pass the input through the encoder layer. Args: hidden_states: the sequence to the encoder layer (required). residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. """ if self.prenorm: if not self.fused_dropout_add_ln: @@ -116,8 +119,12 @@ class Block(nn.Module): self.dropout1.p if self.training else 0.0, self.norm1.eps, rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32 ) - hidden_states = self.mixer(hidden_states, - **(mixer_kwargs if mixer_kwargs is not None else {})) + if mixer_kwargs is None: + mixer_kwargs = {} + mixer_kwargs['mixer_subset'] = mixer_subset + hidden_states = self.mixer(hidden_states, **mixer_kwargs) + if mixer_subset is not None: + residual = residual[:, mixer_subset] if not isinstance(self.mlp, nn.Identity): if not self.fused_dropout_add_ln: dropped = self.drop_path2(self.dropout2(hidden_states)) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 4998d93..4eb5aaf 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -420,7 +420,7 @@ class MHA(nn.Module): return _update_kv_cache(kv, inference_params, self.layer_idx) def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, - inference_params=None, **kwargs): + mixer_subset=None, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if @@ -433,6 +433,9 @@ class MHA(nn.Module): max_seqlen: int. Maximum sequence length in the batch. key_padding_mask: boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention. + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 """ @@ -454,6 +457,7 @@ class MHA(nn.Module): kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs} if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs}) if not self.cross_attn: + assert x_kv is None and mixer_subset is None if not self.return_residual: qkv = self.Wqkv(x) else: @@ -491,14 +495,14 @@ class MHA(nn.Module): context = rearrange(context, 'b h d -> b 1 h d') else: if not self.return_residual: - q = self.Wq(x) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) kv = self.Wkv(x_kv if x_kv is not None else x) else: if x_kv is not None: kv, x_kv = self.Wkv(x_kv) else: kv, x = self.Wkv(x) - q = self.Wq(x) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim) if self.dwconv: diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py index b0fc4f2..5b5529b 100644 --- a/tests/models/test_opt.py +++ b/tests/models/test_opt.py @@ -26,7 +26,7 @@ def test_opt_state_dict(model_name): @pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]) # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) def test_opt_optimized(model_name): - """Check that our implementation of OPT (without any optimizations enabled) matches the + """Check that our implementation of OPT (without all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py new file mode 100644 index 0000000..7452721 --- /dev/null +++ b/tests/models/test_vit.py @@ -0,0 +1,49 @@ +import re + +import torch +import pytest + +from timm.models.vision_transformer import vit_base_patch16_224 + +from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 + + +@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True]) +# @pytest.mark.parametrize('fused_dense_gelu_dense', [False]) +@pytest.mark.parametrize('optimized', [False, True]) +# @pytest.mark.parametrize('optimized', [True]) +def test_vit(optimized, fused_dense_gelu_dense): + """Check that our implementation of ViT matches the timm's implementation: + the output of our forward pass in fp16 should be around the same as + timm' forward pass in fp16, when compared to timm's forward pass in fp32. + """ + dtype = torch.float16 + device = 'cuda' + + kwargs = {} + if optimized: + kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) + kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense + model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) + + model_ref = vit_base_patch16_224(pretrained=True).to(device=device) + model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype) + + model.load_state_dict(model_ref.state_dict()) + + model.eval() + model_ref.eval() + model_timm.eval() + + torch.manual_seed(0) + batch_size = 2 + x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype) + out = model(x) + out_timm = model_timm(x) + out_ref = model_ref(x.float()) + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}') + print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}') + assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()