diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 1f895d7..1782338 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -4,8 +4,6 @@ from typing import Tuple, Optional, Union import torch -from einops import rearrange - import triton import triton.language as tl