Skip to content

Commit

Permalink
Making pmap axis names consistent in examples code to support things …
Browse files Browse the repository at this point in the history
…like cross-replica batch norm layers.

PiperOrigin-RevId: 700687393
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 27, 2024
1 parent 5e135a6 commit 4de99f5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(
),
)

self.params_init = jax.pmap(init_parameters_func)
self.params_init = jax.pmap(init_parameters_func, axis_name="kfac_axis")
self.model_loss_func = model_loss_func
self.model_func_for_estimator = model_func_for_estimator

Expand All @@ -223,10 +223,10 @@ def __init__(
)

self.train_batch_pmap = jax.pmap(
self._train_batch, axis_name="train_axis"
self._train_batch, axis_name="kfac_axis"
)
self.eval_batch_pmap = jax.pmap(
self._eval_batch, axis_name="eval_axis"
self._eval_batch, axis_name="kfac_axis"
)

# Log some useful information
Expand Down

0 comments on commit 4de99f5

Please sign in to comment.