diff --git a/docs/flip/1009-optimizer-api.md b/docs/flip/1009-optimizer-api.md index f658db5c..ec8d1573 100644 --- a/docs/flip/1009-optimizer-api.md +++ b/docs/flip/1009-optimizer-api.md @@ -105,7 +105,7 @@ def train_step(opt_state, variables, inputs, labels, apply_fn, tx_update_fn): params) updates, new_opt_state = tx_update_fn(grads, opt_state, params) new_params = optax.apply_updates(params, updates) - new_variables = {**variables, **new_model_state, 'params': params} + new_variables = {**variables, **new_model_state, 'params': new_params} return new_opt_state, new_variables, loss