diff --git a/common/common.cpp b/common/common.cpp index ffb8d5fdc8..4047aaa8dc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2689,6 +2689,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cache_ram_n_min = std::stoi(argv[i]); return true; } + if (arg == "--cache-ram-reuse-n-min") { + CHECK_ARG + params.cache_ram_reuse_n_min = std::stoi(argv[i]); + return true; + } if (arg == "--pos") { CHECK_ARG params.i_pos = std::stoi(argv[i]); @@ -2887,6 +2892,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib }); options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity }); options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min }); + options.push_back({ "*", "--cache-ram-reuse-n-min N", "minimum reusable common-prefix tokens required to restore from prompt cache (default: %d).", params.cache_ram_reuse_n_min }); options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch }); options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); diff --git a/common/common.h b/common/common.h index bc68ca0f02..510c2cba53 100644 --- a/common/common.h +++ b/common/common.h @@ -519,6 +519,7 @@ struct gpt_params { int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram + int32_t cache_ram_reuse_n_min = 0; // min reusable common-prefix tokens required to restore from ram cache float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens // batched-bench params diff --git a/docs/parameters.md b/docs/parameters.md index 77ec4fa230..1758867a9f 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -150,6 +150,11 @@ Note: When the available memory is very limited, turn this option off (`-cram 0` | `-cram, --cache-ram N` | Set the maximum cache size in MiB | 8192 | -1 = no limit, 0 = disable Very useful when the variations of the same prompt are re-sent to the model (coding agents, etc.). [PR 954](https://github.com/ikawrakow/ik_llama.cpp/pull/954) | | `-crs, --cache-ram-similarity N` | Max similarity of prompt tokens to cache tokens that triggers prompt cache | 0.50 | | | `-cram-n-min, --cache-ram-n-min N` | Minimum number of cached tokens that triggers prompt cache | 0 | | +| `--cache-ram-reuse-n-min N` | Minimum reusable common-prefix tokens required to restore from prompt cache | 0 | | + +When restoring a prompt from RAM cache, the server ranks candidates by reusable common prefix first and uses similarity as a tie-breaker. A candidate is skipped if its reusable prefix is below `--cache-ram-reuse-n-min`, if the reusable fraction is below `--cache-ram-similarity`, or if the cached KV range cannot be rewound safely for SWA/sliding-window attention. + +For SWA models, cached prompt entries track the KV position range saved with the state. If the active KV window starts after the reusable prefix, restore is only allowed when an earlier context checkpoint can rewind to that prefix. Otherwise the candidate is rejected before loading state, so a bad RAM-cache hit does not mutate the slot and then force full prompt re-processing. When context checkpoints are capped, the server keeps the earliest checkpoint as a rewind anchor when possible. ## Sampling diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index 99345458f1..b5368f9973 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -590,8 +590,8 @@ void server_slot::prompt_save(server_prompt_cache& prompt_cache) const { llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0); } -void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) { - bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id); +void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens, size_t min_reusable_prefix, float min_reusable_fraction) { + bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id, min_reusable_prefix, min_reusable_fraction); if (!res) { LLAMA_LOG_INFO("failed to load prompt from cache\n"); } @@ -1031,6 +1031,8 @@ void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, se slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt; slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt; slot.server_cached_prompt.think_tokens = slot.params.think_tokens; + slot.server_cached_prompt.pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); + slot.server_cached_prompt.pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); } server_slot* server_context::get_available_slot(const server_task& task) { @@ -1146,7 +1148,7 @@ server_slot* server_context::get_available_slot(const server_task& task) { const int64_t t_start = ggml_time_us(); copy_data_to_cached_prompt(tokens, *ret); - ret->prompt_load(*prompt_cache, task.tokens); + ret->prompt_load(*prompt_cache, task.tokens, (size_t) std::max(0, cache_ram_reuse_n_min), cache_ram_similarity); prompt_cache->update(); ret->cache_tokens = ret->server_cached_prompt.tokens.clone(); // recover cache tokens @@ -3719,13 +3721,17 @@ bool server_context::create_checkpoint(server_slot & slot) { if (do_checkpoint) { const int64_t t_start = ggml_time_us(); while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) { - // make room for the new checkpoint, if needed - const auto & cur = slot.server_cached_prompt.checkpoints.front(); + // Preserve the earliest checkpoint as an SWA rewind anchor when the cap allows it. + auto erase_it = slot.server_cached_prompt.checkpoints.begin(); + if (params_base.ctx_checkpoints_n > 1 && slot.server_cached_prompt.checkpoints.size() > 1) { + ++erase_it; + } + const auto & cur = *erase_it; SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024); - slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin()); + slot.server_cached_prompt.checkpoints.erase(erase_it); } auto & cur = slot.server_cached_prompt.checkpoints.emplace_back(); diff --git a/examples/server/server-context.h b/examples/server/server-context.h index a33c2113b9..e8c77677de 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -127,7 +127,7 @@ struct server_slot { void prompt_save(server_prompt_cache& prompt_cache) const; - void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens); + void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens, size_t min_reusable_prefix, float min_reusable_fraction); size_t checkpoint_pos = 0; bool do_checkpoint = false; @@ -297,6 +297,7 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; int32_t cache_ram_n_min = 0; + int32_t cache_ram_reuse_n_min = 0; float cache_ram_similarity = 0.5f; ~server_context(); diff --git a/examples/server/server-task.cpp b/examples/server/server-task.cpp index 95287820b9..6186c9d8b1 100644 --- a/examples/server/server-task.cpp +++ b/examples/server/server-task.cpp @@ -1073,7 +1073,7 @@ size_t server_prompt_cache::n_tokens() const { } -bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) { +bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot, size_t min_reusable_prefix, float min_reusable_fraction) { thinking_tokens think_tokens; for (auto it = states.begin(); it != states.end(); ++it) { think_tokens = it->think_tokens; @@ -1086,37 +1086,49 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token tokens_new_ex = tokens_new.get_tokens_exclude_think(ctx, think_tokens); } else { - prompt_tokens = std::move(prompt.tokens); + prompt_tokens = prompt.tokens.clone(); tokens_new_ex = tokens_new.clone(); } const auto lcp_best = prompt_tokens.get_common_prefix(ctx, tokens_new_ex); - float f_keep_best = float(lcp_best.second) / prompt_tokens.size(); + float f_keep_best = prompt_tokens.empty() ? 0.0f : float(lcp_best.first) / prompt_tokens.size(); float sim_best = prompt_tokens.get_tokens_similarity(ctx, tokens_new_ex, prompt.n_kept_prompt, prompt.n_discarded_prompt); - LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt); + size_t lcp_best_tokens = lcp_best.first; + LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, lcp = %zu, min_reusable_prefix = %zu, min_reusable_fraction = %.3f, n_keep = %d, n_discarded_prompt = %d\n", + f_keep_best, sim_best, lcp_best_tokens, min_reusable_prefix, min_reusable_fraction, prompt.n_kept_prompt, prompt.n_discarded_prompt); auto it_best = states.end(); - // find the most similar cached prompt, that would also preserve the most context + // find the most similar viable cached prompt; common prefix breaks ties for (auto it = states.begin(); it != states.end(); ++it) { server_tokens tokens; if (think_tokens.exclude) { tokens = it->tokens.get_tokens_exclude_think(ctx, think_tokens); } else { - tokens = std::move(it->tokens); + tokens = it->tokens.clone(); } const auto lcp_cur = tokens.get_common_prefix(ctx, tokens_new_ex); - const float f_keep_cur = float(lcp_cur.first) / tokens.size(); + const float f_keep_cur = tokens.empty() ? 0.0f : float(lcp_cur.first) / tokens.size(); const float sim_cur = tokens.get_tokens_similarity(ctx, tokens_new_ex, it->n_kept_prompt, it->n_discarded_prompt); - if (sim_best < sim_cur) { + const bool prefix_ok = lcp_cur.first >= min_reusable_prefix; + const bool fraction_ok = f_keep_cur >= min_reusable_fraction; + const bool rewind_ok = it->has_rewind_checkpoint(lcp_cur.first); + if (!prefix_ok || !fraction_ok || !rewind_ok) { + LLAMA_LOG_INFO(" - skipping prompt cache candidate: lcp = %zu, f_keep = %.3f, sim = %.3f, pos_min = %d, checkpoints = %zu, prefix_ok = %d, fraction_ok = %d, rewind_ok = %d\n", + lcp_cur.first, f_keep_cur, sim_cur, it->pos_min, it->checkpoints.size(), (int) prefix_ok, (int) fraction_ok, (int) rewind_ok); + continue; + } + if (sim_best < sim_cur || (sim_best == sim_cur && lcp_best_tokens < lcp_cur.first)) { f_keep_best = f_keep_cur; sim_best = sim_cur; + lcp_best_tokens = lcp_cur.first; it_best = it; } } if (it_best != states.end()) { - LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt); + LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, lcp = %zu, pos_min = %d, checkpoints = %zu, n_keep = %d, n_discarded_prompt = %d\n", + f_keep_best, sim_best, lcp_best_tokens, it_best->pos_min, it_best->checkpoints.size(), it_best->n_kept_prompt, it_best->n_discarded_prompt); const size_t size = it_best->data.size(); const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot, 0); if (n != size) { @@ -1182,6 +1194,8 @@ server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t st /*.n_keep =*/ prompt.n_kept_prompt, /*.n_discarded_prompt =*/ prompt.n_discarded_prompt, /*.think_tokens =*/ prompt.think_tokens, + /*.pos_min =*/ prompt.pos_min, + /*.pos_max =*/ prompt.pos_max, /*.data =*/ std::move(state_data), /*.checkpoints =*/ prompt.checkpoints, }; diff --git a/examples/server/server-task.h b/examples/server/server-task.h index 76a6bad36a..6669133a2d 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -386,6 +386,8 @@ struct server_prompt { int n_kept_prompt; int n_discarded_prompt; thinking_tokens think_tokens; + llama_pos pos_min = -1; + llama_pos pos_max = -1; std::vector data; @@ -397,12 +399,26 @@ struct server_prompt { return tokens.size(); } + bool has_rewind_checkpoint(size_t lcp) const { + if (pos_min < 0 || pos_min <= (llama_pos) lcp) { + return true; + } + for (const auto & checkpoint : checkpoints) { + if (checkpoint.pos_max <= (llama_pos) lcp) { + return true; + } + } + return false; + } + server_prompt clone() const { return server_prompt{ tokens.clone(), n_kept_prompt, n_discarded_prompt, think_tokens, + pos_min, + pos_max, data, checkpoints }; @@ -414,6 +430,8 @@ struct server_prompt { j["tokens"] = tokens.to_json(); j["n_kept_prompt"] = n_kept_prompt; j["n_discarded_prompt"] = n_discarded_prompt; + j["pos_min"] = pos_min; + j["pos_max"] = pos_max; return j; } @@ -422,6 +440,8 @@ struct server_prompt { n_kept_prompt = j.value("n_kept_prompt", 0); n_discarded_prompt = j.value("n_discarded_prompt", 0); n_kept_prompt = j.value("n_kept_prompt", 0); + pos_min = j.value("pos_min", -1); + pos_max = j.value("pos_max", -1); } }; @@ -446,7 +466,7 @@ struct server_prompt_cache { server_prompt* alloc(const server_prompt& prompt, size_t state_size); - bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot); + bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot, size_t min_reusable_prefix, float min_reusable_fraction); void update(); }; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e7e556349d..ada5bfbde8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -553,6 +553,7 @@ int main(int argc, char ** argv) { // Necessary similarity of prompt for slot selection ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; ctx_server.cache_ram_n_min = params.cache_ram_n_min; + ctx_server.cache_ram_reuse_n_min = params.cache_ram_reuse_n_min; ctx_server.cache_ram_similarity = params.cache_ram_similarity; #ifdef SQLITE3_MODERN_CPP_SUPPORT auto db_handle = std::make_shared(params.sql_save_file); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 878d4a3403..325261624c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -202,6 +202,14 @@ llama_build_and_test( ) llama_build_and_test(test-regex-partial.cpp) +llama_build_and_test(test-server-prompt-cache.cpp) +target_link_libraries(test-server-prompt-cache PRIVATE mtmd) +target_include_directories(test-server-prompt-cache PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server + ${CMAKE_CURRENT_SOURCE_DIR}/../examples/mtmd + ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + # llama_target_and_test(test-opt.cpp) # SLOW llama_target_and_test(test-model-load-cancel.cpp LABEL "model") @@ -212,5 +220,3 @@ llama_target_and_test(test-autorelease.cpp LABEL "model") get_filename_component(TEST_TARGET test-c.c NAME_WE) add_executable(${TEST_TARGET} test-c.c) target_link_libraries(${TEST_TARGET} PRIVATE llama) - - diff --git a/tests/test-server-prompt-cache.cpp b/tests/test-server-prompt-cache.cpp new file mode 100644 index 0000000000..7637b02b37 --- /dev/null +++ b/tests/test-server-prompt-cache.cpp @@ -0,0 +1,33 @@ +#include "server-task.h" + +#include + +int main() { + server_prompt prompt; + + prompt.pos_min = 100; + assert(!prompt.has_rewind_checkpoint(64)); + + server_prompt_checkpoint early; + early.pos_min = 0; + early.pos_max = 48; + early.pos_min_prompt = 0; + early.pos_max_prompt = 48; + prompt.checkpoints.push_back(early); + assert(prompt.has_rewind_checkpoint(64)); + + prompt.checkpoints.clear(); + server_prompt_checkpoint prompt_aligned; + prompt_aligned.pos_min = 120; + prompt_aligned.pos_max = 140; + prompt_aligned.pos_min_prompt = 40; + prompt_aligned.pos_max_prompt = 64; + prompt.checkpoints.push_back(prompt_aligned); + assert(!prompt.has_rewind_checkpoint(64)); + + prompt.checkpoints.clear(); + prompt.pos_min = 32; + assert(prompt.has_rewind_checkpoint(64)); + + return 0; +}