Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show one progress bar per chain when sampling #7634

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jan 6, 2025

Description

I really like what nutpie gives you while sampling, so I tried to make something using rich that copies it. Example:

test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-07.10-57-56.mp4

Features are:

  1. One progress bar per chain
  2. Sampling statistics per chain. I copied nutpie, but we can haggle over what these should be (or give the user more control)
  3. Color change based on status. Blue when sampling, turns red after a divergence. Finished bar is either green (no divergences) or purple (with divergences).

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7634.org.readthedocs.build/en/7634/

@jessegrabowski jessegrabowski changed the title Show one progress bars per chain when sampling Show one progress bar per chain when sampling Jan 6, 2025
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this look like when 1) you have another step sampler in the mix and 2) there's no nuts at all or 3) there are more than one NUTS step samplers?

@@ -1,4 +1,4 @@
# Copyright 2024 The PyMC Developers
# Copyright 2025 The PyMC Developers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to Oriol this should just be 20xx-present

@ricardoV94
Copy link
Member

Doesn't need to be this PR but would be nice to show a relevant statistic for each sampler (or at least for when a single non NUTS sampler is being used).

Conversely not showing these columns when there's no NUTS, as it gives a false sense of everything is going great

@jessegrabowski
Copy link
Member Author

test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-07.22-09-52.mp4

Here's a comparison between NUTS and non-nuts sampler.

Ideally we'd add a method to the step samplers themselves that would return the rich columns that sampler wants to use, then we just gather them and display. In that case you could even different sampler stats from different steps in the same run. Maybe it's worth doing. The actual code for this PR is pretty gnarly.

@jessegrabowski
Copy link
Member Author

I moved the responsibility for setting up the progressbars and updating stats to the step samplers. This means each step method can choose what stats are to be shown on the progress bars, and we can also combine them. Example vid attached.

test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-08.18-16-37.mp4

This is a pretty big scope creep for this PR, so I'm not against reverting these changes and going with something more basic. If we like it though I can lean into it.

I will say it's broken right now because when you have e.g. multiple metropolis steps (one per variable) the only stats that get reported are the last one. It needs some logic on how to aggregate the stats across samplers with the same stats.

@ricardoV94
Copy link
Member

The step sampler specifics looks amazing 😍 Gonna give it a try today.

I'll test it but I assume things behave gracefully if the step samplers don't specify the display columns info?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 8, 2025

No it will break. I need to put in a default for the base class. It just needs to return empty stuff.

@ricardoV94
Copy link
Member

I would still like to see the global runtime and /eta like we had before. Is that feasible or too ugly?

Re: repeated samplers, show the mean? Or maybe only display specialized info when a single step sampler is being used?

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Jan 8, 2025

I added the base impl, so things will go gracefully if there's no implementation. This would only show the NUTS stats for example, because there's no implementation for BinaryGibbsMetropolis:

import pymc as pm

with pm.Model() as m:
    x = pm.Bernoulli('x', p=0.5)
    y = pm.Normal('y', mu=pm.math.switch(x, -3, 3), sigma=10, shape=(10,))
    
    idata = pm.sample(step=[pm.BinaryGibbsMetropolis(x), pm.NUTS(y)], tune=2000, draws=2000, chains=8, cores=8, compile_kwargs={'mode':'NUMBA'})

Re: global, yes we can keep it. But we can't have it as a single long bar that breaks the columns, because there's no colspan operator for rich tables (see Textualize/rich#164).

We could make a separate table though. It just won't be as pretty as nutpie.

I was thinking about the mean as well. If it only shows up when there's a single step sampler it would be pretty rare that anyone would use it, because the non-NUTS samplers pretty much always show up as one per variable.

We might also need some priority logic to decide what to show if too many stats get involved. You can see just NUTS + Metropolis already breaks the table. We could do a LOW/MEDIUM/HIGH priority for displaying stats, and only at max 5 ever get displayed?

@ricardoV94
Copy link
Member

Sequential sampling (cores=1) still has the old approach. It has one bar per chain but not the stats

@ricardoV94
Copy link
Member

Re: global, yes we can keep it. But we can't have it as a single long bar that breaks the columns, because there's no colspan operator for rich tables (see Textualize/rich#164).

What if we show as a column per chain then? elapsed/left?

@jessegrabowski
Copy link
Member Author

New version with timing info:

test_lr_scheduler.-.Jupyter.Notebook.Mozilla.Firefox.2025-01-12.19-44-52.mp4

@ricardoV94
Copy link
Member

Sequential sampling (cores=1) still has the old approach. It has one bar per chain but not the stats

Did you address this?

@jessegrabowski
Copy link
Member Author

Not yet, but it will be an easy fix.

@ricardoV94
Copy link
Member

Some failing tests as well. Otherwise I'm happy with the changes. I'll paste in the discord to see if anybody has big complaints

@aloctavodia
Copy link
Member

This is looks really nice and modern and its very informative, but do we have an option for a single progress bar with less information.

@jessegrabowski
Copy link
Member Author

We can do that painlessly yeah

@twiecki
Copy link
Member

twiecki commented Jan 15, 2025

This looks great. I assume blue means tuning and red means post-tuning? If so, I wonder if red is the best color choice as it suggests something gone wrong. Maybe replace red with green? Or make tuning red and sampling blue?

@ricardoV94
Copy link
Member

This looks great. I assume blue means tuning and red means post-tuning? If so, I wonder if red is the best color choice as it suggests something gone wrong. Maybe replace red with green? Or make tuning red and sampling blue?

It turns red if there's any divergence

@twiecki
Copy link
Member

twiecki commented Jan 15, 2025

I see, maybe then green post-tuning without divergences? Or maybe a non-colorblind color.

@aloctavodia
Copy link
Member

If you want some colorblind-friendly palletes https://github.com/arviz-devs/arviz-plots/tree/main/src/arviz_plots/styles

@fonnesbeck
Copy link
Member

fonnesbeck commented Jan 15, 2025

I like it a lot.

I don't know if we need different colors for pre-/post-tuning.

Can we get red for any warning, not just divergences (so that it makes users read the warning)?

Definitely no green if we are using red. I like blue/red for clean/warning.

@ricardoV94
Copy link
Member

I think 2 colors is enough. It will be clear when you use compared to a gif. Otherwise a single color.

@fonnesbeck what sort of warnings are you thinking about? Are they emmited during sampling or only at the end?

@fonnesbeck
Copy link
Member

Things like max tree depth, missing the target acceptance rate and r-hat. Those that are calculated at the end could flip the line to red upon completion.

Not a big deal, though. What's here is a big improvement!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants