@@ -527,6 +527,12 @@ struct llama_server_context
527
527
}
528
528
}
529
529
530
+ // Enable reranking if embeddings are enabled
531
+ if (params.embedding ) {
532
+ params.reranking = true ;
533
+ LOG_INFO (" Reranking enabled (embeddings are enabled)" , {});
534
+ }
535
+
530
536
common_init_result common_init = common_init_from_params (params);
531
537
model = common_init.model .release ();
532
538
ctx = common_init.context .release ();
@@ -1424,11 +1430,11 @@ struct llama_server_context
1424
1430
1425
1431
float score = -1e6f; // Default score if we fail to get embeddings
1426
1432
1427
- if (!params.rerank )
1433
+ if (!params.reranking )
1428
1434
{
1429
1435
LOG_WARNING (" reranking disabled" , {
1430
- {" params.rerank " , params.rerank },
1431
- });
1436
+ {" params.reranking " , params.reranking },
1437
+ });
1432
1438
}
1433
1439
else
1434
1440
{
@@ -1455,7 +1461,7 @@ struct llama_server_context
1455
1461
res.result_json = json
1456
1462
{
1457
1463
{" score" , score},
1458
- {" tokens" , slot.n_prompt_tokens }
1464
+ {" tokens" , slot.num_prompt_tokens }
1459
1465
};
1460
1466
1461
1467
queue_results.send (res);
@@ -2547,7 +2553,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
2547
2553
json data = parse_options (true , request, llama);
2548
2554
const int task_id = llama.queue_tasks .get_new_id ();
2549
2555
llama.queue_results .add_waiting_task_id (task_id);
2550
- llama.request_completion (task_id, data, false , false , -1 );
2556
+ llama.request_completion (task_id, data, false , false , false , -1 );
2551
2557
while (true )
2552
2558
{
2553
2559
task_result result = llama.queue_results .recv (task_id);
@@ -2601,7 +2607,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
2601
2607
json data = parse_options (false , request, llama);
2602
2608
const int task_id = llama.queue_tasks .get_new_id ();
2603
2609
llama.queue_results .add_waiting_task_id (task_id);
2604
- llama.request_completion (task_id, data, false , false , -1 );
2610
+ llama.request_completion (task_id, data, false , false , false , -1 );
2605
2611
std::string completion_text;
2606
2612
task_result result = llama.queue_results .recv (task_id);
2607
2613
if (!result.error && result.stop ) {
@@ -2638,7 +2644,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
2638
2644
json data = parse_options (false , request, llama);
2639
2645
const int task_id = llama.queue_tasks .get_new_id ();
2640
2646
llama.queue_results .add_waiting_task_id (task_id);
2641
- llama.request_completion (task_id, { {" prompt" , data[" embeddings" ]}, { " n_predict" , 0 }, {" image_data" , " " } }, false , true , -1 );
2647
+ llama.request_completion (task_id, { {" prompt" , data[" embeddings" ]}, { " n_predict" , 0 }, {" image_data" , " " } }, false , true , false , -1 );
2642
2648
// get the result
2643
2649
task_result result = llama.queue_results .recv (task_id);
2644
2650
// std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
@@ -2670,6 +2676,46 @@ class BackendServiceImpl final : public backend::Backend::Service {
2670
2676
return grpc::Status::OK;
2671
2677
}
2672
2678
2679
+ grpc::Status Rerank (ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
2680
+ // Create a JSON object with the query and documents
2681
+ json data = {
2682
+ {" prompt" , request->query ()},
2683
+ {" documents" , request->documents ()},
2684
+ {" top_n" , request->top_n ()}
2685
+ };
2686
+
2687
+ // Generate a new task ID
2688
+ const int task_id = llama.queue_tasks .get_new_id ();
2689
+ llama.queue_results .add_waiting_task_id (task_id);
2690
+
2691
+ // Queue the task with reranking mode enabled
2692
+ llama.request_completion (task_id, data, false , false , true , -1 );
2693
+
2694
+ // Get the result
2695
+ task_result result = llama.queue_results .recv (task_id);
2696
+ llama.queue_results .remove_waiting_task_id (task_id);
2697
+
2698
+ if (!result.error && result.stop ) {
2699
+ // Set usage information
2700
+ backend::Usage* usage = rerankResult->mutable_usage ();
2701
+ usage->set_total_tokens (result.result_json .value (" tokens" , 0 ));
2702
+ usage->set_prompt_tokens (result.result_json .value (" tokens" , 0 ));
2703
+
2704
+ // Get the score from the result
2705
+ float score = result.result_json .value (" score" , 0 .0f );
2706
+
2707
+ // Create document results for each input document
2708
+ for (int i = 0 ; i < request->documents_size (); i++) {
2709
+ backend::DocumentResult* doc_result = rerankResult->add_results ();
2710
+ doc_result->set_index (i);
2711
+ doc_result->set_text (request->documents (i));
2712
+ doc_result->set_relevance_score (score);
2713
+ }
2714
+ }
2715
+
2716
+ return grpc::Status::OK;
2717
+ }
2718
+
2673
2719
grpc::Status GetMetrics (ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
2674
2720
llama_client_slot* active_slot = llama.get_active_slot ();
2675
2721
0 commit comments