Import torch before flash_attn_2_cuda
This commit is contained in:
parent
0e8c46ae08
commit
d431f16751
@ -1,8 +1,12 @@
|
|||||||
import flash_attn_2_cuda as flash_attn_cuda
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
# isort: off
|
||||||
|
# We need to import the CUDA kernels after importing torch
|
||||||
|
import flash_attn_2_cuda as flash_attn_cuda
|
||||||
|
# isort: on
|
||||||
|
|
||||||
|
|
||||||
def _get_block_size(device, head_dim, is_dropout, is_causal):
|
def _get_block_size(device, head_dim, is_dropout, is_causal):
|
||||||
# This should match the block sizes in the CUDA kernel
|
# This should match the block sizes in the CUDA kernel
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user