Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tuple of axis in softmax_cross_entropy_with_integer_labels #1165

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

daskol
Copy link
Contributor

@daskol daskol commented Jan 2, 2025

Fix #1162 (see #1164 for concurrent fix).

@daskol daskol force-pushed the fix/softmax_cross_entropy_with_integer_labels branch from 3c5e04e to 9b7dde2 Compare January 2, 2025 21:34
@daskol daskol marked this pull request as ready for review January 2, 2025 21:39
@rdyro
Copy link
Collaborator

rdyro commented Jan 3, 2025

Hey, thanks! These PR looks great, I left some comments!

@rdyro
Copy link
Collaborator

rdyro commented Jan 3, 2025

JAX has an example of canonicalize_axis, we could vendor it from there (but it's a private API, so let's not import it from JAX directly)

@daskol
Copy link
Contributor Author

daskol commented Jan 3, 2025

Yeah, I know of canonicalize_axis and decided to get normalize_axis_* from origin for multiple reasons. Firstly, normilize_axis_* are proper implementation. Secondly, it basically consists of list/tuple transformation and jit-compilable by design. Thirdly, NumPy is already among dependencies and there is an explicit notice about deprecation in NumPy 2. But vendoring these routines is nice too.

optax/losses/_classification.py Show resolved Hide resolved
@rdyro
Copy link
Collaborator

rdyro commented Jan 3, 2025

Yeah, I know of canonicalize_axis and decided to get normalize_axis_* from origin for multiple reasons. Firstly, normilize_axis_* are proper implementation. Secondly, it basically consists of list/tuple transformation and jit-compilable by design. Thirdly, NumPy is already among dependencies and there is an explicit notice about deprecation in NumPy 2. But vendoring these routines is nice too.

Makes sense, numpy's canonicalize are nice, but the import discrepancy between versions is a bit of a downside given potential API instability. Thanks for vendoring them!

@daskol daskol force-pushed the fix/softmax_cross_entropy_with_integer_labels branch from 68ebabd to 80a1270 Compare January 8, 2025 17:33
@daskol
Copy link
Contributor Author

daskol commented Jan 8, 2025

Should I go ahead and close this pull request? I noticed a merge conflict and rebased the branch on top of main. However, it seems that the concurrent fix #1164 has already been merged, making this PR redundant.

copybara-service bot pushed a commit that referenced this pull request Jan 11, 2025
--
9b7dde2 by Daniel Bershatsky <[email protected]>:

Support tuple of axis in `softmax_cross_entropy_with_integer_labels`

--
68ebabd by Daniel Bershatsky <[email protected]>:

Adjust according to review comments

COPYBARA_INTEGRATE_REVIEW=#1165 from daskol:fix/softmax_cross_entropy_with_integer_labels 68ebabd
PiperOrigin-RevId: 714273561
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Axis tuple is not handled properly in softmax_cross_entropy_with_integer_labels loss
2 participants