diff --git a/common/common.cpp b/common/common.cpp index 224f3e0df..888f5e421 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -195,10 +195,38 @@ bool common_params_speculative::has_stage_type(common_speculative_type stage_typ }); } +void common_params_speculative::remove_stage_type(common_speculative_type stage_type) { + stages.erase(std::remove_if(stages.begin(), stages.end(), [stage_type](const common_speculative_stage_params & stage) { + return stage.type == stage_type; + }), stages.end()); + + if (type == stage_type) { + const auto resolved = get_resolved_stages(); + type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type; + } +} + bool common_params_speculative::has_composite_stage_chain() const { return get_resolved_stages().size() > 1; } +bool common_params_speculative::needs_dft_model() const { + return has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT) || + (has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && has_dft()); +} + +void common_params_speculative::clear_dft() { + if (model_dft != nullptr) { + llama_free_model(model_dft); + model_dft = nullptr; + } + + model.clear(); + params.clear(); + mparams_dft.path.clear(); + cparams_dft = llama_context_default_params(); +} + int32_t common_params_speculative::get_max_stage_n_max() const { const auto resolved = get_resolved_stages(); if (resolved.empty()) { diff --git a/common/common.h b/common/common.h index 0125ccd59..070d3b8a5 100644 --- a/common/common.h +++ b/common/common.h @@ -252,7 +252,10 @@ struct common_params_speculative { common_params_speculative with_stage_overrides(const common_speculative_stage_params & stage) const; bool has_stage_chain() const; bool has_stage_type(common_speculative_type stage_type) const; + void remove_stage_type(common_speculative_type stage_type); bool has_composite_stage_chain() const; + bool needs_dft_model() const; + void clear_dft(); int32_t get_max_stage_n_max() const; int32_t get_min_usable_stage_n_min() const; diff --git a/common/speculative.cpp b/common/speculative.cpp index 758202ac2..357bc881f 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -47,6 +47,18 @@ const std::map common_speculative_typ {"suffix", COMMON_SPECULATIVE_TYPE_SUFFIX} }; +void common_speculative_checkpoint::clear() { + valid = false; + per_step_enabled = false; + n_past = 0; + sampled = LLAMA_TOKEN_NULL; + + if (sampler != nullptr) { + common_sampler_free(sampler); + sampler = nullptr; + } +} + struct common_speculative_config { common_speculative_stage_params stage; common_speculative_type type; @@ -172,6 +184,17 @@ struct common_speculative_state_mtp; static common_speculative_state_mtp * common_speculative_get_mtp_state(common_speculative * spec); static const common_speculative_state_mtp * common_speculative_get_mtp_state(const common_speculative * spec); static void mtp_invalidate_cached_drafts(common_speculative_state_mtp & state); +static bool common_speculative_checkpoint_save( + common_speculative_checkpoint & ckpt, + llama_model * model, + llama_context * ctx, + common_sampler * sampler_src, + const common_params_sampling & sparams, + llama_seq_id seq_id, + llama_pos n_past, + llama_token sampled, + int max_tokens, + int ckpt_mode); static std::vector mtp_speculative_gen_draft( common_speculative_state_mtp & state, @@ -1002,12 +1025,17 @@ struct common_speculative_state_suffix : public common_speculative_state { }; struct common_speculative { + common_speculative_checkpoint checkpoint; std::vector configs; // resolved stage config for each implementation std::vector> impls; // list of implementations to use and their states common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) std::unique_ptr tuner; int last_n_drafted = 0; int64_t t_step_start_us = 0; + + ~common_speculative() { + checkpoint.clear(); + } }; static bool common_speculative_stage_chain_matches( @@ -1315,6 +1343,7 @@ common_speculative * common_speculative_init( } auto * result = new common_speculative { + /* .checkpoint = */ {}, /* .configs = */ std::move(configs), /* .impls = */ std::move(impls) }; @@ -1340,6 +1369,170 @@ common_speculative * common_speculative_init( return result; } +common_speculative_init_status common_speculative_try_init( + common_params_speculative & params, + llama_context * ctx_tgt, + common_speculative ** out_spec) { + if (out_spec != nullptr) { + *out_spec = nullptr; + } + + if (!params.has_stage_chain()) { + return COMMON_SPECULATIVE_INIT_SKIPPED; + } + + common_speculative * spec = common_speculative_init(params, ctx_tgt); + if (spec != nullptr) { + if (out_spec != nullptr) { + *out_spec = spec; + } + return COMMON_SPECULATIVE_INIT_READY; + } + + const llama_model * model = ctx_tgt != nullptr ? llama_get_model(ctx_tgt) : nullptr; + if (model != nullptr && llama_model_has_recurrent(model)) { + return COMMON_SPECULATIVE_INIT_ERR_RECURRENT; + } + if (params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) { + return COMMON_SPECULATIVE_INIT_ERR_MTP; + } + return COMMON_SPECULATIVE_INIT_ERR_GENERIC; +} + +void common_speculative_prepare_startup( + gpt_params & params_base, + bool allow_parallel_mtp) { + auto & params = params_base.speculative; + + if (!allow_parallel_mtp && params_base.n_parallel > 1 && params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) { + LOG_WRN("%s: MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption. n_parallel=%d, stage_chain=%s\n", + __func__, params_base.n_parallel, common_speculative_stage_chain_to_str(params).c_str()); + params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP); + } + + if (!params.needs_dft_model()) { + params.clear_dft(); + } + + params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); +} + +bool common_speculative_finalize_startup( + gpt_params & params_base, + const llama_model * model) { + auto & params = params_base.speculative; + + if (!params.needs_dft_model()) { + params.clear_dft(); + } + + if (params.has_dft()) { + LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n"); + if (!common_speculative_load_draft_model(params, params_base)) { + return false; + } + } + + params_base.has_mtp = params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); + const bool has_external_mtp = params_base.has_mtp && + llama_model_is_gemma4_mtp_assistant(params.model_dft); + + params_base.has_mtp = common_speculative_prepare_mtp_runtime( + params, + params_base, + model, + has_external_mtp); + if (params_base.has_mtp) { + params_base.pooling_type = LLAMA_POOLING_TYPE_NONE; + } + + return true; +} + +bool common_speculative_load_draft_model( + common_params_speculative & params, + const gpt_params & params_base) { + if (!params.has_dft()) { + return true; + } + + gpt_params params_dft; + params_dft.devices = params.devices; + params_dft.model = params.model; + params_dft.main_gpu = params_base.main_gpu; + params_dft.n_gpu_layers = params.n_gpu_layers; + params_dft.rpc_servers = params_base.rpc_servers; + params_dft.cache_type_k = params.cache_type_k.empty() ? params_base.cache_type_k : params.cache_type_k; + params_dft.cache_type_v = params.cache_type_v.empty() ? params_base.cache_type_v : params.cache_type_v; + params_dft.flash_attn = params_base.flash_attn; + params_dft.k_cache_hadamard = params_base.k_cache_hadamard; + params_dft.v_cache_hadamard = params_base.v_cache_hadamard; + + if (!params.params.empty()) { + auto [argc, argv] = parse_command_line("llama-server " + params.params); + if (!gpt_params_parse(argc, argv, params_dft)) { + gpt_params_print_usage(argc, argv, params_dft); + free_command_line(argc, argv); + return false; + } + free_command_line(argc, argv); + } + + LOG_INF("%s: loading draft model '%s'\n", __func__, params_dft.model.c_str()); + + if (params_dft.n_ctx == 0) { + params_dft.n_ctx = params.n_ctx; + } + params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx; + params_dft.n_parallel = 1; + params_dft.n_batch = params_dft.n_ctx; + + params.mparams_dft.path = params_dft.model; + + llama_model_params mparams_dft = common_model_params_to_llama(params_dft); + llama_model * loaded_model = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft); + if (loaded_model == nullptr) { + LOG_ERR("%s: failed to load draft model '%s'\n", __func__, params.model.c_str()); + return false; + } + + params.model_dft = loaded_model; + params.cparams_dft = common_context_params_to_llama(params_dft); + return true; +} + +bool common_speculative_prepare_mtp_runtime( + common_params_speculative & params, + const gpt_params & params_base, + const llama_model * model, + bool has_external_mtp) { + if (!params.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) { + return false; + } + + if (llama_model_n_nextn_layer(model) == 0 && !has_external_mtp) { + LOG_WRN("%s: MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain.\n", + __func__); + params.remove_stage_type(COMMON_SPECULATIVE_TYPE_MTP); + if (!params.needs_dft_model()) { + params.clear_dft(); + } + return false; + } + + if (!has_external_mtp) { + gpt_params params_mtp = params_base; + params_mtp.pooling_type = LLAMA_POOLING_TYPE_NONE; + params.cparams_dft = common_context_params_to_llama(params_mtp); + } + + params.cparams_dft.mtp = true; + params.cparams_dft.mtp_op_type = MTP_OP_WARMUP; + params.cparams_dft.embeddings = true; + + return true; +} + void common_speculative_free(common_speculative * spec) { if (spec == nullptr) { return; @@ -1353,6 +1546,11 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr return; } + spec->checkpoint.clear(); + spec->curr_impl = nullptr; + spec->last_n_drafted = 0; + spec->t_step_start_us = 0; + for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); impl->begin(prompt); @@ -1456,6 +1654,34 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { } } +bool common_speculative_before_draft( + common_speculative * spec, + llama_model * model, + llama_context * ctx, + common_sampler * sampler_src, + const common_params_sampling & sparams, + llama_seq_id seq_id, + llama_pos n_past, + llama_token sampled, + int max_tokens, + int ckpt_mode) { + if (spec == nullptr) { + return false; + } + + return common_speculative_checkpoint_save( + spec->checkpoint, + model, + ctx, + sampler_src, + sparams, + seq_id, + n_past, + sampled, + max_tokens, + ckpt_mode); +} + static bool common_speculative_has_type(const common_speculative * spec, common_speculative_type type) { if (spec == nullptr) { return false; @@ -1663,6 +1889,38 @@ bool common_speculative_ensure_sequence_hidden( return common_speculative_capture_output_hidden(spec, ctx, -1, seq_id, pos); } +common_speculative_draft_result common_speculative_draft_ex( + common_speculative * spec, + llama_context * ctx, + common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_pos draft_base_pos, + llama_seq_id draft_seq_id) { + common_speculative_draft_result result = {}; + + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) { + if (!common_speculative_ensure_sequence_hidden(spec, ctx, draft_seq_id, draft_base_pos - 1)) { + LOG_ERR("%s: seq_id=%d MTP hidden state is empty during speculation\n", + __func__, (int) draft_seq_id); + return result; + } + } + + result.tokens = common_speculative_draft( + spec, + params, + prompt_tgt, + id_last, + draft_base_pos, + draft_seq_id); + result.type = spec != nullptr && spec->curr_impl != nullptr + ? spec->curr_impl->type + : COMMON_SPECULATIVE_TYPE_NONE; + + return result; +} + int32_t common_speculative_on_target_seq_batch( common_speculative * spec, llama_context * ctx_tgt, @@ -1834,6 +2092,234 @@ bool common_speculative_commit_accepted_output( hidden_rows); } +static bool common_speculative_checkpoint_save( + common_speculative_checkpoint & ckpt, + llama_model * model, + llama_context * ctx, + common_sampler * sampler_src, + const common_params_sampling & sparams, + llama_seq_id seq_id, + llama_pos n_past, + llama_token sampled, + int max_tokens, + int ckpt_mode) { + ckpt.clear(); + ckpt.n_past = n_past; + ckpt.sampled = sampled; + + const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens); + if (actual_mode == LLAMA_SPEC_CKPT_NONE) { + return false; + } + ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP); + + ckpt.valid = llama_spec_ckpt_save(ctx, seq_id); + if (!ckpt.valid) { + llama_spec_ckpt_discard(ctx); + return false; + } + + ckpt.sampler = common_sampler_init(model, sparams); + if (ckpt.sampler == nullptr) { + common_speculative_checkpoint_discard(ckpt, ctx); + return false; + } + + if (sampler_src != nullptr) { + common_sampler_clone(sampler_src, ckpt.sampler); + } + + return true; +} + +const common_speculative_checkpoint * common_speculative_get_checkpoint(const common_speculative * spec) { + return spec != nullptr ? &spec->checkpoint : nullptr; +} + +void common_speculative_checkpoint_discard( + common_speculative_checkpoint & ckpt, + llama_context * ctx) { + ckpt.clear(); + llama_spec_ckpt_discard(ctx); +} + +void common_speculative_checkpoint_restore( + common_speculative_checkpoint & ckpt, + common_speculative * spec, + llama_context * ctx, + common_sampler * sampler_dst, + llama_seq_id seq_id, + common_speculative_type spec_type_used, + llama_token sampled_before, + const std::vector & ids, + int n_draft, + const std::vector & mtp_hidden_state_pre, + int32_t mtp_n_past_base) { + if (!ckpt.valid) { + return; + } + + if (ckpt.per_step_enabled) { + const int step = (int) ids.size() - 1; + llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, step); + + if (ckpt.sampler != nullptr && sampler_dst != nullptr) { + common_sampler_clone(ckpt.sampler, sampler_dst); + } + if (sampler_dst != nullptr) { + for (llama_token id : ids) { + common_sampler_accept(sampler_dst, ctx, id, true); + } + } + + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !mtp_hidden_state_pre.empty()) { + if (!common_speculative_commit_accepted_hidden_rows( + spec, + spec_type_used, + seq_id, + mtp_n_past_base, + sampled_before, + ids, + mtp_hidden_state_pre)) { + common_speculative_clear_sequence_hidden(spec, seq_id); + } else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) { + LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows after per-step restore\n", + __func__, (int) seq_id); + } + } + + LOG_DBG("%s: seq_id=%d per-step restore: step=%d (rejected %d drafts)\n", + __func__, (int) seq_id, step, (int) (n_draft - (ids.size() - 1))); + } else { + llama_spec_ckpt_restore(ctx, seq_id, ckpt.n_past, 0); + + if (ckpt.sampler != nullptr && sampler_dst != nullptr) { + common_sampler_clone(ckpt.sampler, sampler_dst); + } + + if (!ids.empty()) { + const int n_re = (int) ids.size(); + llama_batch re_batch = llama_batch_init(n_re, 0, 1); + common_batch_add(re_batch, ckpt.sampled, ckpt.n_past, { seq_id }, n_re == 1); + for (int j = 0; j < n_re - 1; ++j) { + common_batch_add(re_batch, ids[j], ckpt.n_past + 1 + j, { seq_id }, j == n_re - 2); + } + + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) { + for (int j = 0; j < re_batch.n_tokens; ++j) { + re_batch.logits[j] = true; + } + llama_set_embeddings(ctx, true); + } + + const int ret = llama_decode(ctx, re_batch); + if (ret != 0) { + LOG_ERR("%s: seq_id=%d failed to re-decode accepted tokens after checkpoint restore: %d\n", + __func__, (int) seq_id, ret); + } + + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP)) { + std::vector redecoded_indices(n_re); + for (int j = 0; j < n_re; ++j) { + redecoded_indices[j] = j; + } + + if (!common_speculative_commit_accepted_output( + spec, + ctx, + spec_type_used, + seq_id, + ckpt.n_past, + sampled_before, + ids, + redecoded_indices)) { + common_speculative_clear_sequence_hidden(spec, seq_id); + } + } + + if (sampler_dst != nullptr) { + for (llama_token id : ids) { + common_sampler_accept(sampler_dst, ctx, id, true); + } + } + + llama_batch_free(re_batch); + LOG_DBG("%s: seq_id=%d spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n", + __func__, (int) seq_id, n_re, (int) (n_draft - (ids.size() - 1))); + } + } + + common_speculative_checkpoint_discard(ckpt, ctx); +} + +void common_speculative_commit( + common_speculative * spec, + llama_context * ctx, + common_sampler * sampler_dst, + llama_seq_id seq_id, + llama_token sampled_before, + const std::vector & ids, + int n_draft, + llama_pos pos_base, + const std::vector & accepted_output_indices) { + GGML_ASSERT(spec != nullptr); + GGML_ASSERT(!ids.empty()); + + common_speculative_checkpoint & ckpt = spec->checkpoint; + const common_speculative_type spec_type_used = spec->curr_impl != nullptr + ? spec->curr_impl->type + : COMMON_SPECULATIVE_TYPE_NONE; + const bool any_rejected = (int) ids.size() - 1 < n_draft; + std::vector mtp_hidden_state_pre; + + common_speculative_accept(spec, ids.size() - 1); + + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && + any_rejected && + ckpt.valid && + !accepted_output_indices.empty()) { + if (!common_speculative_copy_output_hidden_rows(spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) { + mtp_hidden_state_pre.clear(); + } + } + + if (any_rejected && ckpt.valid) { + common_speculative_checkpoint_restore( + ckpt, + spec, + ctx, + sampler_dst, + seq_id, + spec_type_used, + sampled_before, + ids, + n_draft, + mtp_hidden_state_pre, + pos_base); + return; + } + + if (common_speculative_has_type(spec, COMMON_SPECULATIVE_TYPE_MTP) && !accepted_output_indices.empty()) { + if (!common_speculative_commit_accepted_output( + spec, + ctx, + spec_type_used, + seq_id, + pos_base, + sampled_before, + ids, + accepted_output_indices)) { + common_speculative_clear_sequence_hidden(spec, seq_id); + } else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) { + LOG_DBG("%s: seq_id=%d synced MTP target hidden state from accepted-prefix rows\n", + __func__, (int) seq_id); + } + } + + llama_kv_cache_seq_rm(ctx, seq_id, pos_base + (llama_pos) (ids.size() - 1), -1); + common_speculative_checkpoint_discard(ckpt, ctx); +} + void common_speculative_print_stats(const common_speculative * spec, double slot_tps, int n_decoded, int n_past, common_params_speculative * active_params) { if (spec == nullptr) { return; @@ -1980,6 +2466,50 @@ void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_s mtp_clear_target_hidden(*mtp_state, seq_id); } +void common_speculative_clear_sequence( + common_speculative * spec, + llama_seq_id seq_id, + bool clear_companion_ctx) { + if (spec != nullptr) { + spec->checkpoint.clear(); + spec->curr_impl = nullptr; + spec->last_n_drafted = 0; + spec->t_step_start_us = 0; + } + + common_speculative_clear_sequence_hidden(spec, seq_id); + + if (clear_companion_ctx) { + if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) { + llama_kv_cache_clear(ctx_mtp); + } + } +} + +bool common_speculative_trim_sequence( + common_speculative * spec, + llama_context * ctx, + llama_seq_id seq_id, + llama_pos pos_begin) { + const bool target_trimmed = llama_kv_cache_seq_rm(ctx, seq_id, pos_begin, -1); + if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) { + return target_trimmed && llama_kv_cache_seq_rm(ctx_mtp, seq_id, pos_begin, -1); + } + + return target_trimmed; +} + +void common_speculative_clear_sequence_kv( + common_speculative * spec, + llama_context * ctx, + llama_seq_id seq_id) { + common_speculative_clear_sequence(spec, seq_id); + llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); + if (auto * ctx_mtp = common_speculative_get_companion_ctx(spec); ctx_mtp != nullptr) { + llama_kv_cache_seq_rm(ctx_mtp, seq_id, -1, -1); + } +} + llama_context * common_speculative_get_companion_ctx(common_speculative * spec) { if (auto * mtp_state = common_speculative_get_mtp_state(spec); mtp_state != nullptr) { return mtp_state->ctx_mtp; @@ -2123,25 +2653,29 @@ std::vector mtp_speculative_gen_draft( const int n_embd = llama_mtp_state_n_embd(ctx); auto & last = mtp_get_last_embd(state, seq_id); - int i0 = 0; + bool use_cached_hidden = false; if (last.last_id >= 0) { - if (last.prob < p_min) { + if (last.last_id != id_last) { + LOG_DBG("%s: seq_id=%d dropping stale cached MTP token: cached=%d current=%d\n", + __func__, (int) seq_id, last.last_id, id_last); + last.last_id = -1; + last.prob = 0.0f; + } else if (last.prob < p_min) { n_draft = 1; + use_cached_hidden = true; + } else { + use_cached_hidden = true; } - current_input_id = last.last_id; last.last_id = -1; - drafts.push_back(current_input_id); - current_n_past++; - if (!llama_set_draft_input_hidden_state_copy(ctx, last.embd.data(), last.embd.size())) { + } + if (use_cached_hidden && !llama_set_draft_input_hidden_state_copy(ctx, last.embd.data(), last.embd.size())) { llama_batch_free(mtp_batch); llama_set_mtp_op_type(ctx, MTP_OP_NONE); return drafts; - } - i0 = 1; } int n_decode = 0; - for (int i = i0; i < n_draft; ++i) { + for (int i = 0; i < n_draft; ++i) { mtp_batch.n_tokens = 0; const llama_pos draft_pos = constant_draft_positions ? n_past : current_n_past; common_batch_add(mtp_batch, current_input_id, draft_pos, {seq_id}, true); @@ -2180,17 +2714,10 @@ std::vector mtp_speculative_gen_draft( llama_batch_free(mtp_batch); llama_set_mtp_op_type(ctx, MTP_OP_NONE); - // Purge the metadata for the draft tokens. - // This prevents cache state corruption where two cells map to the same logical position. - // If the state contained in `last` had a valid token id and probability, it means that we - // have previously run an "accept" batch, where the token sampled from the main model was included. - // In that case, we need to discard all tokens that we ran here to get the KV cache to the correct state. - // => for i0 = 1 we discard from n_past - // But if we did not have a valid last token_id, it means the first token we run was sampled from the - // main model. Hence we want to keep this token in the KV cache and discard all other tokens. - // => for i0 = 0 we discard from n_past + 1 + // Keep `id_last` in the draft KV cache and discard any speculative tail beyond it so the + // next accept/update pass always starts from the same committed sequence prefix. if (n_decode > 0) { - llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1 - i0, n_past + n_decode + 2); + llama_kv_cache_seq_rm(ctx, seq_id, n_past + 1, -1); } return drafts; diff --git a/common/speculative.h b/common/speculative.h index 06d4b5809..da740c0b6 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -7,6 +7,14 @@ struct common_speculative; +enum common_speculative_init_status { + COMMON_SPECULATIVE_INIT_SKIPPED, + COMMON_SPECULATIVE_INIT_READY, + COMMON_SPECULATIVE_INIT_ERR_RECURRENT, + COMMON_SPECULATIVE_INIT_ERR_MTP, + COMMON_SPECULATIVE_INIT_ERR_GENERIC, +}; + using common_speculative_feature_kind = llama_spec_feature_kind; using common_speculative_feature_row_view = llama_spec_feature_row_view; using common_speculative_feature_view = llama_spec_feature_view; @@ -14,6 +22,21 @@ using common_speculative_feature_view = llama_spec_feature_view; static constexpr common_speculative_feature_kind COMMON_SPECULATIVE_FEATURE_NONE = LLAMA_SPEC_FEATURE_NONE; static constexpr common_speculative_feature_kind COMMON_SPECULATIVE_FEATURE_HIDDEN_STATE = LLAMA_SPEC_FEATURE_HIDDEN_STATE; +struct common_speculative_checkpoint { + bool valid = false; + bool per_step_enabled = false; + llama_pos n_past = 0; + llama_token sampled = LLAMA_TOKEN_NULL; + common_sampler * sampler = nullptr; + + void clear(); +}; + +struct common_speculative_draft_result { + llama_tokens tokens; + common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; +}; + // comma separated list of all types std::string common_speculative_type_name_str(); @@ -31,6 +54,29 @@ common_speculative * common_speculative_init( common_params_speculative & params, llama_context * ctx_tgt); +common_speculative_init_status common_speculative_try_init( + common_params_speculative & params, + llama_context * ctx_tgt, + common_speculative ** out_spec); + +void common_speculative_prepare_startup( + gpt_params & params_base, + bool allow_parallel_mtp = true); + +bool common_speculative_finalize_startup( + gpt_params & params_base, + const llama_model * model); + +bool common_speculative_load_draft_model( + common_params_speculative & params, + const gpt_params & params_base); + +bool common_speculative_prepare_mtp_runtime( + common_params_speculative & params, + const gpt_params & params_base, + const llama_model * model, + bool has_external_mtp); + void common_speculative_free(common_speculative * spec); // optionally call once at the beginning of a new generation @@ -46,9 +92,30 @@ llama_tokens common_speculative_draft( llama_pos draft_base_pos = -1, llama_seq_id draft_seq_id = 0); +common_speculative_draft_result common_speculative_draft_ex( + common_speculative * spec, + llama_context * ctx, + common_params_speculative & params, + const llama_tokens & prompt, + llama_token id_last, + llama_pos draft_base_pos = -1, + llama_seq_id draft_seq_id = 0); + // informs the speculative decoder that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); +bool common_speculative_before_draft( + common_speculative * spec, + llama_model * model, + llama_context * ctx, + common_sampler * sampler_src, + const common_params_sampling & sparams, + llama_seq_id seq_id, + llama_pos n_past, + llama_token sampled, + int max_tokens, + int ckpt_mode); + bool common_speculative_ensure_sequence_hidden( common_speculative * spec, llama_context * ctx, @@ -87,10 +154,56 @@ bool common_speculative_commit_accepted_output( const std::vector & ids, const std::vector & output_indices); +const common_speculative_checkpoint * common_speculative_get_checkpoint(const common_speculative * spec); + +void common_speculative_checkpoint_discard( + common_speculative_checkpoint & ckpt, + llama_context * ctx); + +void common_speculative_checkpoint_restore( + common_speculative_checkpoint & ckpt, + common_speculative * spec, + llama_context * ctx, + common_sampler * sampler_dst, + llama_seq_id seq_id, + common_speculative_type spec_type_used, + llama_token sampled_before, + const std::vector & ids, + int n_draft, + const std::vector & mtp_hidden_state_pre, + int32_t mtp_n_past_base); + +void common_speculative_commit( + common_speculative * spec, + llama_context * ctx, + common_sampler * sampler_dst, + llama_seq_id seq_id, + llama_token sampled_before, + const std::vector & ids, + int n_draft, + llama_pos pos_base, + const std::vector & accepted_output_indices); + bool common_speculative_has_sequence_hidden(const common_speculative * spec, llama_seq_id seq_id); void common_speculative_clear_sequence_hidden(common_speculative * spec, llama_seq_id seq_id); +void common_speculative_clear_sequence( + common_speculative * spec, + llama_seq_id seq_id, + bool clear_companion_ctx = false); + +bool common_speculative_trim_sequence( + common_speculative * spec, + llama_context * ctx, + llama_seq_id seq_id, + llama_pos pos_begin); + +void common_speculative_clear_sequence_kv( + common_speculative * spec, + llama_context * ctx, + llama_seq_id seq_id); + llama_context * common_speculative_get_companion_ctx(common_speculative * spec); int32_t common_speculative_on_target_seq_batch( diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 4b1499f1f..7e2595217 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -45,11 +45,6 @@ static void log_text(const gpt_params & params_base, const std::string & text) { } } -static bool params_use_gemma4_external_mtp(const gpt_params & params_base) { - return params_base.has_mtp && - llama_model_is_gemma4_mtp_assistant(params_base.speculative.model_dft); -} - struct server_mtp_warmup { llama_context * ctx_tgt; server_slot * slot; @@ -72,72 +67,6 @@ static bool server_response_needs_chat_parse(oaicompat_type oaicompat) { oaicompat == OAICOMPAT_TYPE_RESP; } -void server_speculative_checkpoint::clear() { - valid = false; - per_step_enabled = false; - n_past = 0; - sampled = LLAMA_TOKEN_NULL; - - if (sampler != nullptr) { - common_sampler_free(sampler); - sampler = nullptr; - } -} - -static void discard_speculative_checkpoint(server_slot & slot, llama_context * ctx) { - slot.spec_ckpt.clear(); - llama_spec_ckpt_discard(ctx); -} - -static bool save_speculative_checkpoint(server_slot & slot, llama_model * model, llama_context * ctx, int ckpt_mode) { - slot.spec_ckpt.clear(); - const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1); - slot.spec_ckpt.n_past = slot.cache_tokens.pos_next(n_pre_spec_tokens); - slot.spec_ckpt.sampled = slot.sampled; - - const int max_tokens = (int)slot.drafted.size() + 1; - const int actual_mode = llama_spec_ckpt_init(ctx, ckpt_mode, max_tokens); - if (actual_mode == LLAMA_SPEC_CKPT_NONE) { - return false; - } - slot.spec_ckpt.per_step_enabled = (actual_mode == LLAMA_SPEC_CKPT_PER_STEP); - - slot.spec_ckpt.valid = llama_spec_ckpt_save(ctx, slot.id); - if (!slot.spec_ckpt.valid) { - llama_spec_ckpt_discard(ctx); - return false; - } - - slot.spec_ckpt.sampler = common_sampler_init(model, slot.sparams); - if (slot.spec_ckpt.sampler == nullptr) { - discard_speculative_checkpoint(slot, ctx); - return false; - } - - common_sampler_clone(slot.ctx_sampling, slot.spec_ckpt.sampler); - return true; -} - -static void server_remove_speculative_stage(common_params_speculative & spec, common_speculative_type type) { - spec.stages.erase(std::remove_if(spec.stages.begin(), spec.stages.end(), [type](const common_speculative_stage_params & stage) { - return stage.type == type; - }), spec.stages.end()); - - if (spec.type == type) { - spec.type = COMMON_SPECULATIVE_TYPE_NONE; - const auto resolved = spec.get_resolved_stages(); - spec.type = resolved.empty() ? COMMON_SPECULATIVE_TYPE_NONE : resolved.front().type; - } -} - -static bool server_speculative_needs_draft_model(const common_params_speculative & spec) { - return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_DRAFT); -} - -static bool server_speculative_has_mtp(const common_params_speculative & spec) { - return spec.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); -} - static bool server_speculative_same_stage_types( const common_params_speculative & lhs, const common_params_speculative & rhs) { @@ -233,29 +162,17 @@ server_context::~server_context() { } // Free multimodal mtmd_free(mctx); - // Free draft model and context if they exist - if (ctx_draft) { - llama_free(ctx_draft); - ctx_draft = nullptr; - } - if (model_draft) { - llama_free_model(model_draft); - model_draft = nullptr; - } // Clear any sampling context for (server_slot& slot : slots) { if (slot.ctx_sampling != nullptr) { common_sampler_free(slot.ctx_sampling); } - slot.spec_ckpt.clear(); - if (slot.ctx_dft) { - llama_free(slot.ctx_dft); - } common_speculative_free(slot.spec); - llama_batch_free(slot.batch_spec); } + params_base.speculative.clear_dft(); + llama_batch_free(batch); } @@ -278,24 +195,10 @@ bool server_context::load_model(const gpt_params& params_) { add_bos_token = llama_should_add_bos_token(model); has_eos_token = llama_add_eos_token(model) != 1; - if (params_base.n_parallel > 1 && server_speculative_has_mtp(params_base.speculative)) { - LOG_WARNING("MTP is not supported with parallel slots yet, removing the MTP stage to avoid cross-slot corruption.\n", { - {"n_parallel", params_base.n_parallel}, - {"stage_chain", common_speculative_stage_chain_to_str(params_base.speculative)}, - }); - - params_base.has_mtp = false; - server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP); + common_speculative_prepare_startup(params_base, false); - if (!server_speculative_needs_draft_model(params_base.speculative)) { - params_base.speculative.model.clear(); - params_base.speculative.params.clear(); - params_base.speculative.model_dft = nullptr; - } - } - - bool has_draft_model = !params_base.speculative.model.empty() || !params_base.speculative.params.empty(); - std::string& mmproj_path = params_base.mmproj.path; + const bool has_draft_model = params_base.speculative.has_dft(); + std::string & mmproj_path = params_base.mmproj.path; if (!mmproj_path.empty()) { mtmd_context_params mparams = mtmd_context_params_default(); mparams.use_gpu = params_base.mmproj_use_gpu; @@ -309,10 +212,10 @@ bool server_context::load_model(const gpt_params& params_) { mparams.image_max_tokens = params_base.image_max_tokens; mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); if (mctx == nullptr) { - LOG_ERROR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + LOG_ERROR("failed to load multimodal model, %s\n", mmproj_path.c_str()); return false; } - LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + LOG_INFO("loaded multimodal model, %s\n", mmproj_path.c_str()); //if (params.n_cache_reuse) { // params_base.n_cache_reuse = 0; @@ -323,71 +226,22 @@ bool server_context::load_model(const gpt_params& params_) { LOG_ERROR("%s\n", "err: speculative decode is not supported by multimodal"); return false; } - const auto spec_stages = params_base.speculative.get_resolved_stages(); - const bool multimodal_spec_supported = spec_stages.empty() || - (spec_stages.size() == 1 && spec_stages.front().type == COMMON_SPECULATIVE_TYPE_MTP); - if (!multimodal_spec_supported) { + + const auto spec_stages = params_base.speculative.get_resolved_stages(); + const bool multimodal_spec_supported = spec_stages.empty() || + (spec_stages.size() == 1 && spec_stages.front().type == COMMON_SPECULATIVE_TYPE_MTP); + if (!multimodal_spec_supported) { params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; params_base.speculative.stages.clear(); params_base.has_mtp = false; SRV_WRN("%s\n", "speculative decoding is not supported by multimodal, it will be disabled"); } } - // Load draft model for speculative decoding if specified - if (has_draft_model) { - LLAMA_LOG_INFO("\n\n==================================loading DRAFT model==================================\n\n"); - - gpt_params params_dft; - params_dft.devices = params_base.speculative.devices; - params_dft.model = params_base.speculative.model; - params_dft.main_gpu = params_base.main_gpu; - params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; - params_dft.rpc_servers = params_base.rpc_servers; - params_dft.cache_type_k = params_base.speculative.cache_type_k.empty() ? params_base.cache_type_k : params_base.speculative.cache_type_k; - params_dft.cache_type_v = params_base.speculative.cache_type_v.empty() ? params_base.cache_type_v : params_base.speculative.cache_type_v; - params_dft.flash_attn = params_base.flash_attn; - params_dft.k_cache_hadamard = params_base.k_cache_hadamard; - params_dft.v_cache_hadamard = params_base.v_cache_hadamard; - if (!params_base.speculative.params.empty()) { - auto [argc, argv] = parse_command_line("llama-server " + params_base.speculative.params); - if (!gpt_params_parse(argc, argv, params_dft)) { - gpt_params_print_usage(argc, argv, params_dft); - free_command_line(argc, argv); - return false; - }; - free_command_line(argc, argv); - } - LOG_INFO("", { {"model", params_dft.model} }); - if (params_dft.n_ctx == 0) { - params_dft.n_ctx = params_base.speculative.n_ctx; - } - params_dft.n_ctx = params_dft.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_dft.n_ctx; - params_dft.n_parallel = 1; - params_dft.n_batch = params_dft.n_ctx; - - params_base.speculative.mparams_dft.path = params_dft.model; // - - llama_model_params mparams_dft = common_model_params_to_llama(params_dft); - - llama_model * model_dft = llama_model_load_from_file(params_dft.model.c_str(), mparams_dft); - if (model_dft == nullptr) { - LOG_ERROR("failed to load draft model", { {"model", params_base.speculative.model} }); - return false; - } - - cparams_dft = common_context_params_to_llama(params_dft); - params_base.speculative.model_dft = model_dft; - params_base.speculative.cparams_dft = cparams_dft; - - } - if (server_speculative_has_mtp(params_base.speculative) && - llama_model_n_nextn_layer(model) == 0 && - !params_use_gemma4_external_mtp(params_base)) { - LOG_WARNING("WARNING: MTP speculative stage requested, but model has 0 NextN layers. MTP will be disabled.\n", {}); - params_base.has_mtp = false; - server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP); + if (!common_speculative_finalize_startup(params_base, model)) { + return false; } + return true; } @@ -396,6 +250,20 @@ void server_context::init() { LOG_INFO("initializing slots", { {"n_slots", params_base.n_parallel} }); + if (params_base.has_mtp) { + SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); + llama_set_embeddings(ctx, true); + } + + const bool requested_spec = params_base.speculative.has_stage_chain(); + bool can_spec = true; + if (!params_base.dry_run) { + can_spec = common_speculative_is_compat(ctx); + } + if (!can_spec && requested_spec) { + SRV_WRN("%s", "speculative decoding not supported by this context\n"); + } + for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; @@ -440,68 +308,27 @@ void server_context::init() { slot.params.speculative = params_base.speculative; slot.sparams = params_base.sparams; - const bool wants_mtp_stage = server_speculative_has_mtp(params_base.speculative); - if (wants_mtp_stage) { - const bool has_external_mtp = params_use_gemma4_external_mtp(params_base); - - if (llama_model_n_nextn_layer(model) > 0 || has_external_mtp) { - params_base.pooling_type = LLAMA_POOLING_TYPE_NONE; - - if (!has_external_mtp) { - params_base.speculative.cparams_dft = common_context_params_to_llama(params_base); - } - - params_base.speculative.cparams_dft.mtp = true; - params_base.speculative.cparams_dft.mtp_op_type = MTP_OP_WARMUP; - params_base.speculative.cparams_dft.embeddings = true; - - slot.has_mtp = true; - slot.params.speculative.cparams_dft = params_base.speculative.cparams_dft; - - slot.batch_spec = llama_batch_init(slot.params.speculative.get_max_stage_n_max() + 1, 0, 1); - SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); - - SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); - llama_set_embeddings(ctx, true); - } - else { - SRV_WRN("%s\n", "MTP speculative stage requested, but model has 0 NextN layers. Removing MTP from the configured stage chain."); - params_base.has_mtp = false; - server_remove_speculative_stage(params_base.speculative, COMMON_SPECULATIVE_TYPE_MTP); - slot.params.speculative = params_base.speculative; - slot.has_mtp = false; - } - } - - const bool requested_spec = !params_base.speculative.get_resolved_stages().empty(); - - bool can_spec = true; - if (!params_base.dry_run) { - can_spec = common_speculative_is_compat(ctx); - } - if (!can_spec) { - SRV_WRN("%s", "speculative decoding not supported by this context\n"); - } // try speculative decoding if (can_spec && requested_spec) { - slot.spec = common_speculative_init(params_base.speculative, slot.ctx); - if (slot.spec) { - if (mctx && !slot.has_mtp) { + switch (common_speculative_try_init(params_base.speculative, slot.ctx, &slot.spec)) { + case COMMON_SPECULATIVE_INIT_READY: + if (mctx && !slot.uses_mtp()) { SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); return; } SLT_INF(slot, "%s", "speculative decoding context initialized\n"); - } else { - if (llama_model_has_recurrent(model)) { - SRV_ERR("%s", "failed to initialize recurrent speculative context\n"); - throw std::runtime_error("recurrent speculative context initialization failed"); - } else if (slot.has_mtp) { - SRV_ERR("%s", "failed to initialize MTP speculative context\n"); - throw std::runtime_error("MTP speculative context initialization failed"); - } else { - SRV_ERR("%s", "failed to initialize speculative decoding context\n"); - throw std::runtime_error("speculative decoding context initialization failed"); - } + break; + case COMMON_SPECULATIVE_INIT_ERR_RECURRENT: + SRV_ERR("%s", "failed to initialize recurrent speculative context\n"); + throw std::runtime_error("recurrent speculative context initialization failed"); + case COMMON_SPECULATIVE_INIT_ERR_MTP: + SRV_ERR("%s", "failed to initialize MTP speculative context\n"); + throw std::runtime_error("MTP speculative context initialization failed"); + case COMMON_SPECULATIVE_INIT_ERR_GENERIC: + SRV_ERR("%s", "failed to initialize speculative decoding context\n"); + throw std::runtime_error("speculative decoding context initialization failed"); + case COMMON_SPECULATIVE_INIT_SKIPPED: + break; } } @@ -620,9 +447,7 @@ void server_slot::reset() { n_kept_prompt = 0; n_sent_text = 0; drafted.clear(); - drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE; i_batch_dft.clear(); - spec_ckpt.clear(); n_sent_token_probs = 0; infill = false; ga_i = 0; @@ -640,7 +465,7 @@ void server_slot::reset() { image_just_processed = false; do_checkpoint = false; if (spec != nullptr) { - common_speculative_clear_sequence_hidden(spec, id); + common_speculative_clear_sequence(spec, id); } positional_bans.clear(); @@ -675,7 +500,11 @@ void server_slot::reset() { } bool server_slot::need_embd() const { - return embedding || has_mtp; + return embedding || uses_mtp(); +} + +bool server_slot::uses_mtp() const { + return params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP); } bool server_slot::has_budget(gpt_params& global_params) { @@ -711,7 +540,7 @@ void server_slot::add_token_string(const completion_token_output& token) { } bool server_slot::can_speculate() const { - return (!!spec || has_mtp); + return (!!spec || uses_mtp()); } int server_slot::get_n_draft_max() const { @@ -1327,7 +1156,7 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) throw std::runtime_error("Error: per-request speculative stages must match the server startup stage types; only stage parameter overrides are supported"); } - if (slot.params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && !slot.has_mtp) { + if (slot.params.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP) && !params_base.has_mtp) { throw std::runtime_error("Error: MTP speculative stage requested, but the server was not started with MTP support"); } @@ -2107,10 +1936,7 @@ void server_context::kv_cache_clear() { continue; } - common_speculative_clear_sequence_hidden(slot.spec, slot.id); - if (auto * ctx_companion = common_speculative_get_companion_ctx(slot.spec); ctx_companion != nullptr) { - llama_kv_cache_clear(ctx_companion); - } + common_speculative_clear_sequence(slot.spec, slot.id, true); } clean_kv_cache = false; } @@ -3359,7 +3185,7 @@ void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_sl const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(kv_keep), slot.cache_tokens.pos_next(kv_keep + kv_discard)); llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard); - if (slot.has_mtp && slot.spec) { + if (slot.uses_mtp() && slot.spec) { common_speculative_context_shift(slot.spec, slot.id, kv_keep, kv_discard, kv_past); } if (slot.params.cache_prompt) { @@ -3568,33 +3394,27 @@ void server_context::add_sampled_tokens() { // perform the speculative drafting for all sequences at the same time in a single batch const int n_draft_max_pre = slot.get_n_draft_max(); if (n_draft_max_pre > 0) { - if (mctx && !slot.has_mtp) { + if (mctx && !slot.uses_mtp()) { // we should never reach this, as speculative is automatically disabled if mmproj is loaded GGML_ABORT("not supported by multimodal"); } static const llama_tokens empty_prompt; - const llama_tokens & cached_text_tokens = slot.has_mtp && !slot.params.speculative.has_composite_stage_chain() + const llama_tokens & cached_text_tokens = slot.uses_mtp() && !slot.params.speculative.has_composite_stage_chain() ? empty_prompt : slot.cache_tokens.get_text_tokens(); auto & params_spec = slot.params.speculative; - const llama_pos draft_base_pos = slot.has_mtp ? slot.cache_tokens.pos_next() : -1; - - if (slot.has_mtp) { - if (!common_speculative_ensure_sequence_hidden(slot.spec, ctx, slot.id, draft_base_pos - 1)) { - LOG_ERROR("MTP hidden state is empty during speculation", {}); - } - } - - llama_tokens draft = common_speculative_draft( + const llama_pos draft_base_pos = slot.uses_mtp() ? slot.cache_tokens.pos_next() : -1; + common_speculative_draft_result draft_result = common_speculative_draft_ex( slot.spec, + ctx, params_spec, cached_text_tokens, slot.sampled, draft_base_pos, slot.id); - slot.drafted_spec_type = common_speculative_current_type(slot.spec); + llama_tokens & draft = draft_result.tokens; const int n_draft_max = slot.get_n_draft_max(); @@ -3619,7 +3439,6 @@ void server_context::add_sampled_tokens() { // fallback to normal decoding slot.i_batch = slot.i_batch_dft[0]; slot.drafted.clear(); - slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE; slot.i_batch_dft.clear(); } else { // keep track of total number of drafted tokens tested @@ -3636,7 +3455,6 @@ void server_context::add_sampled_tokens() { } else { // no speculative decoding - slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE; slot.i_batch = batch.n_tokens; common_batch_add(batch, slot.sampled, slot.cache_tokens.pos_next(), { slot.id }, true); @@ -3976,15 +3794,10 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot.cache_tokens.keep_first(slot.n_past); int p0 = (int)system_tokens.size() + slot.n_past; p0 = system_tokens.size() + slot.cache_tokens.pos_next(); - auto * ctx_companion = slot.spec ? common_speculative_get_companion_ctx(slot.spec) : nullptr; - const bool target_trimmed = llama_kv_cache_seq_rm(ctx, slot.id, p0, -1); - const bool companion_trimmed = ctx_companion == nullptr || llama_kv_cache_seq_rm(ctx_companion, slot.id, p0, -1); - if (!target_trimmed || !companion_trimmed) { + const bool trimmed = common_speculative_trim_sequence(slot.spec, ctx, slot.id, p0); + if (!trimmed) { // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); - if (ctx_companion != nullptr) { - llama_kv_cache_seq_rm(ctx_companion, slot.id, -1, -1); - } + common_speculative_clear_sequence_kv(slot.spec, ctx, slot.id); p0 = (int)system_tokens.size(); if (p0 != 0) { @@ -4021,7 +3834,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t llama_pos p1 = slot.cache_tokens.pos_next() + slot.n_past_prompt - slot.n_past; // add offset to prompt server_mtp_warmup mtp_media_warmup { ctx, - slot.has_mtp && slot.spec ? &slot : nullptr, + slot.uses_mtp() && slot.spec ? &slot : nullptr, }; mtmd_helper_eval_batch_callback mtp_media_callback = mtp_media_warmup.slot ? server_mtp_media_warmup_callback : nullptr; @@ -4163,103 +3976,6 @@ void server_context::extend_context(const int32_t n_tokens) { } } -// Restore recurrent state and re-decode accepted tokens after speculative-decode rejection. -static void restore_speculative_checkpoint( - server_slot & slot, llama_context * ctx, llama_model * model, - common_speculative_type spec_type_used, - llama_token sampled_before, - const std::vector & ids, int n_draft, - const std::vector & mtp_hidden_state_pre, int32_t mtp_n_past_base) { - if (slot.spec_ckpt.per_step_enabled) { - const int step = (int)ids.size() - 1; - llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, step); - - if (slot.spec_ckpt.sampler) { - common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling); - } - for (llama_token id : ids) { - common_sampler_accept(slot.ctx_sampling, ctx, id, true); - } - - // Update MTP KV cache and hidden state using embeddings collected before checkpoint restore. - if (slot.has_mtp && !mtp_hidden_state_pre.empty()) { - if (!common_speculative_commit_accepted_hidden_rows( - slot.spec, - spec_type_used, - slot.id, - mtp_n_past_base, - sampled_before, - ids, - mtp_hidden_state_pre)) { - common_speculative_clear_sequence_hidden(slot.spec, slot.id); - } else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) { - SLT_DBG(slot, "%s", "synced MTP target hidden state from accepted-prefix rows after per-step restore"); - } - } - - SLT_DBG(slot, "per-step restore: step=%d (rejected %d drafts)\n", - step, (int)(n_draft - (ids.size() - 1))); - } else { - // Restore pre-speculation recurrent state then re-decode accepted tokens. - llama_spec_ckpt_restore(ctx, slot.id, slot.spec_ckpt.n_past, 0); - - if (slot.spec_ckpt.sampler) { - common_sampler_clone(slot.spec_ckpt.sampler, slot.ctx_sampling); - } - - if (!ids.empty()) { - // Re-decode to advance recurrent state to the accepted position. - const int n_re = (int)ids.size(); - llama_batch re_batch = llama_batch_init(n_re, 0, 1); - common_batch_add(re_batch, slot.spec_ckpt.sampled, slot.spec_ckpt.n_past, { slot.id }, n_re == 1); - for (int j = 0; j < n_re - 1; j++) { - common_batch_add(re_batch, ids[j], slot.spec_ckpt.n_past + 1 + j, { slot.id }, j == n_re - 2); - } - - if (slot.has_mtp) { - for (int j = 0; j < re_batch.n_tokens; j++) { - re_batch.logits[j] = true; - } - llama_set_embeddings(ctx, true); - } - - const int ret = llama_decode(ctx, re_batch); - if (ret != 0) { - SLT_ERR(slot, "failed to re-decode accepted tokens after checkpoint restore: %d\n", ret); - } - if (slot.has_mtp) { - const int n_accepted = (int)ids.size(); - std::vector redecoded_indices(n_accepted); - for (int j = 0; j < n_accepted; ++j) { - redecoded_indices[j] = j; - } - - if (!common_speculative_commit_accepted_output( - slot.spec, - ctx, - spec_type_used, - slot.id, - slot.spec_ckpt.n_past, - sampled_before, - ids, - redecoded_indices)) { - common_speculative_clear_sequence_hidden(slot.spec, slot.id); - } - } - - for (llama_token id : ids) { - common_sampler_accept(slot.ctx_sampling, ctx, id, true); - } - - llama_batch_free(re_batch); - SLT_DBG(slot, "spec checkpoint restored: re-decoded %d tokens (rejected %d drafts)\n", - n_re, (int)(n_draft - (ids.size() - 1))); - } - } - - discard_speculative_checkpoint(slot, ctx); -} - void server_context::speculative_decoding_accept() { for (auto& slot : slots) { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) { @@ -4267,7 +3983,6 @@ void server_context::speculative_decoding_accept() { } const llama_token sampled_before = slot.sampled; - const common_speculative_type spec_type_used = slot.drafted_spec_type; size_t n_draft = slot.drafted.size(); slot.ctx_sampling->to_generated_text = &slot.generated_text; @@ -4297,28 +4012,15 @@ void server_context::speculative_decoding_accept() { continue; } - const bool any_rejected = (ids.size() - 1) < n_draft; - int32_t mtp_n_past_base = 0; - std::vector mtp_hidden_state_pre; std::vector accepted_output_indices; - if (slot.has_mtp) { - const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t)(slot.drafted.size() + 1); - mtp_n_past_base = slot.cache_tokens.pos_next(n_pre_spec_tokens); - + if (slot.uses_mtp()) { if (!ids.empty()) { accepted_output_indices.assign(slot.i_batch_dft.begin(), slot.i_batch_dft.begin() + ids.size()); } - - if (any_rejected && slot.spec_ckpt.valid && !accepted_output_indices.empty()) { - if (!common_speculative_copy_output_hidden_rows(slot.spec, ctx, accepted_output_indices, mtp_hidden_state_pre)) { - mtp_hidden_state_pre.clear(); - } - } } slot.i_batch_dft.clear(); slot.drafted.clear(); - slot.drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE; slot.n_past += ids.size(); slot.n_decoded += ids.size(); @@ -4328,11 +4030,9 @@ void server_context::speculative_decoding_accept() { // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - // inform the speculative decoding about the number of accepted tokens - common_speculative_accept(slot.spec, ids.size() - 1); - // rollback to the state before sampling the draft tokens slot.cache_tokens.keep_first(slot.cache_tokens.n_tokens() - n_draft); + const llama_pos spec_pos_base = slot.cache_tokens.pos_next(); // add accepted tokens to the prompt for (auto it = ids.begin(); it != ids.end() - 1; ++it) { @@ -4341,28 +4041,16 @@ void server_context::speculative_decoding_accept() { slot.sampled = ids.back(); // last accepted token slot.n_past = slot.cache_tokens.n_tokens(); - // for recurrent/hybrid models: if any drafts were rejected, restore recurrent state - if (any_rejected && slot.spec_ckpt.valid) { - restore_speculative_checkpoint(slot, ctx, model, spec_type_used, sampled_before, ids, n_draft, mtp_hidden_state_pre, mtp_n_past_base); - } else { - if (slot.has_mtp && !accepted_output_indices.empty()) { - if (!common_speculative_commit_accepted_output( - slot.spec, - ctx, - spec_type_used, - slot.id, - mtp_n_past_base, - sampled_before, - ids, - accepted_output_indices)) { - common_speculative_clear_sequence_hidden(slot.spec, slot.id); - } else if (spec_type_used != COMMON_SPECULATIVE_TYPE_MTP) { - SLT_DBG(slot, "%s", "synced MTP target hidden state from accepted-prefix rows"); - } - } - llama_kv_cache_seq_rm(ctx, slot.id, slot.cache_tokens.pos_next(slot.n_past), -1); - discard_speculative_checkpoint(slot, ctx); - } + common_speculative_commit( + slot.spec, + ctx, + slot.ctx_sampling, + slot.id, + sampled_before, + ids, + n_draft, + spec_pos_base, + accepted_output_indices); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; @@ -4736,9 +4424,9 @@ void server_context::process_batch_tokens(int32_t & n_batch) { continue; // continue loop of n_batch } - if (server_speculative_has_mtp(params_base.speculative)) { + if (params_base.speculative.has_stage_type(COMMON_SPECULATIVE_TYPE_MTP)) { for (auto & slot : slots) { - if (!slot.spec || !slot.has_mtp) { + if (!slot.spec || !slot.uses_mtp()) { continue; } @@ -4778,7 +4466,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { if (slot.n_decoded == 0 && slot.can_speculate()) { static const llama_tokens empty_prompt; - const llama_tokens & spec_prompt = slot.has_mtp && !slot.params.speculative.has_composite_stage_chain() + const llama_tokens & spec_prompt = slot.uses_mtp() && !slot.params.speculative.has_composite_stage_chain() ? empty_prompt : slot.cache_tokens.get_text_tokens(); common_speculative_begin(slot.spec, spec_prompt); @@ -4802,7 +4490,7 @@ void server_context::process_batch_tokens(int32_t & n_batch) { completion_token_output result; const int tok_idx = slot.i_batch - i; - if (slot.has_mtp && slot.n_decoded == 0) { + if (slot.uses_mtp() && slot.n_decoded == 0) { (void) common_speculative_capture_output_hidden(slot.spec, ctx, tok_idx, slot.id, slot.n_past); } @@ -4934,10 +4622,25 @@ void server_context::update_slots() { if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch_dft.empty()) { continue; } - if (save_speculative_checkpoint(slot, model, ctx, ckpt_mode)) { - const char * mode_name = slot.spec_ckpt.per_step_enabled ? "per-step" : "shadow/cpu"; + const int32_t n_pre_spec_tokens = slot.cache_tokens.n_tokens() - (int32_t) (slot.drafted.size() + 1); + const llama_pos n_past_pre_spec = slot.cache_tokens.pos_next(n_pre_spec_tokens); + const int max_tokens = (int) slot.drafted.size() + 1; + if (common_speculative_before_draft( + slot.spec, + model, + ctx, + slot.ctx_sampling, + slot.sparams, + slot.id, + n_past_pre_spec, + slot.sampled, + max_tokens, + ckpt_mode)) { + const common_speculative_checkpoint * ckpt = common_speculative_get_checkpoint(slot.spec); + GGML_ASSERT(ckpt != nullptr); + const char * mode_name = ckpt->per_step_enabled ? "per-step" : "shadow/cpu"; SLT_DBG(slot, "spec checkpoint saved (mode=%s), n_past_pre_spec=%d\n", - mode_name, slot.spec_ckpt.n_past); + mode_name, ckpt->n_past); } else { SLT_WRN(slot, "%s", "failed to save spec checkpoint\n"); } diff --git a/examples/server/server-context.h b/examples/server/server-context.h index a33c2113b..1a0d1e4ac 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -22,16 +22,6 @@ enum slot_command { SLOT_COMMAND_RELEASE, }; -struct server_speculative_checkpoint { - bool valid = false; - bool per_step_enabled = false; // per-step SSM checkpoints active - llama_pos n_past = 0; - llama_token sampled = LLAMA_TOKEN_NULL; - common_sampler * sampler = nullptr; // saved sampler state - - void clear(); -}; - struct server_slot { int id; int id_task = -1; @@ -39,9 +29,6 @@ struct server_slot { struct slot_params params; - llama_batch batch_spec = {}; - llama_context * ctx_dft = nullptr; - bool released = false; slot_state state = SLOT_STATE_IDLE; slot_command command = SLOT_COMMAND_NONE; @@ -136,7 +123,6 @@ struct server_slot { // sampling llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; - common_speculative_type drafted_spec_type = COMMON_SPECULATIVE_TYPE_NONE; json json_schema; @@ -171,11 +157,6 @@ struct server_slot { // expiring logit bias std::vector prev_elb_states; - bool has_mtp = false; - - // saves recurrent state before a speculative batch so it can be restored on rejection - server_speculative_checkpoint spec_ckpt; - // speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted @@ -195,6 +176,7 @@ struct server_slot { void reset(); bool need_embd() const; + bool uses_mtp() const; bool has_budget(gpt_params& global_params); @@ -266,11 +248,6 @@ struct server_context { // multimodal mtmd_context* mctx = nullptr; - // For speculative decoding - llama_model* model_draft = nullptr; - llama_context* ctx_draft = nullptr; - llama_context_params cparams_dft; - int32_t n_ctx; // total context for all clients / slots // system prompt