-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Keras NNX demo guide #21557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras NNX demo guide #21557
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've added a new guide titled "How to use Keras with NNX backend". This tutorial demonstrates the integration of Keras with Flax's NNX module system, highlighting how it enhances variable handling and enables advanced training capabilities within the JAX ecosystem. It covers both the familiar model.fit() workflow and custom training loops, showcasing the flexibility and power of combining Keras's user-friendliness with NNX's explicit state management.
Highlights
- Keras-NNX Integration Guide: I've introduced a new comprehensive guide demonstrating how to effectively use Keras with the Flax NNX backend, emphasizing enhanced variable handling and advanced training capabilities within JAX.
- Seamless Keras Variable Integration: The guide illustrates how
keras.Variableacts as a native component within the NNX ecosystem, allowing for the free mixing of Keras and NNX components and ensuring proper state management. - Flexible Training Workflows: I've showcased two primary training workflows: the classic Keras
model.fit()experience for high-level productivity and custom training loops leveraging NNX and Optax for fine-grained control. - Model Serialization Compatibility: The guide confirms that standard Keras model saving and loading functionalities work seamlessly with the NNX integration, ensuring investment in the Keras ecosystem is preserved.
- Real-World Application with Gemma: I've included an example demonstrating the fine-tuning of a Gemma model from KerasHub, illustrating the practical application of the Keras-NNX integration in a real-world scenario.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new guide for using Keras with the JAX backend and NNX. The guide is comprehensive, covering setup, core concepts, different training workflows, serialization, and a real-world example with Gemma. My review focuses on improving the guide's robustness, maintainability, and correctness. Key suggestions include using official package sources instead of personal forks, leveraging public APIs over internal attributes, adding verification steps to examples, ensuring all dependencies are installed, and fixing some formatting issues in the text.
| # Getting Started: Setting Up Your Environment | ||
| """ | ||
|
|
||
| !pip install -q git+https://github.com/hertschuh/keras.git@saving_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The guide installs Keras from a personal fork and branch (hertschuh/keras.git@saving_op). For a public guide, this is a potential security risk and a maintenance issue, as the branch could be deleted or modified with breaking changes. It's recommended to use an official release from PyPI or, if a development version is absolutely necessary, pin it to a specific commit hash from the official repository for stability and security.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
What is this branch of mine?? (that doesn't exist anymore)
| Before trying out this KerasHub model, please make sure you have set up your Kaggle credentials in colab secrets. The colab pulls in `KAGGLE_KEY` and `KAGGLE_USERNAME` to authenticate and download the models. | ||
| """ | ||
|
|
||
| import keras_hub |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model(X) | ||
|
|
||
| tx = optax.sgd(1e-3) | ||
| trainable_var = nnx.All(keras.Variable, lambda path, x: getattr(x, '_trainable', False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The lambda function uses getattr(x, '_trainable', False) to filter for trainable variables. _trainable is an internal attribute, and its use is discouraged as it may change without notice in future versions. Please use the public property x.trainable instead for better code stability and maintainability.
trainable_var = nnx.All(keras.Variable, lambda path, x: x.trainable)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, this ought to work and is more readable.
| model.save('my_nnx_model.keras') | ||
| restored_model = keras.models.load_model('my_nnx_model.keras') | ||
|
|
||
| print("Restored model output:", restored_model(dummy_input)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example for saving and loading a model demonstrates the API calls but doesn't programmatically verify that the restored model behaves identically to the original. For a more robust and self-contained example, it would be beneficial to add an assertion that compares the outputs of the original and restored models. This provides a clear confirmation that serialization was successful.
print("Restored model output:", restored_model(dummy_input))
# Verification
np.testing.assert_allclose(model(dummy_input), restored_model(dummy_input))
print("\n✅ SUCCESS: Restored model output matches original model output.")|
|
||
| The Keras-NNX integration represents a significant step forward, offering a unified framework for both rapid prototyping and high-performance, customizable research. You can now: | ||
| Use familiar Keras APIs (Sequential, Model, fit, save) on a JAX backend. | ||
| Integrate Keras layers and models directly into Flax NNX modules and training loops.Integrate keras code/model with NNX ecosytem like Qwix, Tunix, etc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a formatting issue in this line. The sentence "Integrate keras code/model with NNX ecosytem like Qwix, Tunix, etc." is concatenated with the preceding bullet point. It should likely be on a new line to be its own bullet point for clarity. Also, there is a typo in "ecosytem" (should be "ecosystem").
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21557 +/- ##
=======================================
Coverage 62.02% 62.02%
=======================================
Files 567 567
Lines 56464 56464
Branches 8825 8825
=======================================
+ Hits 35021 35023 +2
+ Misses 19263 19262 -1
+ Partials 2180 2179 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Can you run autogen to generate the ipynb and md files?
| Accelerator: CPU | ||
| """ | ||
|
|
||
| # -*- coding: utf-8 -*- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an emacs thing. autogen won't honor it anyway, so remove.
|
|
||
| # A Guide to the Keras & Flax NNX Integration | ||
|
|
||
| This tutorial will guide you through the integration of Keras with Flax's NNX (Neural Networks JAX) module system, demonstrating how it significantly enhances variable handling and opens up advanced training capabilities within the JAX ecosystem. Whether you love the simplicity of model.fit() or the fine-grained control of a custom training loop, this integration lets you have the best of both worlds. Let's dive in! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to run black on this and make sure lines are only 80 characters long.
|
|
||
| Keras is known for its user-friendliness and high-level API, making deep learning accessible. JAX, on the other hand, provides high-performance numerical computation, especially suited for machine learning research due to its JIT compilation and automatic differentiation capabilities. NNX is Flax's functional module system built on JAX, offering explicit state management and powerful functional programming paradigms | ||
|
|
||
| NNX is designed for simplicity. It is characterized by its Pythonic approach, where modules are standard Python classes, promoting ease of use and familiarity. NNX prioritizes user-friendliness and offers fine-grained control over JAX transformations through typed Variable collections |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dot at the end of the sentence.
| # Getting Started: Setting Up Your Environment | ||
| """ | ||
|
|
||
| !pip install -q git+https://github.com/hertschuh/keras.git@saving_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
What is this branch of mine?? (that doesn't exist anymore)
| """ | ||
|
|
||
| !pip install -q git+https://github.com/hertschuh/keras.git@saving_op | ||
| !pip uninstall -y flax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just do !pip install -U -q flax==0.11.0 on the next line?
| keras.layers.Dense(units=1, input_shape=(10,), name="my_dense_layer") | ||
| ]) | ||
|
|
||
| print("--- Initial Model Weights ---") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once you render the notebook, I think it looks better to just do:
"""
1. Create a Keras Model
"""The comments pop out better than if they look like the outputs.
| model = MySimpleKerasModel() | ||
| model(X) | ||
|
|
||
| tx = optax.sgd(1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, what does tx stand for?
| model(X) | ||
|
|
||
| tx = optax.sgd(1e-3) | ||
| trainable_var = nnx.All(keras.Variable, lambda path, x: getattr(x, '_trainable', False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, this ought to work and is more readable.
|
moving to keras-team/keras-io#2159 |
No description provided.