From d431f16751bf42033a67f7c98251f70d225ab62f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 19 Aug 2023 21:07:33 -0700 Subject: [PATCH] Import torch before flash_attn_2_cuda --- flash_attn/flash_attn_interface.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 2124e6b..ffc6440 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1,8 +1,12 @@ -import flash_attn_2_cuda as flash_attn_cuda import torch import torch.nn as nn from einops import rearrange +# isort: off +# We need to import the CUDA kernels after importing torch +import flash_attn_2_cuda as flash_attn_cuda +# isort: on + def _get_block_size(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel