-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Don't use deprecated batched_dot #7951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Don't use deprecated batched_dot #7951
Conversation
- Fixes Issue pymc-devs#7878 - Replace pt.batched_dot(sqrt_quad.T, sqrt_quad.T) with pt.sum(sqrt_quad.T ** 2, axis=-1) - Computes squared norm per sample using modern PyTensor operations - Eliminates deprecation warnings and ensures future compatibility
|
pre-commit.ci autofix |
for more information, see https://pre-commit.ci
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #7951 +/- ##
==========================================
+ Coverage 90.62% 91.48% +0.86%
==========================================
Files 116 116
Lines 18947 18947
==========================================
+ Hits 17170 17333 +163
+ Misses 1777 1614 -163
🚀 New features to boost your workflow:
|
| # Square each sample | ||
| quad = pt.batched_dot(sqrt_quad.T, sqrt_quad.T) | ||
| # Square each sample - compute squared norm for each sample | ||
| quad = pt.sum(sqrt_quad.T**2, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| quad = pt.sum(sqrt_quad.T**2, axis=-1) | |
| quad = pt.sum(sqrt_quad**2, axis=-2) |
More explicit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is also incorrect, it's not a "squared norm".
Description
This PR replaces the deprecated
pt.batched_dotfunction with the preferredpt.sumoperation in the KroneckerNormal distribution's logprob calculation, addressing issue #7878.Problem
The current implementation uses
batched_dot, which is deprecated in PyTensor and triggers warnings. Deprecated functions may lead to future breakage and lower performance.Solution
Refactored the
KroneckerNormallogprob code to usept.sumwith explicit axis parameters, achieving the same functionality without relying on deprecated APIs.Tests
Related Issue
Fixes #7878
Checklist
📚 Documentation preview 📚: https://pymc--7951.org.readthedocs.build/en/7951/