Fix import in test

This commit is contained in:
Tri Dao 2024-08-01 02:14:25 -07:00
parent bafe253042
commit c33de664a1

View File

@ -6,8 +6,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
# from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
from test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
ABS_TOL = 5e-3
REL_TOL = 1e-1