-
Notifications
You must be signed in to change notification settings - Fork 259
DeepSeek MTP #913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
DeepSeek MTP #913
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello team, gemini-code-assist here to provide a summary of this pull request titled "DeepSeek MTP". Based on the title and the extensive code changes, it appears this PR introduces support for the DeepSeek MTP (Multi-Task Prediction) model, specifically implementing a speculative decoding approach. The changes involve adding new components for MTP-specific logic, integrating this into the core model inference flow, and updating the server backend and CLI to support the new model and speculative decoding parameters.
Highlights
- DeepSeek MTP Model Support: Adds a new model implementation (
Deepseek3MTPModel
) specifically for the DeepSeek MTP architecture, inheriting from the Deepseek2 model. - Speculative Decoding Implementation: Introduces a framework for speculative decoding, including new enums (
SpeculativeDecodeAlgorithm
), batch objects (ModelInput
,ModelOutput
), and pre-processing logic (mtp_pre_process.py
). This allows the model to generate multiple tokens speculatively in a single step. - MTP-Specific Components: Adds MTP-specific memory management (
Deepseek3MTPMemoryManager
) and pre/post layer inference logic (Deepseek3MTPPreLayerInfer
,Deepseek3MTPPreAndPostLayerWeight
) to handle the unique aspects of the DeepSeek MTP model. - Integration into Inference Pipeline: Modifies the core
BaseModel
forward pass,InferStateInfo
, and CUDA graph handling to accommodate the newModelInput
/ModelOutput
structure and the speculative decoding algorithm (spec_algo
). Updates thecopy_kv_index_to_req
kernel to support variable decode lengths. - Backend and CLI Updates: Adds new backend implementations (
ContinuesBatchWithMTPBackend
,ChunkedPrefillWithMTPBackend
) to manage the speculative decoding workflow. Introduces new CLI arguments (--spec_algo
,--spec_model_dir
,--spec_step
) to control speculative decoding settings. - Testing and Logging: Adds a new test script (
test/model/model_infer_mtp.py
) to specifically test the MTP model inference with speculative decoding. Updates logging in the HTTP server manager to includeavg_token_per_step
based on MTP acceptance length.
Changelog
Click here to see the changelog
- lightllm/common/basemodel/basemodel.py
- Imported
SpeculativeDecodeAlgorithm
,ModelInput
,ModelOutput
(L23-24). - Added
spec_algo
attribute toBaseModel
initialization (L77-78). - Refactored
forward
method to acceptModelInput
object (L235). - Introduced
_create_inferstate
helper method (L243-275). - Updated
_prefill
and_decode
methods to useModelInput
and_create_inferstate
(L277-310). - Modified return type of
_context_forward
and_token_forward
toModelOutput
and included hidden states conditionally (L456-459, L484-487). - Renamed
predict_logics
topredict_logits
in various methods (L360, L368, L372, L429, L433, L507, L512, L526, L530). - Updated
_check_max_len_infer
to useModelInput
andModelOutput
(L550-565).
- Imported
- lightllm/common/basemodel/batch_objs.py
- Added new file defining
ModelInput
andModelOutput
dataclasses (L1-23).
- Added new file defining
- lightllm/common/basemodel/cuda_graph.py
- Imported
ModelInput
,ModelOutput
,SpeculativeDecodeAlgorithm
(L8-9). - Updated
_capture_decode
and_replay
to handleModelOutput
(L51-54, L99-103). - Added
decode_len
calculation inwarmup
based onspec_algo
(L133). - Updated
warmup
to useModelInput
andModelOutput
and handle potential hidden states for MTP modules (L145-166, L175-191).
- Imported
- lightllm/common/basemodel/infer_struct.py
- Imported
SpeculativeDecodeAlgorithm
(L8). - Added
spec_algo
andspec_info
attributes toInferStateInfo
(L59-60).
- Imported
- lightllm/common/basemodel/triton_kernel/copy_kv_index_to_req.py
- Modified the Triton kernel
_fwd_kernel_copy_kv_index_to_req
to handledecode_len
and batch size in grid calculation (L12-17). - Updated
copy_kv_index_to_req
function signature to acceptdecode_len
and adjusted grid calculation (L24-30). - Added a test case for
copy_kv_index_to_req
withdecode_len
1 and 2 (L46-63).
- Modified the Triton kernel
- lightllm/common/spec_info.py
- Added new file defining
SpeculativeDecodeAlgorithm
IntEnum withNONE
,MTP
,MTP_MOUDLE
values (L1-7). - Added helper methods
is_none
,is_mtp
,is_mtp_module
(L9-16). - Added
from_string
static method to parse algorithm names (L18-27). - Added
decode_len
method to return the expected decode length for each algorithm (L29-35).
- Added new file defining
- lightllm/models/deepseek2/infer_struct.py
- Imported
SpeculativeDecodeAlgorithm
(L6).
- Imported
- lightllm/models/deepseek_mtp/deepseek3_mtp_mem_manager.py
- Added new file defining
Deepseek3MTPMemoryManager
inheriting fromDeepseek2MemoryManager
(L10). - Initializes memory manager with specific head parameters and shared memory for token count (L11-41).
- Added new file defining
- lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py
- Added new file defining
Deepseek3MTPPreLayerInfer
inheriting fromLlamaPreLayerInfer
(L16). - Implemented
mtp_context_forward
andmtp_token_forward
methods for MTP-specific pre-layer logic involving hidden states and projection (L25-58). - Overrode
context_forward
andtoken_forward
to call the MTP-specific methods (L60-70).
- Added new file defining
- lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py
- Added new file defining
Deepseek3MTPPreAndPostLayerWeight
inheriting fromLlamaPreAndPostLayerWeight
(L5). - Defined MTP-specific weights (
eh_proj_weight_
,enorm_weight_
,hnorm_weight_
) and loading logic from HF weights (L14-21). - Implemented
verify_load
for the new weights (L23-28).
- Added new file defining
- lightllm/models/deepseek_mtp/model.py
- Added new file defining
Deepseek3MTPModel
inheriting fromDeepseek2TpPartModel
(L17). - Set MTP-specific pre/post weight and layer infer classes (L19-20).
- Modified initialization to accept
main_model
andlast_mtp_module
(L22-26). - Overrode
_init_req_manager
to share the main model's request manager (L28-40). - Overrode
_init_mem_manager
to useDeepseek3MTPMemoryManager
(L42-51). - Overrode
_init_weights
to share main model's embedding, lm_head, and final norm weights (L53-57).
- Added new file defining
- lightllm/models/llama/layer_infer/post_layer_infer.py
- Minor reordering of lines in
token_forward
(L68-69).
- Minor reordering of lines in
- lightllm/server/api_cli.py
- Added command-line arguments
--spec_algo
,--spec_model_dir
, and--spec_step
for speculative decoding configuration (L384-401).
- Added command-line arguments
- lightllm/server/core/objs/req.py
- Added
mtp_accepted_len
field to theReq
ctypes structure (L98). - Initialized
mtp_accepted_len
to 0 in theinit
method (L150).
- Added
- lightllm/server/httpserver/manager.py
- Added calculation and logging for
avg_token_per_step
usingmtp_accepted_len
(L544-545, L555). - Included
mtp_accepted_len
in the metadata passed to the detokenization process (L658).
- Added calculation and logging for
- lightllm/server/router/manager.py
- Added
spec_step
attribute initialized from args (L65). - Passed
spec_algo
,spec_weight_dir
, andspec_step
tomodel_rpc_client.init_model
(L180-182). - Modified prefill and decode batch handling to send
spec_step + 1
None packages to detokenization (L391-393, L410-412).
- Added
- lightllm/server/router/model_infer/infer_batch.py
- Added
cur_accepted_len
attribute toInferReq
(L260). - Added
get_chunked_input_token_ids_shift
method (L323-328). - Added
set_total_accepted_len
method to update the shared memory request object (L341-342).
- Added
- lightllm/server/router/model_infer/mode_backend/init.py
- Imported
ContinuesBatchWithMTPBackend
(L15).
- Imported
- lightllm/server/router/model_infer/mode_backend/base_backend.py
- Added
spec_algo
to themodel_kvargs
dictionary passed toget_model
(L113).
- Added
- lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
- Updated
prepare_decode_inputs
andprepare_prefill_inputs
calls to usemodel_input
instead ofkwargs
(L41, L58). - Updated model forward calls to use
model_input
and handlemodel_output
(L42, L61). - Updated sample calls to use
model_output.logits
(L46, L65). - Deleted
model_output
after sampling (L52, L71).
- Updated
- lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_mtp.py
- Added new file defining
ChunkedPrefillWithMTPBackend
inheriting fromContinuesBatchWithMTPBackend
(L27). - Implemented
decode
method with MTP-specific prefill and decode logic, including calling draft models and verifying accepted tokens (L31-120). - Implemented
verify
method to compare main model output with draft tokens and determine accepted length (L122-145). - Implemented
_save_draft_token_ids
to store draft model outputs (L147-154).
- Added new file defining
- lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py
- Updated
prepare_prefill_inputs
andprepare_decode_inputs
calls to usemodel_input
instead ofkwargs
(L33, L51). - Updated model forward calls to use
model_input
(L36, L52).
- Updated
- lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py
- Added new file defining
ContinuesBatchWithMTPBackend
inheriting fromModeBackend
(L29). - Implemented
init_model
to load the main model and multiple draft models based onspec_step
andspec_model_dir
(L34-81). - Implemented
prefill
anddecode
methods with MTP-specific logic, similar to the chunked prefill version, including calling draft models, verifying, and saving draft tokens (L83-178). - Implemented
verify
and_save_draft_token_ids
methods for MTP speculative decoding (L180-215).
- Added new file defining
- lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py
- Added new file defining MTP-specific pre-processing functions (L1-81).
- Implemented
prepare_mtp_prefill_inputs
to prepare inputs for draft model prefill, potentially shifting input IDs (L11-29). - Implemented
prepare_draft_main_model_decode_inputs
to prepare inputs for the main model decode step in MTP, including draft tokens (L32-81).
- lightllm/server/router/model_infer/model_rpc.py
- Imported
ContinuesBatchWithMTPBackend
(L24). - Added logic in
init_model
to selectContinuesBatchWithMTPBackend
ifspec_algo
is 'MTP' anddisable_chunked_prefill
is true (L157-159).
- Imported
- lightllm/utils/dist_utils.py
- Added
device0_print
helper function for printing only on device 0 (L196-198).
- Added
- test/model/model_infer_mtp.py
- Added new file with MTP-specific model inference test logic (L1-282).
- Includes
init_mtp_model
to load draft models (L20-53). - Includes
run_forward_once
to perform a single prefill and decode step with main and draft models (L113-246). - Includes
tppart_model_infer
to run the test inference in a distributed environment (L249-282).
- test/model/test_model.py
- Imported
test_model_inference_mtp
(L8). - Modified
test_model_infer
to calltest_model_inference_mtp
ifspec_algo
is 'MTP' (L19-21). - Modified default values for
--batch_size
,--input_len
, and--output_len
in the argument parser (L30-32).
- Imported
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Speculative steps,
Drafting tokens, quick and bold,
Main model checks all.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Multi-Token Prediction (MTP) speculative decoding, primarily for DeepSeek models. The changes involve significant refactoring of the model's forward pass to use ModelInput
and ModelOutput
dataclasses, which improves code organization. New classes and methods are added to support the MTP algorithm, including specialized memory managers, layer inference logic, and weight handling for draft models.
The core MTP logic involves the main model making a prediction, followed by one or more draft models making subsequent predictions using the main model's hidden states. These draft predictions are then verified.
Overall, the changes are substantial and introduce a complex new feature. The refactoring to ModelInput
/ModelOutput
is a good step. One potential issue was found in the logging logic that needs addressing. The renaming of predict_logics
to predict_logits
throughout the codebase is a welcome improvement for clarity.
Summary of Findings
- Potential ZeroDivisionError in Logging: In
lightllm/server/httpserver/manager.py
, the calculation ofavg_token_per_step
(line 544) can lead to aZeroDivisionError
ifout_token_counter
equalsmetadata["mtp_accepted_len"]
. This needs to be handled to prevent runtime errors. - Refactoring of Model Forward Pass: The
forward
method inbasemodel.py
and its usage in various backends have been refactored to useModelInput
andModelOutput
dataclasses. This is a good change that improves code clarity and maintainability by grouping related parameters. - MTP Speculative Decoding Implementation: The core logic for MTP, including draft model handling, hidden state passing, and token verification, has been implemented. New classes like
Deepseek3MTPModel
,ContinuesBatchWithMTPBackend
, and specialized pre/post processing functions support this. - CUDA Graph and Kernel Adjustments: Changes to CUDA graph handling and the
copy_kv_index_to_req
Triton kernel accommodate the multi-token nature of speculative decoding. - Type Hinting in
batch_objs.py
: InModelOutput
dataclass,torch.tensor
is used for type hints. It's recommended to usetorch.Tensor
for consistency.
Merge Readiness
This pull request introduces a significant new feature (MTP speculative decoding) and includes substantial refactoring. While the overall structure and MTP logic appear sound, there is a high-severity issue in lightllm/server/httpserver/manager.py
(potential ZeroDivisionError
) that must be addressed before merging. Additionally, a medium-severity suggestion regarding type hinting in lightllm/common/basemodel/batch_objs.py
should be considered.
Given these points, I recommend addressing the identified issues. I am not authorized to approve pull requests, so please ensure further review and approval from other maintainers after the changes are made.
@@ -540,6 +540,8 @@ async def _wait_to_token_package( | |||
x_request_id = request.headers.get("X-Request-Id", "") if request is not None else "" | |||
x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" | |||
prompt_cache_ratio = prompt_cache_len / prompt_tokens | |||
|
|||
avg_token_per_step = out_token_counter / (out_token_counter - metadata["mtp_accepted_len"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for avg_token_per_step
could lead to a ZeroDivisionError
if out_token_counter
is equal to metadata["mtp_accepted_len"]
. This might happen if, for example, all tokens generated so far were accepted speculative tokens, or if out_token_counter
is 0.
Could you add a check to prevent division by zero here? For example, you could set avg_token_per_step
to a default value (like out_token_counter
or float('inf')
or None
) or skip this calculation if the denominator is zero.
logits: torch.tensor | ||
hidden_states: torch.tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using torch.Tensor
for type hints instead of torch.tensor
for consistency with PyTorch's official type hinting. While torch.tensor
is a function to create tensors, torch.Tensor
is the type. This is a minor point but improves consistency.
logits: torch.tensor | |
hidden_states: torch.tensor | |
logits: torch.Tensor | |
hidden_states: torch.Tensor |
No description provided.