Import torch before flash_attn_2_cuda
This commit is contained in:
parent
0e8c46ae08
commit
d431f16751
@ -1,8 +1,12 @@
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
# isort: off
|
||||
# We need to import the CUDA kernels after importing torch
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
# isort: on
|
||||
|
||||
|
||||
def _get_block_size(device, head_dim, is_dropout, is_causal):
|
||||
# This should match the block sizes in the CUDA kernel
|
||||
|
||||
Loading…
Reference in New Issue
Block a user