From a81900d4c1196ac66cb5e64ff0ecc8c59bee84b6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 17 Aug 2023 17:25:34 -0700 Subject: [PATCH] [ViT] Minor fix so it runs --- flash_attn/models/vit.py | 2 +- tests/models/test_vit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/models/vit.py b/flash_attn/models/vit.py index 312f5e4..7135ae8 100644 --- a/flash_attn/models/vit.py +++ b/flash_attn/models/vit.py @@ -31,7 +31,7 @@ except ImportError: def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False): - mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias, + mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, qkv_proj_bias=qkv_bias, dropout=attn_drop, fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn) return mixer_cls diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py index b1cb959..9898106 100644 --- a/tests/models/test_vit.py +++ b/tests/models/test_vit.py @@ -46,5 +46,5 @@ def test_vit(optimized, fused_mlp): print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}') print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}') - rtol = 2 if not fused_mlp else 4 + rtol = 2 if not fused_mlp else 8 assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()