From 7153673c1a3c7753c38e4c10ef2c98a02be5f778 Mon Sep 17 00:00:00 2001 From: Neil Tenenholtz Date: Fri, 15 Nov 2024 19:23:40 -0500 Subject: [PATCH] Fix swiglu backwards return type (#1337) --- flash_attn/ops/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/ops/activations.py b/flash_attn/ops/activations.py index b00063b..7c09649 100644 --- a/flash_attn/ops/activations.py +++ b/flash_attn/ops/activations.py @@ -110,7 +110,7 @@ template T swiglu_fwd(T x, T y) { } """ swiglu_bwd_codestring = """ -template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { +template void swiglu_bwd(T x, T y, T g, T& dx, T& dy) { float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); dy = float(x) * x_sigmoid * float(g);