Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2689,6 +2689,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.cache_ram_n_min = std::stoi(argv[i]);
return true;
}
if (arg == "--cache-ram-reuse-n-min") {
CHECK_ARG
params.cache_ram_reuse_n_min = std::stoi(argv[i]);
return true;
}
if (arg == "--pos") {
CHECK_ARG
params.i_pos = std::stoi(argv[i]);
Expand Down Expand Up @@ -2887,6 +2892,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-cram, --cache-ram N", "set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)",params.cache_ram_mib });
options.push_back({ "*", "-crs, --cache-ram-similarity N", "max of similarity of prompt tokens to cache tokens that triggers prompt cache (default: %.2f).",params.cache_ram_similarity });
options.push_back({ "*", "-cram-n-min --cache-ram-n-min N", "minimum number of the cached tokens that triggers prompt cache (default: %d).", params.cache_ram_n_min });
options.push_back({ "*", "--cache-ram-reuse-n-min N", "minimum reusable common-prefix tokens required to restore from prompt cache (default: %d).", params.cache_ram_reuse_n_min });
options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch });
options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch });
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ struct gpt_params {
int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
int32_t cache_ram_n_min = 0; // min number of tokens required to save in the ram
int32_t cache_ram_reuse_n_min = 0; // min reusable common-prefix tokens required to restore from ram cache
float cache_ram_similarity = 0.5f; // similarity of tokens to cached tokens

// batched-bench params
Expand Down
5 changes: 5 additions & 0 deletions docs/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ Note: When the available memory is very limited, turn this option off (`-cram 0`
| `-cram, --cache-ram N` | Set the maximum cache size in MiB | 8192 | -1 = no limit, 0 = disable Very useful when the variations of the same prompt are re-sent to the model (coding agents, etc.). [PR 954](https://github.com/ikawrakow/ik_llama.cpp/pull/954) |
| `-crs, --cache-ram-similarity N` | Max similarity of prompt tokens to cache tokens that triggers prompt cache | 0.50 | |
| `-cram-n-min, --cache-ram-n-min N` | Minimum number of cached tokens that triggers prompt cache | 0 | |
| `--cache-ram-reuse-n-min N` | Minimum reusable common-prefix tokens required to restore from prompt cache | 0 | |

When restoring a prompt from RAM cache, the server ranks candidates by reusable common prefix first and uses similarity as a tie-breaker. A candidate is skipped if its reusable prefix is below `--cache-ram-reuse-n-min`, if the reusable fraction is below `--cache-ram-similarity`, or if the cached KV range cannot be rewound safely for SWA/sliding-window attention.

For SWA models, cached prompt entries track the KV position range saved with the state. If the active KV window starts after the reusable prefix, restore is only allowed when an earlier context checkpoint can rewind to that prefix. Otherwise the candidate is rejected before loading state, so a bad RAM-cache hit does not mutate the slot and then force full prompt re-processing. When context checkpoints are capped, the server keeps the earliest checkpoint as a rewind anchor when possible.

## Sampling

Expand Down
18 changes: 12 additions & 6 deletions examples/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,8 @@ void server_slot::prompt_save(server_prompt_cache& prompt_cache) const {
llama_state_seq_get_data(ctx, cur->data.data(), cur_size, id, 0);
}

void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens) {
bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id);
void server_slot::prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens, size_t min_reusable_prefix, float min_reusable_fraction) {
bool res = prompt_cache.load(server_cached_prompt, tokens, ctx, id, min_reusable_prefix, min_reusable_fraction);
if (!res) {
LLAMA_LOG_INFO("failed to load prompt from cache\n");
}
Expand Down Expand Up @@ -1031,6 +1031,8 @@ void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, se
slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt;
slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt;
slot.server_cached_prompt.think_tokens = slot.params.think_tokens;
slot.server_cached_prompt.pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id);
slot.server_cached_prompt.pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id);
}

