-
Notifications
You must be signed in to change notification settings - Fork 306
ADD RWKV7 #2421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
ADD RWKV7 #2421
Conversation
Summary of ChangesHello @pass-lin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces the RWKV-7 model, a powerful RNN architecture, to keras_hub. The contribution is significant and includes the backbone, tokenizer, preprocessor, an incomplete task model, and a checkpoint conversion script. The implementation follows the modular structure of keras_hub.
However, there are several critical issues that must be addressed before this PR can be merged:
- Missing Tests: The PR lacks unit tests for all new components. According to the contribution guidelines, testing is a mandatory requirement.[^1]
- Incomplete
CausalLMTask: TheRWKV7CausalLMtask model is a stub withTODOs, making it non-functional for generation. - Critical Bugs: There are critical bugs in the tokenizer and preprocessor implementations that will cause runtime errors.
- Style Guide Violations: There are numerous style guide violations, including a filename typo, missing docstrings, and inconsistencies with the recommended model input structure.
I've left detailed comments on these issues. Once these are resolved, this will be a great addition to the library.
| from modelscope import snapshot_download | ||
|
|
||
| from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone | ||
| from keras_hub.src.models.rwkv7.rwkv7_casual_lm import RWKV7CausalLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "keras_hub.models.RWKVTokenizer", | ||
| ] | ||
| ) | ||
| class RWKVTokenizer(tokenizer.Tokenizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RWKVTokenizer class is missing a docstring. Please add a comprehensive docstring that explains the tokenizer, its arguments, and includes usage examples, as per the style guide.1
Style Guide References
Footnotes
-
All public classes must have Google-style docstrings with a summary, examples, and documentation for parameters. ↩
| - 0.5 | ||
| ) # soft-clamp to (-inf, -0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
|
|
||
| @keras_hub_export("keras_hub.models.RWKV7CausalLMPreprocessor") | ||
| class RWKV7CausalLMPreprocessor(CausalLMPreprocessor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RWKV7CausalLMPreprocessor class is missing a docstring. Please add a Google-style docstring explaining its purpose, parameters, and include a usage example, as specified in the style guide.1
Style Guide References
Footnotes
-
All public classes must have Google-style docstrings with a summary, examples, and documentation for parameters. ↩
|
@divyashreepathihalli @mattdangerw Anybody review my code? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive implementation of the RWKV-7 model, a modern RNN architecture, into keras_hub. The contribution is well-structured, following the repository's modular design with a backbone, causal LM task, preprocessor, tokenizer, and a checkpoint conversion script. The code is generally of high quality.
My review highlights a few areas for improvement. There are critical bugs in the tokenizer's asset loading and saving logic that need to be addressed. The custom layers in rwkv7_layer.py have inconsistent return signatures, which could lead to runtime errors and makes the code harder to maintain. Additionally, the test files do not use the standardized helper methods from TestCase as required by the repository's style guide, which is a significant deviation. I have also pointed out some minor typos and opportunities for code clarification.
Overall, this is an excellent and valuable addition. Addressing the feedback will improve the robustness, correctness, and maintainability of the new model.
| """ | ||
| path = os.path.join(dir_path, VOCAB_FILENAME) | ||
| with open(path, "r", encoding="utf-8") as f: | ||
| vocabulary = f.readlines() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_assets method uses f.readlines(), which keeps trailing newline characters (\n). This will cause the int() conversion to fail when parsing the vocabulary file because int() cannot parse a string with a trailing newline (e.g., int('3\n')). You should strip the newlines when reading the file.
| vocabulary = f.readlines() | |
| vocabulary = f.read().splitlines() |
| if train_mode: | ||
| return x, v_first | ||
| return x, v_first, last_cache_x, finnal_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call method in RWKV7_TimeMix has two issues:
- Inconsistent return signature: It returns a different number of values depending on
train_mode, which can cause runtime errors.1 - Typo: The variable
finnal_stateis misspelled. It should befinal_state.
To fix this, the method should always return the same number of values, and the typo should be corrected. This change should also be propagated to where finnal_state is computed on line 448.
| if train_mode: | |
| return x, v_first | |
| return x, v_first, last_cache_x, finnal_state | |
| if train_mode: | |
| return x, v_first, None, None | |
| return x, v_first, last_cache_x, finnal_state |
Rules References
Footnotes
-
Methods should have a consistent return signature to avoid errors and improve predictability. Returning
Nonefor values not computed in a certain mode is a good practice. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is error issue
| if train_mode: | ||
| return output | ||
| return output, last_cache_x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call method in RWKV7_ChannelMix has an inconsistent return signature. It returns one value in training mode and two in inference mode. This can lead to errors and makes the layer harder to use. The method should always return the same number of values, using None for outputs that are not computed in a particular mode.1
| if train_mode: | |
| return output | |
| return output, last_cache_x | |
| if train_mode: | |
| return output, None | |
| return output, last_cache_x |
Rules References
Footnotes
-
Methods should have a consistent return signature to avoid errors and improve predictability. Returning
Nonefor values not computed in a certain mode is a good practice. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method in has an inconsistent return signature. It returns one value in training mode and two in inference mode. This can lead to errors and makes the layer harder to use. The method should always return the same number of values, using for outputs that are not computed in a particular mode.
call``RWKV7_ChannelMix``None1Rules References
Footnotes
- Methods should have a consistent return signature to avoid errors and improve predictability. Returning for values not computed in a certain mode is a good practice. ↩
None
this is error issue
|
|
||
| Args: | ||
| x: Input tensor. | ||
| last_cache_x: Cached previous values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """Initialize RWKV7 backbone. | ||
|
|
||
| Args: | ||
| hidden_size: Hidden dimension size. | ||
| head_size: Attention head size. | ||
| num_layers: Number of RWKV blocks. | ||
| vocabulary_size: Size of vocabulary. | ||
| intermediate_dim: Intermediate dimension for FFN. | ||
| gate_lora: LoRA dimension for gating. | ||
| mv_lora: LoRA dimension for value mixing. | ||
| aaa_lora: LoRA dimension for alpha parameters. | ||
| decay_lora: LoRA dimension for decay parameters. | ||
| dtype: Data type for the layer. | ||
| dropout_rate: Dropout rate for regularization. | ||
| **kwargs: Additional arguments. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method's docstring duplicates information already present in the class-level docstring. According to the style guide's example for backbones, the __init__ method should not have a separate docstring.1 Removing this will make the code more concise and align it with the repository's conventions.
Style Guide References
Footnotes
-
The style guide example for a backbone class shows arguments documented in the class docstring, not in the
__init__method. ↩
| the final hidden representation of the input tokens, and `cache` is | ||
| the decoding cache. | ||
| """ | ||
| state_cachce, last_token_cache = cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
@mattdangerw @divyashreepathihalli @sachinprasadhs |
|
Apologies for the delay in review, taking a look into this. Will add my comments. |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the RWKV-7 model, a modern RNN architecture, to keras_hub. The implementation is comprehensive, covering the backbone, causal LM task, preprocessor, tokenizer, and a checkpoint conversion script. The code is well-structured and follows the modular design principles of the repository.
My review focuses on ensuring adherence to the repository's style guide, particularly regarding testing practices and code style conventions. I've identified several areas for improvement:
- The testing for the new components should be updated to use the standardized helper methods from the base
TestCase. Some tests also contain incorrect assertions. - There are a few deviations from the coding style, such as the use of type hints in function signatures and a few hardcoded values that could be made more flexible.
- The backbone implementation should be updated to accept a
padding_maskas input, aligning with the repository's conventions.
Addressing these points will improve the consistency, correctness, and maintainability of the new model. Overall, this is a great contribution, adding a powerful and interesting architecture to the library.
| token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="token_ids" | ||
| ) | ||
|
|
||
| padding_mask = ops.not_equal(token_id_input, 0) | ||
|
|
||
| x = self.token_embedding(token_id_input) | ||
| padding_mask = ops.cast(padding_mask, dtype=x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model should accept padding_mask as an input instead of deriving it from token_ids. This hardcodes the assumption that the padding token ID is 0 and deviates from the repository's style guide for backbone models.1
Please update the model to accept padding_mask as a keras.Input and also update the super().__init__ call to include it in the inputs dictionary.
| token_id_input = keras.Input( | |
| shape=(None,), dtype="int32", name="token_ids" | |
| ) | |
| padding_mask = ops.not_equal(token_id_input, 0) | |
| x = self.token_embedding(token_id_input) | |
| padding_mask = ops.cast(padding_mask, dtype=x.dtype) | |
| token_id_input = keras.Input( | |
| shape=(None,), dtype="int32", name="token_ids" | |
| ) | |
| padding_mask_input = keras.Input( | |
| shape=(None,), dtype="int32", name="padding_mask" | |
| ) | |
| x = self.token_embedding(token_id_input) | |
| padding_mask = ops.cast(padding_mask_input, dtype=x.dtype) |
Style Guide References
Footnotes
-
Backbone models should accept standardized input names like
token_idsandpadding_maskto ensure interoperability. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem, I asked the author, it can be hardcoded.
| def test_backbone_basics(self): | ||
| """ | ||
| Test basic functionality of the RWKV7 backbone. | ||
| """ | ||
| y = self.backbone(self.input_data) | ||
| self.assertEqual(y.shape, (2, 5, 10)) | ||
|
|
||
| def test_num_parameters(self): | ||
| """ | ||
| Test that the model has the expected number of parameters. | ||
| """ | ||
| self.assertEqual(self.backbone.count_params(), 10208) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests should use the standardized test routines provided by TestCase as required by the style guide.1 Please replace the custom test logic with calls to self.run_backbone_test() and self.run_model_saving_test(). This ensures consistency and covers more test cases automatically, such as variable input shapes and serialization. You will also need to add import pytest.
| def test_backbone_basics(self): | |
| """ | |
| Test basic functionality of the RWKV7 backbone. | |
| """ | |
| y = self.backbone(self.input_data) | |
| self.assertEqual(y.shape, (2, 5, 10)) | |
| def test_num_parameters(self): | |
| """ | |
| Test that the model has the expected number of parameters. | |
| """ | |
| self.assertEqual(self.backbone.count_params(), 10208) | |
| def test_backbone_basics(self): | |
| self.run_backbone_test( | |
| cls=RWKV7Backbone, | |
| init_kwargs=self.init_kwargs, | |
| input_data=self.input_data, | |
| expected_output_shape=(2, 5, 10), | |
| ) | |
| def test_saved_model(self): | |
| self.run_model_saving_test( | |
| cls=RWKV7Backbone, | |
| init_kwargs=self.init_kwargs, | |
| input_data=self.input_data, | |
| ) |
Style Guide References
Footnotes
-
The style guide requires using helper methods like
self.run_backbone_test()andself.run_model_saving_test()for standardized testing of backbones. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After modifying it this way, fp16 will fail, but I cannot reproduce this error.
| def test_preprocessor_basics(self): | ||
| result = self.preprocessor(x=["hello world hello world hello world"]) | ||
| self.assertAllEqual( | ||
| result[0], [[0, 0, 0, 0, 0, 0, 4, 1, 5, 1, 4, 1, 5, 1, 4, 1]] | ||
| ) | ||
| self.assertAllEqual( | ||
| result[1], [[0, 0, 0, 0, 0, 4, 1, 5, 1, 4, 1, 5, 1, 4, 1, 5]] | ||
| ) | ||
| self.assertAllEqual( | ||
| result[2], | ||
| [ | ||
| [ | ||
| False, | ||
| False, | ||
| False, | ||
| False, | ||
| False, | ||
| True, | ||
| True, | ||
| True, | ||
| True, | ||
| True, | ||
| True, | ||
| True, | ||
| True, | ||
| True, | ||
| True, | ||
| True, | ||
| ] | ||
| ], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a couple of issues here:
- The test assertions for the output shapes appear to be incorrect. For a
sequence_lengthof 15, thecallmethod adjusts it to 17. The resultingsample_weight(result[2]) should have a length of 16, but the test asserts a length of 17. - The style guide recommends using the
self.run_preprocessor_test()helper for testing preprocessors, which is not used here.1
Please correct the assertions and refactor the test to use the standard test helper.
Style Guide References
Footnotes
-
The style guide requires using
self.run_preprocessor_test()for testing preprocessors. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Gemini number is incorrect
| def test_generate(self): | ||
| """ | ||
| Test text generation functionality. | ||
| """ | ||
| causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) | ||
| prompt = ["hello world"] | ||
| output = causal_lm.generate(prompt, 16) | ||
| self.assertTrue(isinstance(output[0], str)) | ||
| self.assertTrue(isinstance(output, list)) | ||
|
|
||
| prompt = "hello world" | ||
| output = causal_lm.generate(prompt, 16) | ||
| self.assertTrue(isinstance(output, str)) | ||
|
|
||
| def test_generate_strip_prompt(self): | ||
| """ | ||
| Test that generated text can strip the prompt from output. | ||
| """ | ||
| prompt = ["hello world"] | ||
| causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) | ||
| output = causal_lm.generate(prompt, 16, strip_prompt=True) | ||
| self.assertFalse(output[0].startswith(prompt[0])) | ||
|
|
||
| def test_generate_compilation(self): | ||
| """ | ||
| Test that the generate function compiles correctly and | ||
| reuses compiled functions. | ||
| """ | ||
| causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) | ||
| causal_lm.generate(["hello world"], 16) | ||
| first_fn = causal_lm.generate_function | ||
| causal_lm.generate(["hello world"], 16) | ||
| second_fn = causal_lm.generate_function | ||
| self.assertEqual(first_fn, second_fn) | ||
|
|
||
| causal_lm.compile(sampler="greedy") | ||
| self.assertIsNone(causal_lm.generate_function) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the style guide, task model tests should use the self.run_task_test() helper method.1 This ensures comprehensive testing, including the attached preprocessor functionality. Please refactor the tests to use this helper.
| def test_generate(self): | |
| """ | |
| Test text generation functionality. | |
| """ | |
| causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) | |
| prompt = ["hello world"] | |
| output = causal_lm.generate(prompt, 16) | |
| self.assertTrue(isinstance(output[0], str)) | |
| self.assertTrue(isinstance(output, list)) | |
| prompt = "hello world" | |
| output = causal_lm.generate(prompt, 16) | |
| self.assertTrue(isinstance(output, str)) | |
| def test_generate_strip_prompt(self): | |
| """ | |
| Test that generated text can strip the prompt from output. | |
| """ | |
| prompt = ["hello world"] | |
| causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) | |
| output = causal_lm.generate(prompt, 16, strip_prompt=True) | |
| self.assertFalse(output[0].startswith(prompt[0])) | |
| def test_generate_compilation(self): | |
| """ | |
| Test that the generate function compiles correctly and | |
| reuses compiled functions. | |
| """ | |
| causal_lm = RWKV7CausalLM(self.backbone, self.preprocessor) | |
| causal_lm.generate(["hello world"], 16) | |
| first_fn = causal_lm.generate_function | |
| causal_lm.generate(["hello world"], 16) | |
| second_fn = causal_lm.generate_function | |
| self.assertEqual(first_fn, second_fn) | |
| causal_lm.compile(sampler="greedy") | |
| self.assertIsNone(causal_lm.generate_function) | |
| def test_task_basics(self): | |
| """ | |
| Test CausalLM task with the preprocessor. | |
| """ | |
| self.run_task_test( | |
| cls=RWKV7CausalLM, | |
| init_kwargs=self.init_kwargs, | |
| train_data=( | |
| ["hello world", "the quick brown fox"], | |
| ), | |
| expected_output_shape=(2, 16, 5), | |
| ) |
Style Guide References
Footnotes
-
The style guide requires using
self.run_task_test()for testing task models with preprocessors. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the tests modified based on LLaMA Casual LM. I think Gemini's review is incorrect.
| Examples: | ||
| ```python | ||
| # Initialize the tokenizer and load assets from a local path. | ||
| tokenizer = RWKVTokenizer() | ||
| tokenizer.load_assets(rwkv_path) | ||
|
|
||
| # Create a preprocessor with a sequence length of 8. | ||
| preprocessor = RWKV7CausalLMPreprocessor(tokenizer, sequence_length=8) | ||
|
|
||
| # Initialize the model with a backbone and preprocessor. | ||
| causal_lm = RWKV7CausalLM(backbone, preprocessor) | ||
|
|
||
| prompts = ["Bubble sort\n```python", "Hello World\n```python\n"] | ||
|
|
||
| causal_lm.compile(sampler="greedy") | ||
|
|
||
| outputs = causal_lm.generate(prompts, max_length=128) | ||
| for out in outputs: | ||
| print(out) | ||
| print("-" * 100) | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example in the docstring is not self-contained and cannot be run as-is, as it references undefined variables like RWKVTokenizer, rwkv_path, and backbone.1 Per the style guide, examples should be comprehensive and runnable. Please update it to be a complete code snippet.
Style Guide References
Footnotes
-
Docstrings must include comprehensive examples showing usage patterns. ↩
Hello, I have modified the code according to Gemini's suggestions. Please add a manual review |
RWKV7 is one of the strongest RNN models available today, and we now provide a full implementation for it in keras_hub.
📚 References
🔗 Pre-trained Checkpoints (ModelScope)
Numerical-verification and Inference Example notebook
This is the first modern RNN architecture in keras_hub. With the resurgence of recurrent models, more pre-trained RNN backbones will follow; hence this PR also serves as a reference implementation for future work.
Current progress