2023-01-16 17:20:04 +08:00
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
import pytest
|
2023-08-19 11:59:35 +08:00
|
|
|
import torch
|
2023-01-16 17:20:04 +08:00
|
|
|
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
|
2023-08-19 11:59:35 +08:00
|
|
|
from timm.models.vision_transformer import vit_base_patch16_224
|
2023-01-16 17:20:04 +08:00
|
|
|
|
|
|
|
|
|
2023-08-19 11:59:35 +08:00
|
|
|
@pytest.mark.parametrize("fused_mlp", [False, True])
|
2023-01-18 10:12:27 +08:00
|
|
|
# @pytest.mark.parametrize('fused_mlp', [False])
|
2023-08-19 11:59:35 +08:00
|
|
|
@pytest.mark.parametrize("optimized", [False, True])
|
2023-01-16 17:20:04 +08:00
|
|
|
# @pytest.mark.parametrize('optimized', [True])
|
2023-01-18 10:12:27 +08:00
|
|
|
def test_vit(optimized, fused_mlp):
|
2023-01-16 17:20:04 +08:00
|
|
|
"""Check that our implementation of ViT matches the timm's implementation:
|
|
|
|
|
the output of our forward pass in fp16 should be around the same as
|
|
|
|
|
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
|
|
|
|
|
"""
|
|
|
|
|
dtype = torch.float16
|
2023-08-19 11:59:35 +08:00
|
|
|
device = "cuda"
|
2023-01-16 17:20:04 +08:00
|
|
|
|
|
|
|
|
kwargs = {}
|
|
|
|
|
if optimized:
|
|
|
|
|
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
|
2023-08-19 11:59:35 +08:00
|
|
|
kwargs["fused_mlp"] = fused_mlp
|
2023-01-16 17:20:04 +08:00
|
|
|
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
|
|
|
|
|
model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(model_ref.state_dict())
|
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
model_ref.eval()
|
|
|
|
|
model_timm.eval()
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
batch_size = 2
|
|
|
|
|
x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)
|
|
|
|
|
out = model(x)
|
|
|
|
|
out_timm = model_timm(x)
|
|
|
|
|
out_ref = model_ref(x.float())
|
|
|
|
|
|
2023-08-19 11:59:35 +08:00
|
|
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
|
|
|
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
|
|
|
|
print(f"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}")
|
|
|
|
|
print(f"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}")
|
2023-08-18 08:25:34 +08:00
|
|
|
rtol = 2 if not fused_mlp else 8
|
2023-01-18 10:12:27 +08:00
|
|
|
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()
|