server_slot* server_context::get_available_slot(const server_task& task) {
Expand Down Expand Up @@ -1146,7 +1148,7 @@ server_slot* server_context::get_available_slot(const server_task& task) {
const int64_t t_start = ggml_time_us();
copy_data_to_cached_prompt(tokens, *ret);

ret->prompt_load(*prompt_cache, task.tokens);
ret->prompt_load(*prompt_cache, task.tokens, (size_t) std::max(0, cache_ram_reuse_n_min), cache_ram_similarity);
prompt_cache->update();

ret->cache_tokens = ret->server_cached_prompt.tokens.clone(); // recover cache tokens
Expand Down Expand Up @@ -3719,13 +3721,17 @@ bool server_context::create_checkpoint(server_slot & slot) {
if (do_checkpoint) {
const int64_t t_start = ggml_time_us();
while (slot.server_cached_prompt.checkpoints.size() >= (size_t)params_base.ctx_checkpoints_n) {
// make room for the new checkpoint, if needed
const auto & cur = slot.server_cached_prompt.checkpoints.front();
// Preserve the earliest checkpoint as an SWA rewind anchor when the cap allows it.
auto erase_it = slot.server_cached_prompt.checkpoints.begin();
if (params_base.ctx_checkpoints_n > 1 && slot.server_cached_prompt.checkpoints.size() > 1) {
++erase_it;
}
const auto & cur = *erase_it;

SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float)cur.data.size() / 1024 / 1024);

slot.server_cached_prompt.checkpoints.erase(slot.server_cached_prompt.checkpoints.begin());
slot.server_cached_prompt.checkpoints.erase(erase_it);
}

auto & cur = slot.server_cached_prompt.checkpoints.emplace_back();
Expand Down
3 changes: 2 additions & 1 deletion examples/server/server-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct server_slot {

void prompt_save(server_prompt_cache& prompt_cache) const;

void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens);
void prompt_load(server_prompt_cache& prompt_cache, const server_tokens& tokens, size_t min_reusable_prefix, float min_reusable_fraction);

size_t checkpoint_pos = 0;
bool do_checkpoint = false;
Expand Down Expand Up @@ -297,6 +297,7 @@ struct server_context {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
int32_t cache_ram_n_min = 0;
int32_t cache_ram_reuse_n_min = 0;
float cache_ram_similarity = 0.5f;

~server_context();
Expand Down
32 changes: 23 additions & 9 deletions examples/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ size_t server_prompt_cache::n_tokens() const {

}

bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot) {
bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot, size_t min_reusable_prefix, float min_reusable_fraction) {
thinking_tokens think_tokens;
for (auto it = states.begin(); it != states.end(); ++it) {
think_tokens = it->think_tokens;
Expand All @@ -1086,37 +1086,49 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token
tokens_new_ex = tokens_new.get_tokens_exclude_think(ctx, think_tokens);
}
else {
prompt_tokens = std::move(prompt.tokens);
prompt_tokens = prompt.tokens.clone();
tokens_new_ex = tokens_new.clone();
}
const auto lcp_best = prompt_tokens.get_common_prefix(ctx, tokens_new_ex);
float f_keep_best = float(lcp_best.second) / prompt_tokens.size();
float f_keep_best = prompt_tokens.empty() ? 0.0f : float(lcp_best.first) / prompt_tokens.size();
float sim_best = prompt_tokens.get_tokens_similarity(ctx, tokens_new_ex, prompt.n_kept_prompt, prompt.n_discarded_prompt);
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, prompt.n_kept_prompt, prompt.n_discarded_prompt);
size_t lcp_best_tokens = lcp_best.first;
LLAMA_LOG_INFO(" - looking for better prompt, base f_keep = %.3f, sim = %.3f, lcp = %zu, min_reusable_prefix = %zu, min_reusable_fraction = %.3f, n_keep = %d, n_discarded_prompt = %d\n",
f_keep_best, sim_best, lcp_best_tokens, min_reusable_prefix, min_reusable_fraction, prompt.n_kept_prompt, prompt.n_discarded_prompt);

auto it_best = states.end();

// find the most similar cached prompt, that would also preserve the most context
// find the most similar viable cached prompt; common prefix breaks ties
for (auto it = states.begin(); it != states.end(); ++it) {
server_tokens tokens;
if (think_tokens.exclude) {
tokens = it->tokens.get_tokens_exclude_think(ctx, think_tokens);
}
else {
tokens = std::move(it->tokens);
tokens = it->tokens.clone();
}
const auto lcp_cur = tokens.get_common_prefix(ctx, tokens_new_ex);
const float f_keep_cur = float(lcp_cur.first) / tokens.size();
const float f_keep_cur = tokens.empty() ? 0.0f : float(lcp_cur.first) / tokens.size();
const float sim_cur = tokens.get_tokens_similarity(ctx, tokens_new_ex, it->n_kept_prompt, it->n_discarded_prompt);
if (sim_best < sim_cur) {
const bool prefix_ok = lcp_cur.first >= min_reusable_prefix;
const bool fraction_ok = f_keep_cur >= min_reusable_fraction;
const bool rewind_ok = it->has_rewind_checkpoint(lcp_cur.first);
if (!prefix_ok || !fraction_ok || !rewind_ok) {
LLAMA_LOG_INFO(" - skipping prompt cache candidate: lcp = %zu, f_keep = %.3f, sim = %.3f, pos_min = %d, checkpoints = %zu, prefix_ok = %d, fraction_ok = %d, rewind_ok = %d\n",
lcp_cur.first, f_keep_cur, sim_cur, it->pos_min, it->checkpoints.size(), (int) prefix_ok, (int) fraction_ok, (int) rewind_ok);
continue;
}
if (sim_best < sim_cur || (sim_best == sim_cur && lcp_best_tokens < lcp_cur.first)) {
f_keep_best = f_keep_cur;
sim_best = sim_cur;
lcp_best_tokens = lcp_cur.first;
it_best = it;
}
}

if (it_best != states.end()) {
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, n_keep = %d, n_discarded_prompt = %d\n", f_keep_best, sim_best, it_best->n_kept_prompt, it_best->n_discarded_prompt);
LLAMA_LOG_INFO(" - found better prompt with f_keep = %.3f, sim = %.3f, lcp = %zu, pos_min = %d, checkpoints = %zu, n_keep = %d, n_discarded_prompt = %d\n",
f_keep_best, sim_best, lcp_best_tokens, it_best->pos_min, it_best->checkpoints.size(), it_best->n_kept_prompt, it_best->n_discarded_prompt);
const size_t size = it_best->data.size();
const size_t n = llama_state_seq_set_data(ctx, it_best->data.data(), size, id_slot, 0);
if (n != size) {
Expand Down Expand Up @@ -1182,6 +1194,8 @@ server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t st
/*.n_keep =*/ prompt.n_kept_prompt,
/*.n_discarded_prompt =*/ prompt.n_discarded_prompt,
/*.think_tokens =*/ prompt.think_tokens,
/*.pos_min =*/ prompt.pos_min,
/*.pos_max =*/ prompt.pos_max,
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};
Expand Down
22 changes: 21 additions & 1 deletion examples/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ struct server_prompt {
int n_kept_prompt;
int n_discarded_prompt;
thinking_tokens think_tokens;
llama_pos pos_min = -1;
llama_pos pos_max = -1;

std::vector<uint8_t> data;

Expand All @@ -397,12 +399,26 @@ struct server_prompt {
return tokens.size();
}

bool has_rewind_checkpoint(size_t lcp) const {
if (pos_min < 0 || pos_min <= (llama_pos) lcp) {
return true;
}
for (const auto & checkpoint : checkpoints) {
if (checkpoint.pos_max <= (llama_pos) lcp) {
return true;
}
}
return false;
}

server_prompt clone() const {
return server_prompt{
tokens.clone(),
n_kept_prompt,
n_discarded_prompt,
think_tokens,
pos_min,
pos_max,
data,
checkpoints
};
Expand All @@ -414,6 +430,8 @@ struct server_prompt {
j["tokens"] = tokens.to_json();
j["n_kept_prompt"] = n_kept_prompt;
j["n_discarded_prompt"] = n_discarded_prompt;
j["pos_min"] = pos_min;
j["pos_max"] = pos_max;
return j;
}

Expand All @@ -422,6 +440,8 @@ struct server_prompt {
n_kept_prompt = j.value<llama_pos>("n_kept_prompt", 0);
n_discarded_prompt = j.value<llama_pos>("n_discarded_prompt", 0);
n_kept_prompt = j.value<llama_pos>("n_kept_prompt", 0);
pos_min = j.value<llama_pos>("pos_min", -1);
pos_max = j.value<llama_pos>("pos_max", -1);
}
};

Expand All @@ -446,7 +466,7 @@ struct server_prompt_cache {

server_prompt* alloc(const server_prompt& prompt, size_t state_size);

bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot);
bool load(server_prompt& prompt, const server_tokens& tokens_new, llama_context* ctx, int32_t id_slot, size_t min_reusable_prefix, float min_reusable_fraction);

void update();
};
1 change: 1 addition & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ int main(int argc, char ** argv) {
// Necessary similarity of prompt for slot selection
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
ctx_server.cache_ram_n_min = params.cache_ram_n_min;
ctx_server.cache_ram_reuse_n_min = params.cache_ram_reuse_n_min;
ctx_server.cache_ram_similarity = params.cache_ram_similarity;
#ifdef SQLITE3_MODERN_CPP_SUPPORT
auto db_handle = std::make_shared<DatabaseHandle>(params.sql_save_file);
Expand Down
10 changes: 8 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ llama_build_and_test(
)
llama_build_and_test(test-regex-partial.cpp)

llama_build_and_test(test-server-prompt-cache.cpp)
target_link_libraries(test-server-prompt-cache PRIVATE mtmd)
target_include_directories(test-server-prompt-cache PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../examples/server
${CMAKE_CURRENT_SOURCE_DIR}/../examples/mtmd
${CMAKE_CURRENT_SOURCE_DIR}/..
)

# llama_target_and_test(test-opt.cpp) # SLOW

llama_target_and_test(test-model-load-cancel.cpp LABEL "model")
Expand All @@ -212,5 +220,3 @@ llama_target_and_test(test-autorelease.cpp LABEL "model")
get_filename_component(TEST_TARGET test-c.c NAME_WE)
add_executable(${TEST_TARGET} test-c.c)
target_link_libraries(${TEST_TARGET} PRIVATE llama)


33 changes: 33 additions & 0 deletions tests/test-server-prompt-cache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "server-task.h"

#include <cassert>

int main() {
server_prompt prompt;

prompt.pos_min = 100;
assert(!prompt.has_rewind_checkpoint(64));

server_prompt_checkpoint early;
early.pos_min = 0;
early.pos_max = 48;
early.pos_min_prompt = 0;
early.pos_max_prompt = 48;
prompt.checkpoints.push_back(early);
assert(prompt.has_rewind_checkpoint(64));

prompt.checkpoints.clear();
server_prompt_checkpoint prompt_aligned;
prompt_aligned.pos_min = 120;
prompt_aligned.pos_max = 140;
prompt_aligned.pos_min_prompt = 40;
prompt_aligned.pos_max_prompt = 64;
prompt.checkpoints.push_back(prompt_aligned);
assert(!prompt.has_rewind_checkpoint(64));

prompt.checkpoints.clear();
prompt.pos_min = 32;
assert(prompt.has_rewind_checkpoint(64));

return 0;
}