Skip to content

Commit c4081d8

Browse files
authored
fix typo in orthogonal init
1 parent 87f439d commit c4081d8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/nn/initializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,5 @@ def init(key, shape, dtype=np.float32):
9494
if n_rows < n_cols: Q = Q.T
9595
Q = np.reshape(Q, np.delete(shape, column_axis) + (shape[column_axis],))
9696
Q = np.moveaxis(Q, -1, column_axis)
97-
return stddev * Q
97+
return scale * Q
9898
return init

0 commit comments

Comments
 (0)