Skip to content

Conversation

@divyashreepathihalli
Copy link
Collaborator

No description provided.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've added a new guide titled "How to use Keras with NNX backend". This tutorial demonstrates the integration of Keras with Flax's NNX module system, highlighting how it enhances variable handling and enables advanced training capabilities within the JAX ecosystem. It covers both the familiar model.fit() workflow and custom training loops, showcasing the flexibility and power of combining Keras's user-friendliness with NNX's explicit state management.

Highlights

  • Keras-NNX Integration Guide: I've introduced a new comprehensive guide demonstrating how to effectively use Keras with the Flax NNX backend, emphasizing enhanced variable handling and advanced training capabilities within JAX.
  • Seamless Keras Variable Integration: The guide illustrates how keras.Variable acts as a native component within the NNX ecosystem, allowing for the free mixing of Keras and NNX components and ensuring proper state management.
  • Flexible Training Workflows: I've showcased two primary training workflows: the classic Keras model.fit() experience for high-level productivity and custom training loops leveraging NNX and Optax for fine-grained control.
  • Model Serialization Compatibility: The guide confirms that standard Keras model saving and loading functionalities work seamlessly with the NNX integration, ensuring investment in the Keras ecosystem is preserved.
  • Real-World Application with Gemma: I've included an example demonstrating the fine-tuning of a Gemma model from KerasHub, illustrating the practical application of the Keras-NNX integration in a real-world scenario.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new guide for using Keras with the JAX backend and NNX. The guide is comprehensive, covering setup, core concepts, different training workflows, serialization, and a real-world example with Gemma. My review focuses on improving the guide's robustness, maintainability, and correctness. Key suggestions include using official package sources instead of personal forks, leveraging public APIs over internal attributes, adding verification steps to examples, ensuring all dependencies are installed, and fixing some formatting issues in the text.

# Getting Started: Setting Up Your Environment
"""

!pip install -q git+https://github.com/hertschuh/keras.git@saving_op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The guide installs Keras from a personal fork and branch (hertschuh/keras.git@saving_op). For a public guide, this is a potential security risk and a maintenance issue, as the branch could be deleted or modified with breaking changes. It's recommended to use an official release from PyPI or, if a development version is absolutely necessary, pin it to a specific commit hash from the official repository for stability and security.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

What is this branch of mine?? (that doesn't exist anymore)

Before trying out this KerasHub model, please make sure you have set up your Kaggle credentials in colab secrets. The colab pulls in `KAGGLE_KEY` and `KAGGLE_USERNAME` to authenticate and download the models.
"""

import keras_hub
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The keras_hub library is used here to load the Gemma model, but it is not installed in the setup section of the guide. This will cause a ModuleNotFoundError. Please add !pip install -q keras_hub to the setup cell at the beginning of the notebook (e.g., after line 30).

model(X)

tx = optax.sgd(1e-3)
trainable_var = nnx.All(keras.Variable, lambda path, x: getattr(x, '_trainable', False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The lambda function uses getattr(x, '_trainable', False) to filter for trainable variables. _trainable is an internal attribute, and its use is discouraged as it may change without notice in future versions. Please use the public property x.trainable instead for better code stability and maintainability.

trainable_var = nnx.All(keras.Variable, lambda path, x: x.trainable)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this ought to work and is more readable.

model.save('my_nnx_model.keras')
restored_model = keras.models.load_model('my_nnx_model.keras')

print("Restored model output:", restored_model(dummy_input))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example for saving and loading a model demonstrates the API calls but doesn't programmatically verify that the restored model behaves identically to the original. For a more robust and self-contained example, it would be beneficial to add an assertion that compares the outputs of the original and restored models. This provides a clear confirmation that serialization was successful.

print("Restored model output:", restored_model(dummy_input))

# Verification
np.testing.assert_allclose(model(dummy_input), restored_model(dummy_input))
print("\n✅ SUCCESS: Restored model output matches original model output.")


The Keras-NNX integration represents a significant step forward, offering a unified framework for both rapid prototyping and high-performance, customizable research. You can now:
Use familiar Keras APIs (Sequential, Model, fit, save) on a JAX backend.
Integrate Keras layers and models directly into Flax NNX modules and training loops.Integrate keras code/model with NNX ecosytem like Qwix, Tunix, etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There appears to be a formatting issue in this line. The sentence "Integrate keras code/model with NNX ecosytem like Qwix, Tunix, etc." is concatenated with the preceding bullet point. It should likely be on a new line to be its own bullet point for clarity. Also, there is a typo in "ecosytem" (should be "ecosystem").

@codecov-commenter
Copy link

codecov-commenter commented Aug 7, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 62.02%. Comparing base (0b905e9) to head (6c71273).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21557   +/-   ##
=======================================
  Coverage   62.02%   62.02%           
=======================================
  Files         567      567           
  Lines       56464    56464           
  Branches     8825     8825           
=======================================
+ Hits        35021    35023    +2     
+ Misses      19263    19262    -1     
+ Partials     2180     2179    -1     
Flag Coverage Δ
keras 62.01% <ø> (+<0.01%) ⬆️
keras-numpy 58.30% <ø> (+<0.01%) ⬆️
keras-openvino 34.63% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Can you run autogen to generate the ipynb and md files?

Accelerator: CPU
"""

# -*- coding: utf-8 -*-
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an emacs thing. autogen won't honor it anyway, so remove.


# A Guide to the Keras & Flax NNX Integration

This tutorial will guide you through the integration of Keras with Flax's NNX (Neural Networks JAX) module system, demonstrating how it significantly enhances variable handling and opens up advanced training capabilities within the JAX ecosystem. Whether you love the simplicity of model.fit() or the fine-grained control of a custom training loop, this integration lets you have the best of both worlds. Let's dive in!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to run black on this and make sure lines are only 80 characters long.


Keras is known for its user-friendliness and high-level API, making deep learning accessible. JAX, on the other hand, provides high-performance numerical computation, especially suited for machine learning research due to its JIT compilation and automatic differentiation capabilities. NNX is Flax's functional module system built on JAX, offering explicit state management and powerful functional programming paradigms

NNX is designed for simplicity. It is characterized by its Pythonic approach, where modules are standard Python classes, promoting ease of use and familiarity. NNX prioritizes user-friendliness and offers fine-grained control over JAX transformations through typed Variable collections
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dot at the end of the sentence.

# Getting Started: Setting Up Your Environment
"""

!pip install -q git+https://github.com/hertschuh/keras.git@saving_op
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

What is this branch of mine?? (that doesn't exist anymore)

"""

!pip install -q git+https://github.com/hertschuh/keras.git@saving_op
!pip uninstall -y flax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do !pip install -U -q flax==0.11.0 on the next line?

keras.layers.Dense(units=1, input_shape=(10,), name="my_dense_layer")
])

print("--- Initial Model Weights ---")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you render the notebook, I think it looks better to just do:

"""
1. Create a Keras Model
"""

The comments pop out better than if they look like the outputs.

model = MySimpleKerasModel()
model(X)

tx = optax.sgd(1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, what does tx stand for?

model(X)

tx = optax.sgd(1e-3)
trainable_var = nnx.All(keras.Variable, lambda path, x: getattr(x, '_trainable', False))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this ought to work and is more readable.

@divyashreepathihalli
Copy link
Collaborator Author

moving to keras-team/keras-io#2159

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants