Fix swiglu backwards return type (#1337)

This commit is contained in:
Neil Tenenholtz 2024-11-15 19:23:40 -05:00 committed by GitHub
parent 641db759ab
commit 7153673c1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -110,7 +110,7 @@ template <typename T> T swiglu_fwd(T x, T y) {
}
"""
swiglu_bwd_codestring = """
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
template <typename T> void swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);