Skip to content

Commit 5e0a1e2

Browse files
committed
wire to grpc
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 01e2e3d commit 5e0a1e2

File tree

1 file changed

+53
-7
lines changed

1 file changed

+53
-7
lines changed

backend/cpp/llama/grpc-server.cpp

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,12 @@ struct llama_server_context
527527
}
528528
}
529529

530+
// Enable reranking if embeddings are enabled
531+
if (params.embedding) {
532+
params.reranking = true;
533+
LOG_INFO("Reranking enabled (embeddings are enabled)", {});
534+
}
535+
530536
common_init_result common_init = common_init_from_params(params);
531537
model = common_init.model.release();
532538
ctx = common_init.context.release();
@@ -1424,11 +1430,11 @@ struct llama_server_context
14241430

14251431
float score = -1e6f; // Default score if we fail to get embeddings
14261432

1427-
if (!params.rerank)
1433+
if (!params.reranking)
14281434
{
14291435
LOG_WARNING("reranking disabled", {
1430-
{"params.rerank", params.rerank},
1431-
});
1436+
{"params.reranking", params.reranking},
1437+
});
14321438
}
14331439
else
14341440
{
@@ -1455,7 +1461,7 @@ struct llama_server_context
14551461
res.result_json = json
14561462
{
14571463
{"score", score},
1458-
{"tokens", slot.n_prompt_tokens}
1464+
{"tokens", slot.num_prompt_tokens}
14591465
};
14601466

14611467
queue_results.send(res);
@@ -2547,7 +2553,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
25472553
json data = parse_options(true, request, llama);
25482554
const int task_id = llama.queue_tasks.get_new_id();
25492555
llama.queue_results.add_waiting_task_id(task_id);
2550-
llama.request_completion(task_id, data, false, false, -1);
2556+
llama.request_completion(task_id, data, false, false, false, -1);
25512557
while (true)
25522558
{
25532559
task_result result = llama.queue_results.recv(task_id);
@@ -2601,7 +2607,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
26012607
json data = parse_options(false, request, llama);
26022608
const int task_id = llama.queue_tasks.get_new_id();
26032609
llama.queue_results.add_waiting_task_id(task_id);
2604-
llama.request_completion(task_id, data, false, false, -1);
2610+
llama.request_completion(task_id, data, false, false, false, -1);
26052611
std::string completion_text;
26062612
task_result result = llama.queue_results.recv(task_id);
26072613
if (!result.error && result.stop) {
@@ -2638,7 +2644,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
26382644
json data = parse_options(false, request, llama);
26392645
const int task_id = llama.queue_tasks.get_new_id();
26402646
llama.queue_results.add_waiting_task_id(task_id);
2641-
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1);
2647+
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, false, -1);
26422648
// get the result
26432649
task_result result = llama.queue_results.recv(task_id);
26442650
//std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
@@ -2670,6 +2676,46 @@ class BackendServiceImpl final : public backend::Backend::Service {
26702676
return grpc::Status::OK;
26712677
}
26722678

2679+
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
2680+
// Create a JSON object with the query and documents
2681+
json data = {
2682+
{"prompt", request->query()},
2683+
{"documents", request->documents()},
2684+
{"top_n", request->top_n()}
2685+
};
2686+
2687+
// Generate a new task ID
2688+
const int task_id = llama.queue_tasks.get_new_id();
2689+
llama.queue_results.add_waiting_task_id(task_id);
2690+
2691+
// Queue the task with reranking mode enabled
2692+
llama.request_completion(task_id, data, false, false, true, -1);
2693+
2694+
// Get the result
2695+
task_result result = llama.queue_results.recv(task_id);
2696+
llama.queue_results.remove_waiting_task_id(task_id);
2697+
2698+
if (!result.error && result.stop) {
2699+
// Set usage information
2700+
backend::Usage* usage = rerankResult->mutable_usage();
2701+
usage->set_total_tokens(result.result_json.value("tokens", 0));
2702+
usage->set_prompt_tokens(result.result_json.value("tokens", 0));
2703+
2704+
// Get the score from the result
2705+
float score = result.result_json.value("score", 0.0f);
2706+
2707+
// Create document results for each input document
2708+
for (int i = 0; i < request->documents_size(); i++) {
2709+
backend::DocumentResult* doc_result = rerankResult->add_results();
2710+
doc_result->set_index(i);
2711+
doc_result->set_text(request->documents(i));
2712+
doc_result->set_relevance_score(score);
2713+
}
2714+
}
2715+
2716+
return grpc::Status::OK;
2717+
}
2718+
26732719
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
26742720
llama_client_slot* active_slot = llama.get_active_slot();
26752721

0 commit comments

Comments
 (0)