Import torch before flash_attn_2_cuda

This commit is contained in:
Tri Dao 2023-08-19 21:07:33 -07:00
parent 0e8c46ae08
commit d431f16751

View File

@ -1,8 +1,12 @@
import flash_attn_2_cuda as flash_attn_cuda
import torch
import torch.nn as nn
from einops import rearrange
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
# isort: on
def _get_block_size(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel