Skip to content

DeepSeek MTP #913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

DeepSeek MTP #913

wants to merge 23 commits into from

Conversation

shihaobai
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, gemini-code-assist here to provide a summary of this pull request titled "DeepSeek MTP". Based on the title and the extensive code changes, it appears this PR introduces support for the DeepSeek MTP (Multi-Task Prediction) model, specifically implementing a speculative decoding approach. The changes involve adding new components for MTP-specific logic, integrating this into the core model inference flow, and updating the server backend and CLI to support the new model and speculative decoding parameters.

Highlights

  • DeepSeek MTP Model Support: Adds a new model implementation (Deepseek3MTPModel) specifically for the DeepSeek MTP architecture, inheriting from the Deepseek2 model.
  • Speculative Decoding Implementation: Introduces a framework for speculative decoding, including new enums (SpeculativeDecodeAlgorithm), batch objects (ModelInput, ModelOutput), and pre-processing logic (mtp_pre_process.py). This allows the model to generate multiple tokens speculatively in a single step.
  • MTP-Specific Components: Adds MTP-specific memory management (Deepseek3MTPMemoryManager) and pre/post layer inference logic (Deepseek3MTPPreLayerInfer, Deepseek3MTPPreAndPostLayerWeight) to handle the unique aspects of the DeepSeek MTP model.
  • Integration into Inference Pipeline: Modifies the core BaseModel forward pass, InferStateInfo, and CUDA graph handling to accommodate the new ModelInput/ModelOutput structure and the speculative decoding algorithm (spec_algo). Updates the copy_kv_index_to_req kernel to support variable decode lengths.
  • Backend and CLI Updates: Adds new backend implementations (ContinuesBatchWithMTPBackend, ChunkedPrefillWithMTPBackend) to manage the speculative decoding workflow. Introduces new CLI arguments (--spec_algo, --spec_model_dir, --spec_step) to control speculative decoding settings.
  • Testing and Logging: Adds a new test script (test/model/model_infer_mtp.py) to specifically test the MTP model inference with speculative decoding. Updates logging in the HTTP server manager to include avg_token_per_step based on MTP acceptance length.

Changelog

Click here to see the changelog
  • lightllm/common/basemodel/basemodel.py
    • Imported SpeculativeDecodeAlgorithm, ModelInput, ModelOutput (L23-24).
    • Added spec_algo attribute to BaseModel initialization (L77-78).
    • Refactored forward method to accept ModelInput object (L235).
    • Introduced _create_inferstate helper method (L243-275).
    • Updated _prefill and _decode methods to use ModelInput and _create_inferstate (L277-310).
    • Modified return type of _context_forward and _token_forward to ModelOutput and included hidden states conditionally (L456-459, L484-487).
    • Renamed predict_logics to predict_logits in various methods (L360, L368, L372, L429, L433, L507, L512, L526, L530).
    • Updated _check_max_len_infer to use ModelInput and ModelOutput (L550-565).
  • lightllm/common/basemodel/batch_objs.py
    • Added new file defining ModelInput and ModelOutput dataclasses (L1-23).
  • lightllm/common/basemodel/cuda_graph.py
    • Imported ModelInput, ModelOutput, SpeculativeDecodeAlgorithm (L8-9).
    • Updated _capture_decode and _replay to handle ModelOutput (L51-54, L99-103).
    • Added decode_len calculation in warmup based on spec_algo (L133).
    • Updated warmup to use ModelInput and ModelOutput and handle potential hidden states for MTP modules (L145-166, L175-191).
  • lightllm/common/basemodel/infer_struct.py
    • Imported SpeculativeDecodeAlgorithm (L8).
    • Added spec_algo and spec_info attributes to InferStateInfo (L59-60).
  • lightllm/common/basemodel/triton_kernel/copy_kv_index_to_req.py
    • Modified the Triton kernel _fwd_kernel_copy_kv_index_to_req to handle decode_len and batch size in grid calculation (L12-17).
    • Updated copy_kv_index_to_req function signature to accept decode_len and adjusted grid calculation (L24-30).
    • Added a test case for copy_kv_index_to_req with decode_len 1 and 2 (L46-63).
  • lightllm/common/spec_info.py
    • Added new file defining SpeculativeDecodeAlgorithm IntEnum with NONE, MTP, MTP_MOUDLE values (L1-7).
    • Added helper methods is_none, is_mtp, is_mtp_module (L9-16).
    • Added from_string static method to parse algorithm names (L18-27).
    • Added decode_len method to return the expected decode length for each algorithm (L29-35).
  • lightllm/models/deepseek2/infer_struct.py
    • Imported SpeculativeDecodeAlgorithm (L6).
  • lightllm/models/deepseek_mtp/deepseek3_mtp_mem_manager.py
    • Added new file defining Deepseek3MTPMemoryManager inheriting from Deepseek2MemoryManager (L10).
    • Initializes memory manager with specific head parameters and shared memory for token count (L11-41).
  • lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py
    • Added new file defining Deepseek3MTPPreLayerInfer inheriting from LlamaPreLayerInfer (L16).
    • Implemented mtp_context_forward and mtp_token_forward methods for MTP-specific pre-layer logic involving hidden states and projection (L25-58).
    • Overrode context_forward and token_forward to call the MTP-specific methods (L60-70).
  • lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py
    • Added new file defining Deepseek3MTPPreAndPostLayerWeight inheriting from LlamaPreAndPostLayerWeight (L5).
    • Defined MTP-specific weights (eh_proj_weight_, enorm_weight_, hnorm_weight_) and loading logic from HF weights (L14-21).
    • Implemented verify_load for the new weights (L23-28).
  • lightllm/models/deepseek_mtp/model.py
    • Added new file defining Deepseek3MTPModel inheriting from Deepseek2TpPartModel (L17).
    • Set MTP-specific pre/post weight and layer infer classes (L19-20).
    • Modified initialization to accept main_model and last_mtp_module (L22-26).
    • Overrode _init_req_manager to share the main model's request manager (L28-40).
    • Overrode _init_mem_manager to use Deepseek3MTPMemoryManager (L42-51).
    • Overrode _init_weights to share main model's embedding, lm_head, and final norm weights (L53-57).
  • lightllm/models/llama/layer_infer/post_layer_infer.py
    • Minor reordering of lines in token_forward (L68-69).
  • lightllm/server/api_cli.py
    • Added command-line arguments --spec_algo, --spec_model_dir, and --spec_step for speculative decoding configuration (L384-401).
  • lightllm/server/core/objs/req.py
    • Added mtp_accepted_len field to the Req ctypes structure (L98).
    • Initialized mtp_accepted_len to 0 in the init method (L150).
  • lightllm/server/httpserver/manager.py
    • Added calculation and logging for avg_token_per_step using mtp_accepted_len (L544-545, L555).
    • Included mtp_accepted_len in the metadata passed to the detokenization process (L658).
  • lightllm/server/router/manager.py
    • Added spec_step attribute initialized from args (L65).
    • Passed spec_algo, spec_weight_dir, and spec_step to model_rpc_client.init_model (L180-182).
    • Modified prefill and decode batch handling to send spec_step + 1 None packages to detokenization (L391-393, L410-412).
  • lightllm/server/router/model_infer/infer_batch.py
    • Added cur_accepted_len attribute to InferReq (L260).
    • Added get_chunked_input_token_ids_shift method (L323-328).
    • Added set_total_accepted_len method to update the shared memory request object (L341-342).
  • lightllm/server/router/model_infer/mode_backend/init.py
    • Imported ContinuesBatchWithMTPBackend (L15).
  • lightllm/server/router/model_infer/mode_backend/base_backend.py
    • Added spec_algo to the model_kvargs dictionary passed to get_model (L113).
  • lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py
    • Updated prepare_decode_inputs and prepare_prefill_inputs calls to use model_input instead of kwargs (L41, L58).
    • Updated model forward calls to use model_input and handle model_output (L42, L61).
    • Updated sample calls to use model_output.logits (L46, L65).
    • Deleted model_output after sampling (L52, L71).
  • lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_mtp.py
    • Added new file defining ChunkedPrefillWithMTPBackend inheriting from ContinuesBatchWithMTPBackend (L27).
    • Implemented decode method with MTP-specific prefill and decode logic, including calling draft models and verifying accepted tokens (L31-120).
    • Implemented verify method to compare main model output with draft tokens and determine accepted length (L122-145).
    • Implemented _save_draft_token_ids to store draft model outputs (L147-154).
  • lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py
    • Updated prepare_prefill_inputs and prepare_decode_inputs calls to use model_input instead of kwargs (L33, L51).
    • Updated model forward calls to use model_input (L36, L52).
  • lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py
    • Added new file defining ContinuesBatchWithMTPBackend inheriting from ModeBackend (L29).
    • Implemented init_model to load the main model and multiple draft models based on spec_step and spec_model_dir (L34-81).
    • Implemented prefill and decode methods with MTP-specific logic, similar to the chunked prefill version, including calling draft models, verifying, and saving draft tokens (L83-178).
    • Implemented verify and _save_draft_token_ids methods for MTP speculative decoding (L180-215).
  • lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py
    • Added new file defining MTP-specific pre-processing functions (L1-81).
    • Implemented prepare_mtp_prefill_inputs to prepare inputs for draft model prefill, potentially shifting input IDs (L11-29).
    • Implemented prepare_draft_main_model_decode_inputs to prepare inputs for the main model decode step in MTP, including draft tokens (L32-81).
  • lightllm/server/router/model_infer/model_rpc.py
    • Imported ContinuesBatchWithMTPBackend (L24).
    • Added logic in init_model to select ContinuesBatchWithMTPBackend if spec_algo is 'MTP' and disable_chunked_prefill is true (L157-159).
  • lightllm/utils/dist_utils.py
    • Added device0_print helper function for printing only on device 0 (L196-198).
  • test/model/model_infer_mtp.py
    • Added new file with MTP-specific model inference test logic (L1-282).
    • Includes init_mtp_model to load draft models (L20-53).
    • Includes run_forward_once to perform a single prefill and decode step with main and draft models (L113-246).
    • Includes tppart_model_infer to run the test inference in a distributed environment (L249-282).
  • test/model/test_model.py
    • Imported test_model_inference_mtp (L8).
    • Modified test_model_infer to call test_model_inference_mtp if spec_algo is 'MTP' (L19-21).
    • Modified default values for --batch_size, --input_len, and --output_len in the argument parser (L30-32).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Speculative steps,
Drafting tokens, quick and bold,
Main model checks all.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Multi-Token Prediction (MTP) speculative decoding, primarily for DeepSeek models. The changes involve significant refactoring of the model's forward pass to use ModelInput and ModelOutput dataclasses, which improves code organization. New classes and methods are added to support the MTP algorithm, including specialized memory managers, layer inference logic, and weight handling for draft models.

The core MTP logic involves the main model making a prediction, followed by one or more draft models making subsequent predictions using the main model's hidden states. These draft predictions are then verified.

Overall, the changes are substantial and introduce a complex new feature. The refactoring to ModelInput/ModelOutput is a good step. One potential issue was found in the logging logic that needs addressing. The renaming of predict_logics to predict_logits throughout the codebase is a welcome improvement for clarity.

Summary of Findings

  • Potential ZeroDivisionError in Logging: In lightllm/server/httpserver/manager.py, the calculation of avg_token_per_step (line 544) can lead to a ZeroDivisionError if out_token_counter equals metadata["mtp_accepted_len"]. This needs to be handled to prevent runtime errors.
  • Refactoring of Model Forward Pass: The forward method in basemodel.py and its usage in various backends have been refactored to use ModelInput and ModelOutput dataclasses. This is a good change that improves code clarity and maintainability by grouping related parameters.
  • MTP Speculative Decoding Implementation: The core logic for MTP, including draft model handling, hidden state passing, and token verification, has been implemented. New classes like Deepseek3MTPModel, ContinuesBatchWithMTPBackend, and specialized pre/post processing functions support this.
  • CUDA Graph and Kernel Adjustments: Changes to CUDA graph handling and the copy_kv_index_to_req Triton kernel accommodate the multi-token nature of speculative decoding.
  • Type Hinting in batch_objs.py: In ModelOutput dataclass, torch.tensor is used for type hints. It's recommended to use torch.Tensor for consistency.

Merge Readiness

This pull request introduces a significant new feature (MTP speculative decoding) and includes substantial refactoring. While the overall structure and MTP logic appear sound, there is a high-severity issue in lightllm/server/httpserver/manager.py (potential ZeroDivisionError) that must be addressed before merging. Additionally, a medium-severity suggestion regarding type hinting in lightllm/common/basemodel/batch_objs.py should be considered.

Given these points, I recommend addressing the identified issues. I am not authorized to approve pull requests, so please ensure further review and approval from other maintainers after the changes are made.

@@ -540,6 +540,8 @@ async def _wait_to_token_package(
x_request_id = request.headers.get("X-Request-Id", "") if request is not None else ""
x_session_id = request.headers.get("X-Session-Id", "") if request is not None else ""
prompt_cache_ratio = prompt_cache_len / prompt_tokens

avg_token_per_step = out_token_counter / (out_token_counter - metadata["mtp_accepted_len"])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation for avg_token_per_step could lead to a ZeroDivisionError if out_token_counter is equal to metadata["mtp_accepted_len"]. This might happen if, for example, all tokens generated so far were accepted speculative tokens, or if out_token_counter is 0.

Could you add a check to prevent division by zero here? For example, you could set avg_token_per_step to a default value (like out_token_counter or float('inf') or None) or skip this calculation if the denominator is zero.

Comment on lines +22 to +23
logits: torch.tensor
hidden_states: torch.tensor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using torch.Tensor for type hints instead of torch.tensor for consistency with PyTorch's official type hinting. While torch.tensor is a function to create tensors, torch.Tensor is the type. This is a minor point but improves consistency.

Suggested change
logits: torch.tensor
hidden_states: torch.tensor
logits: torch.Tensor
hidden_states: torch.Tensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants