Fix swiglu backwards return type (#1337)
This commit is contained in:
parent
641db759ab
commit
7153673c1a
@ -110,7 +110,7 @@ template <typename T> T swiglu_fwd(T x, T y) {
|
||||
}
|
||||
"""
|
||||
swiglu_bwd_codestring = """
|
||||
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
|
||||
template <typename T> void swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
|
||||
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
||||
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
|
||||
dy = float(x) * x_sigmoid * float(g);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user