We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 87f439d + c4081d8 commit 20f167dCopy full SHA for 20f167d
jax/nn/initializers.py
@@ -94,5 +94,5 @@ def init(key, shape, dtype=np.float32):
94
if n_rows < n_cols: Q = Q.T
95
Q = np.reshape(Q, np.delete(shape, column_axis) + (shape[column_axis],))
96
Q = np.moveaxis(Q, -1, column_axis)
97
- return stddev * Q
+ return scale * Q
98
return init
0 commit comments