Skip to content

Commit 20f167d

Browse files
authored
Merge pull request jax-ml#1323 from google/ortho-typo
fix typo in orthogonal init
2 parents 87f439d + c4081d8 commit 20f167d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/nn/initializers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,5 @@ def init(key, shape, dtype=np.float32):
9494
if n_rows < n_cols: Q = Q.T
9595
Q = np.reshape(Q, np.delete(shape, column_axis) + (shape[column_axis],))
9696
Q = np.moveaxis(Q, -1, column_axis)
97-
return stddev * Q
97+
return scale * Q
9898
return init

0 commit comments

Comments
 (0)