diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 2124e6b..ffc6440 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1,8 +1,12 @@ -import flash_attn_2_cuda as flash_attn_cuda import torch import torch.nn as nn from einops import rearrange +# isort: off +# We need to import the CUDA kernels after importing torch +import flash_attn_2_cuda as flash_attn_cuda +# isort: on + def _get_block_size(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel