Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def model(X, y):
drawing method, hence allowing us to collect samples in parallel on a single device.
:param bool progress_bar: Whether to enable progress bar updates. Defaults to
``True``.
:param int print_rate: Number of iterations per progress bar update. Defaults to None, which is
5% of total iterations when there are more than 20 iterations, otherwise every iteration.
:param bool jit_model_args: If set to `True`, this will compile the potential energy
computation as a function of model arguments. As such, calling `MCMC.run` again
on a same sized but different dataset will not result in additional compilation cost.
Expand Down Expand Up @@ -331,6 +333,7 @@ def __init__(
postprocess_fn=None,
chain_method="parallel",
progress_bar=True,
print_rate=None,
jit_model_args=False,
):
self.sampler = sampler
Expand Down Expand Up @@ -374,6 +377,7 @@ def __init__(
self.progress_bar = progress_bar
if "CI" in os.environ or "PYTEST_XDIST_WORKER" in os.environ:
self.progress_bar = False
self.print_rate = print_rate
self._jit_model_args = jit_model_args
self._states = None
self._states_flat = None
Expand Down Expand Up @@ -497,6 +501,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
postprocess_fn, collect_fields, remove_sites
),
progbar=self.progress_bar,
print_rate=self.print_rate,
return_last_val=True,
thinning=self.thinning,
collection_size=collection_size,
Expand Down
21 changes: 13 additions & 8 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,13 @@ def _wrapped(fn):
return _wrapped


def progress_bar_factory(num_samples: int, num_chains: int) -> Callable:
def progress_bar_factory(
num_samples: int, num_chains: int, print_rate: int
) -> Callable:
"""Factory that builds a progress bar decorator along
with the `set_tqdm_description` and `close_tqdm` functions
"""

if num_samples > 20:
print_rate = int(num_samples / 20)
else:
print_rate = 1

remainder = num_samples % print_rate

idx_counter = 0 # resource counter to assign chains to progress bars
Expand Down Expand Up @@ -303,6 +300,7 @@ def fori_collect(
init_val: Any,
transform: Callable = identity,
progbar: bool = True,
print_rate: None | int = None,
return_last_val: bool = False,
collection_size=None,
thinning=1,
Expand All @@ -325,6 +323,7 @@ def fori_collect(
be any Python collection type containing `np.ndarray` objects.
:param transform: a callable to post-process the values returned by `body_fn`.
:param progbar: whether to post progress bar updates.
:param print_rate: number of iterations per progress bar update.
:param bool return_last_val: If `True`, the last value is also returned.
This has the same type as `init_val`.
:param thinning: Positive integer that controls the thinning ratio for retained
Expand All @@ -351,6 +350,12 @@ def fori_collect(
start_idx = lower + (upper - lower) % thinning
num_chains = progbar_opts.pop("num_chains", 1)

if print_rate is None:
if upper > 20:
print_rate = int(upper / 20)
else:
print_rate = 1

@partial(maybe_jit, donate_argnums=2)
@cached_by(fori_collect, body_fun, transform)
def _body_fn(i, val, collection, start_idx, thinning):
Expand Down Expand Up @@ -391,7 +396,7 @@ def loop_fn(collection):
last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)

elif num_chains > 1:
progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
progress_bar_fori_loop = progress_bar_factory(upper, num_chains, print_rate)
_body_fn_pbar = progress_bar_fori_loop(lambda i, vals: _body_fn(i, *vals))

def loop_fn(collection):
Expand All @@ -416,7 +421,7 @@ def loop_fn(collection):
_, collection, _, _ = _body_fn(-1, val, collection, start_idx, thinning)
vals = (val, collection, start_idx, thinning)
else:
with tqdm.trange(upper) as t:
with tqdm.trange(upper, miniters=print_rate) as t:
for i in t:
vals = _body_fn(i, *vals)

Expand Down