Skip to content

Conversation

@pass-lin
Copy link
Contributor

@pass-lin pass-lin commented Sep 28, 2025

RWKV7 is one of the strongest RNN models available today, and we now provide a full implementation for it in keras_hub.

📚 References


🔗 Pre-trained Checkpoints (ModelScope)

Numerical-verification and Inference Example notebook

This is the first modern RNN architecture in keras_hub. With the resurgence of recurrent models, more pre-trained RNN backbones will follow; hence this PR also serves as a reference implementation for future work.

Current progress

  • [✅] backbone implementation
  • [✅] checkpoint-conversion script
  • [✅] tokenizer implementation
  • [✅] unit tests / examples
  • [✅] complete CausalLM task wrapper
  • [✅] Add document

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @pass-lin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the keras_hub library by integrating the RWKV7 model, a cutting-edge recurrent neural network. This addition not only provides a robust new model for users but also serves as a foundational reference implementation, encouraging the future inclusion of more modern RNN architectures within the library.

Highlights

  • RWKV7 Model Integration: Introduced the RWKV7 model, a powerful RNN architecture, into keras_hub, marking a significant expansion of the library's capabilities.
  • Comprehensive Implementation: The pull request includes a full implementation of the RWKV7 backbone, its dedicated tokenizer, and a causal language model wrapper.
  • Foundation for Modern RNNs: This is the first modern RNN architecture added to keras_hub, serving as a reference implementation and paving the way for future recurrent model additions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces the RWKV-7 model, a powerful RNN architecture, to keras_hub. The contribution is significant and includes the backbone, tokenizer, preprocessor, an incomplete task model, and a checkpoint conversion script. The implementation follows the modular structure of keras_hub.

However, there are several critical issues that must be addressed before this PR can be merged:

  1. Missing Tests: The PR lacks unit tests for all new components. According to the contribution guidelines, testing is a mandatory requirement.[^1]
  2. Incomplete CausalLM Task: The RWKV7CausalLM task model is a stub with TODOs, making it non-functional for generation.
  3. Critical Bugs: There are critical bugs in the tokenizer and preprocessor implementations that will cause runtime errors.
  4. Style Guide Violations: There are numerous style guide violations, including a filename typo, missing docstrings, and inconsistencies with the recommended model input structure.

I've left detailed comments on these issues. Once these are resolved, this will be a great addition to the library.

from modelscope import snapshot_download

from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
from keras_hub.src.models.rwkv7.rwkv7_casual_lm import RWKV7CausalLM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a typo in the import path: rwkv7_casual_lm should be rwkv7_causal_lm. This will cause an ImportError.

Suggested change
from keras_hub.src.models.rwkv7.rwkv7_casual_lm import RWKV7CausalLM
from keras_hub.src.models.rwkv7.rwkv7_causal_lm import RWKV7CausalLM

"keras_hub.models.RWKVTokenizer",
]
)
class RWKVTokenizer(tokenizer.Tokenizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The RWKVTokenizer class is missing a docstring. Please add a comprehensive docstring that explains the tokenizer, its arguments, and includes usage examples, as per the style guide.1

Style Guide References

Footnotes

  1. All public classes must have Google-style docstrings with a summary, examples, and documentation for parameters.

Comment on lines +394 to +395
- 0.5
) # soft-clamp to (-inf, -0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This hardcoded value - 0.5 is a magic number. It would be better to define it as a named constant at the top of the file or as a class attribute to improve readability and maintainability.



