From c33de664a105533853cfe807f2caa50a05dd46e8 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 1 Aug 2024 02:14:25 -0700 Subject: [PATCH] Fix import in test --- hopper/test_flash_attn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 5a3f590..8c90988 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -6,8 +6,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from flash_attn_interface import flash_attn_func, flash_attn_varlen_func -# from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref -from test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref +from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref ABS_TOL = 5e-3 REL_TOL = 1e-1