50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
import re
|
|
|
|
import torch
|
|
import pytest
|
|
|
|
from timm.models.vision_transformer import vit_base_patch16_224
|
|
|
|
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
|
|
|
|
|
|
@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
|
|
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
|
|
@pytest.mark.parametrize('optimized', [False, True])
|
|
# @pytest.mark.parametrize('optimized', [True])
|
|
def test_vit(optimized, fused_dense_gelu_dense):
|
|
"""Check that our implementation of ViT matches the timm's implementation:
|
|
the output of our forward pass in fp16 should be around the same as
|
|
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
|
|
"""
|
|
dtype = torch.float16
|
|
device = 'cuda'
|
|
|
|
kwargs = {}
|
|
if optimized:
|
|
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
|
|
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
|
|
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
|
|
|
|
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
|
|
model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)
|
|
|
|
model.load_state_dict(model_ref.state_dict())
|
|
|
|
model.eval()
|
|
model_ref.eval()
|
|
model_timm.eval()
|
|
|
|
torch.manual_seed(0)
|
|
batch_size = 2
|
|
x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
|
|
out = model(x)
|
|
out_timm = model_timm(x)
|
|
out_ref = model_ref(x.float())
|
|
|
|
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
|
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
|
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
|
|
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
|
|
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()
|