[Bugfix] Fix logit soft cap in flash-attn backend (#7425)

This commit is contained in:
Woosuk Kwon 2024-08-12 09:58:28 -07:00 committed by GitHub
parent d2bc4510a4
commit cfba4def5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -563,6 +563,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
# Reshape the output tensor.