Skip to content

Commit

Permalink
minor change to make backtracking init and update return the same sta…
Browse files Browse the repository at this point in the history
…te types, this avoids recompilation when calling update
  • Loading branch information
bafflingbits committed Jan 15, 2025
1 parent 98c73c5 commit 7d090b7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optax/_src/linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ def init_fn(params: base.Params) -> ScaleByBacktrackingLinesearchState:
grad = None
return ScaleByBacktrackingLinesearchState(
learning_rate=jnp.array(1.0),
value=jnp.array(jnp.inf),
value=jnp.array(jnp.inf, dtype=params.dtype),
grad=grad,
info=BacktrackingLinesearchInfo(
num_linesearch_steps=0,
decrease_error=jnp.array(jnp.inf),
decrease_error=jnp.array(jnp.inf, dtype=params.dtype),
),
)

Expand Down

0 comments on commit 7d090b7

Please sign in to comment.