diff --git a/README.md b/README.md index aaab7e1..53c9fd3 100644 --- a/README.md +++ b/README.md @@ -400,12 +400,13 @@ If you use this codebase, or otherwise found our work valuable, please cite: @inproceedings{dao2022flashattention, title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, - booktitle={Advances in Neural Information Processing Systems}, + booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2022} } -@article{dao2023flashattention2, +@inproceedings{dao2023flashattention2, title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, author={Dao, Tri}, - year={2023} + booktitle={International Conference on Learning Representations (ICLR)}, + year={2024} } ``` diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index d5d1139..0d9120c 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -12,7 +12,12 @@ import torch.nn.functional as F from einops import rearrange, repeat from torch import Tensor from torch.profiler import ProfilerActivity, profile, record_function -from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput + +try: + from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput +except ImportError: + GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"]) + SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"]) @dataclass