@keras_hub_export("keras_hub.models.RWKV7CausalLMPreprocessor")
class RWKV7CausalLMPreprocessor(CausalLMPreprocessor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The RWKV7CausalLMPreprocessor class is missing a docstring. Please add a Google-style docstring explaining its purpose, parameters, and include a usage example, as specified in the style guide.1

Style Guide References

Footnotes

  1. All public classes must have Google-style docstrings with a summary, examples, and documentation for parameters.

@pass-lin
Copy link
Contributor Author

@divyashreepathihalli @mattdangerw Anybody review my code?

@pass-lin
Copy link
Contributor Author

@sachinprasadhs

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive implementation of the RWKV-7 model, a modern RNN architecture, into keras_hub. The contribution is well-structured, following the repository's modular design with a backbone, causal LM task, preprocessor, tokenizer, and a checkpoint conversion script. The code is generally of high quality.

My review highlights a few areas for improvement. There are critical bugs in the tokenizer's asset loading and saving logic that need to be addressed. The custom layers in rwkv7_layer.py have inconsistent return signatures, which could lead to runtime errors and makes the code harder to maintain. Additionally, the test files do not use the standardized helper methods from TestCase as required by the repository's style guide, which is a significant deviation. I have also pointed out some minor typos and opportunities for code clarification.

Overall, this is an excellent and valuable addition. Addressing the feedback will improve the robustness, correctness, and maintainability of the new model.

"""
path = os.path.join(dir_path, VOCAB_FILENAME)
with open(path, "r", encoding="utf-8") as f:
vocabulary = f.readlines()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The load_assets method uses f.readlines(), which keeps trailing newline characters (\n). This will cause the int() conversion to fail when parsing the vocabulary file because int() cannot parse a string with a trailing newline (e.g., int('3\n')). You should strip the newlines when reading the file.

Suggested change
vocabulary = f.readlines()
vocabulary = f.read().splitlines()

Comment on lines 474 to 476
if train_mode:
return x, v_first
return x, v_first, last_cache_x, finnal_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call method in RWKV7_TimeMix has two issues:

  1. Inconsistent return signature: It returns a different number of values depending on train_mode, which can cause runtime errors.1
  2. Typo: The variable finnal_state is misspelled. It should be final_state.

To fix this, the method should always return the same number of values, and the typo should be corrected. This change should also be propagated to where finnal_state is computed on line 448.

Suggested change
if train_mode:
return x, v_first
return x, v_first, last_cache_x, finnal_state
if train_mode:
return x, v_first, None, None
return x, v_first, last_cache_x, finnal_state

Rules References

Footnotes

  1. Methods should have a consistent return signature to avoid errors and improve predictability. Returning None for values not computed in a certain mode is a good practice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is error issue

Comment on lines +120 to +122
if train_mode:
return output
return output, last_cache_x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call method in RWKV7_ChannelMix has an inconsistent return signature. It returns one value in training mode and two in inference mode. This can lead to errors and makes the layer harder to use. The method should always return the same number of values, using None for outputs that are not computed in a particular mode.1

Suggested change
if train_mode:
return output
return output, last_cache_x
if train_mode:
return output, None
return output, last_cache_x

Rules References

Footnotes

  1. Methods should have a consistent return signature to avoid errors and improve predictability. Returning None for values not computed in a certain mode is a good practice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The method in has an inconsistent return signature. It returns one value in training mode and two in inference mode. This can lead to errors and makes the layer harder to use. The method should always return the same number of values, using for outputs that are not computed in a particular mode.call``RWKV7_ChannelMix``None1

Rules References

Footnotes

  1. Methods should have a consistent return signature to avoid errors and improve predictability. Returning for values not computed in a certain mode is a good practice. None

this is error issue


Args:
x: Input tensor.
last_cache_x: Cached previous values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sorted variable is initialized and appended to but never used. It should be removed to avoid confusion and unnecessary code.

Suggested change
last_cache_x: Cached previous values.
# sorted = [] # must be already sorted

Comment on lines 78 to 93
"""Initialize RWKV7 backbone.

Args:
hidden_size: Hidden dimension size.
head_size: Attention head size.
num_layers: Number of RWKV blocks.
vocabulary_size: Size of vocabulary.
intermediate_dim: Intermediate dimension for FFN.
gate_lora: LoRA dimension for gating.
mv_lora: LoRA dimension for value mixing.
aaa_lora: LoRA dimension for alpha parameters.
decay_lora: LoRA dimension for decay parameters.
dtype: Data type for the layer.
dropout_rate: Dropout rate for regularization.
**kwargs: Additional arguments.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method's docstring duplicates information already present in the class-level docstring. According to the style guide's example for backbones, the __init__ method should not have a separate docstring.1 Removing this will make the code more concise and align it with the repository's conventions.

Style Guide References

Footnotes

  1. The style guide example for a backbone class shows arguments documented in the class docstring, not in the __init__ method.

the final hidden representation of the input tokens, and `cache` is
the decoding cache.
"""
state_cachce, last_token_cache = cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the variable name state_cachce. It should be state_cache. This typo appears in multiple places within call_with_cache and _build_cache.

Suggested change
state_cachce, last_token_cache = cache
state_cache, last_token_cache = cache

@pass-lin
Copy link
Contributor Author

@mattdangerw @divyashreepathihalli @sachinprasadhs
It's been almost a month, is there anyone willing to pay attention to me?

@sachinprasadhs
Copy link
Collaborator

Apologies for the delay in review, taking a look into this. Will add my comments.

@sachinprasadhs
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the RWKV-7 model, a modern RNN architecture, to keras_hub. The implementation is comprehensive, covering the backbone, causal LM task, preprocessor, tokenizer, and a checkpoint conversion script. The code is well-structured and follows the modular design principles of the repository.

My review focuses on ensuring adherence to the repository's style guide, particularly regarding testing practices and code style conventions. I've identified several areas for improvement:

  • The testing for the new components should be updated to use the standardized helper methods from the base TestCase. Some tests also contain incorrect assertions.
  • There are a few deviations from the coding style, such as the use of type hints in function signatures and a few hardcoded values that could be made more flexible.
  • The backbone implementation should be updated to accept a padding_mask as input, aligning with the repository's conventions.

Addressing these points will improve the consistency, correctness, and maintainability of the new model. Overall, this is a great contribution, adding a powerful and interesting architecture to the library.

Comment on lines +121 to +128
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)

padding_mask = ops.not_equal(token_id_input, 0)

x = self.token_embedding(token_id_input)
padding_mask = ops.cast(padding_mask, dtype=x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The model should accept padding_mask as an input instead of deriving it from token_ids. This hardcodes the assumption that the padding token ID is 0 and deviates from the repository's style guide for backbone models.1

Please update the model to accept padding_mask as a keras.Input and also update the super().__init__ call to include it in the inputs dictionary.

Suggested change
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask = ops.not_equal(token_id_input, 0)
x = self.token_embedding(token_id_input)
padding_mask = ops.cast(padding_mask, dtype=x.dtype)
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)
x = self.token_embedding(token_id_input)
padding_mask = ops.cast(padding_mask_input, dtype=x.dtype)

Style Guide References

Footnotes

  1. Backbone models should accept standardized input names like token_ids and padding_mask to ensure interoperability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, I asked the author, it can be hardcoded.

Comment on lines +26 to +37
def test_backbone_basics(self):
"""
Test basic functionality of the RWKV7 backbone.
"""
y = self.backbone(self.input_data)
self.assertEqual(y.shape, (2, 5, 10))

def test_num_parameters(self):
"""
Test that the model has the expected number of parameters.
"""
self.assertEqual(self.backbone.count_params(), 10208)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tests should use the standardized test routines provided by TestCase as required by the style guide.1 Please replace the custom test logic with calls to self.run_backbone_test() and self.run_model_saving_test(). This ensures consistency and covers more test cases automatically, such as variable input shapes and serialization. You will also need to add import pytest.

Suggested change
def test_backbone_basics(self):
"""
Test basic functionality of the RWKV7 backbone.
"""
y = self.backbone(self.input_data)
self.assertEqual(y.shape, (2, 5, 10))
def test_num_parameters(self):
"""
Test that the model has the expected number of parameters.
"""
self.assertEqual(self.backbone.count_params(), 10208)
def test_backbone_basics(self):
self.run_backbone_test(
cls=RWKV7Backbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 10),
)
def test_saved_model(self):
self.run_model_saving_test(
cls=RWKV7Backbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

Style Guide References

Footnotes

  1. The style guide requires using helper methods like self.run_backbone_test() and self.run_model_saving_test() for standardized testing of backbones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After modifying it this way, fp16 will fail, but I cannot reproduce this error.

Comment on lines +20 to +50
def test_preprocessor_basics(self):
result = self.preprocessor(x=["hello world hello world hello world"])
self.assertAllEqual(
result[0], [[0, 0, 0, 0, 0, 0, 4, 1, 5, 1, 4, 1, 5, 1, 4, 1]]
)
self.assertAllEqual(
result[1], [[0, 0, 0, 0, 0, 4, 1, 5, 1, 4, 1, 5, 1, 4, 1, 5]]
)
self.assertAllEqual(
result[2],
[
[
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of issues here:

  1. The test assertions for the output shapes appear to be incorrect. For a sequence_length of 15, the call method adjusts it to 17. The resulting sample_weight (result[2]) should have a length of 16, but the test asserts a length of 17.
  2. The style guide recommends using the self.run_preprocessor_test() helper for testing preprocessors, which is not used here.1

Please correct the assertions and refactor the test to use the standard test helper.

Style Guide References

Footnotes

  1. The style guide requires using self.run_preprocessor_test() for testing preprocessors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Gemini number is incorrect

Comment on lines +54 to +90
def test_generate(self):
"""
Test text generation functionality.
"""
causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor)
prompt = ["hello world"]
output = causal_lm.generate(prompt, 16)
self.assertTrue(isinstance(output[0], str))
self.assertTrue(isinstance(output, list))

prompt = "hello world"
output = causal_lm.generate(prompt, 16)
self.assertTrue(isinstance(output, str))

def test_generate_strip_prompt(self):
"""
Test that generated text can strip the prompt from output.
"""
prompt = ["hello world"]
causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor)
output = causal_lm.generate(prompt, 16, strip_prompt=True)
self.assertFalse(output[0].startswith(prompt[0]))

def test_generate_compilation(self):
"""
Test that the generate function compiles correctly and
reuses compiled functions.
"""
causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor)
causal_lm.generate(["hello world"], 16)
first_fn = causal_lm.generate_function
causal_lm.generate(["hello world"], 16)
second_fn = causal_lm.generate_function
self.assertEqual(first_fn, second_fn)

causal_lm.compile(sampler="greedy")
self.assertIsNone(causal_lm.generate_function)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

According to the style guide, task model tests should use the self.run_task_test() helper method.1 This ensures comprehensive testing, including the attached preprocessor functionality. Please refactor the tests to use this helper.

Suggested change
def test_generate(self):
"""
Test text generation functionality.
"""
causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor)
prompt = ["hello world"]
output = causal_lm.generate(prompt, 16)
self.assertTrue(isinstance(output[0], str))
self.assertTrue(isinstance(output, list))
prompt = "hello world"
output = causal_lm.generate(prompt, 16)
self.assertTrue(isinstance(output, str))
def test_generate_strip_prompt(self):
"""
Test that generated text can strip the prompt from output.
"""
prompt = ["hello world"]
causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor)
output = causal_lm.generate(prompt, 16, strip_prompt=True)
self.assertFalse(output[0].startswith(prompt[0]))
def test_generate_compilation(self):
"""
Test that the generate function compiles correctly and
reuses compiled functions.
"""
causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor)
causal_lm.generate(["hello world"], 16)
first_fn = causal_lm.generate_function
causal_lm.generate(["hello world"], 16)
second_fn = causal_lm.generate_function
self.assertEqual(first_fn, second_fn)
causal_lm.compile(sampler="greedy")
self.assertIsNone(causal_lm.generate_function)
def test_task_basics(self):
"""
Test CausalLM task with the preprocessor.
"""
self.run_task_test(
cls=RWKV7CausalLM,
init_kwargs=self.init_kwargs,
train_data=(
["hello world", "the quick brown fox"],
),
expected_output_shape=(2, 16, 5),
)

Style Guide References

Footnotes

  1. The style guide requires using self.run_task_test() for testing task models with preprocessors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the tests modified based on LLaMA Casual LM. I think Gemini's review is incorrect.

Comment on lines +34 to +54
Examples:
```python
# Initialize the tokenizer and load assets from a local path.
tokenizer = RWKVTokenizer()
tokenizer.load_assets(rwkv_path)

# Create a preprocessor with a sequence length of 8.
preprocessor = RWKV7CausalLMPreprocessor(tokenizer, sequence_length=8)

# Initialize the model with a backbone and preprocessor.
causal_lm = RWKV7CausalLM(backbone, preprocessor)

prompts = ["Bubble sort\n```python", "Hello World\n```python\n"]

causal_lm.compile(sampler="greedy")

outputs = causal_lm.generate(prompts, max_length=128)
for out in outputs:
print(out)
print("-" * 100)
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example in the docstring is not self-contained and cannot be run as-is, as it references undefined variables like RWKVTokenizer, rwkv_path, and backbone.1 Per the style guide, examples should be comprehensive and runnable. Please update it to be a complete code snippet.

Style Guide References

Footnotes

  1. Docstrings must include comprehensive examples showing usage patterns.

@pass-lin
Copy link
Contributor Author

pass-lin commented Nov 5, 2025

Apologies for the delay in review, taking a look into this. Will add my comments.

Hello, I have modified the code according to Gemini's suggestions. Please add a manual review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants