@@ -231,6 +231,7 @@ static void params_parse(const backend::ModelOptions* request,
231
231
params.n_parallel = 1 ;
232
232
}
233
233
234
+
234
235
const char *llama_grpc_servers = std::getenv (" LLAMACPP_GRPC_SERVERS" );
235
236
if (llama_grpc_servers != NULL ) {
236
237
add_rpc_devices (std::string (llama_grpc_servers));
@@ -291,6 +292,7 @@ static void params_parse(const backend::ModelOptions* request,
291
292
params.ctx_shift = false ; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
292
293
293
294
params.embedding = request->embeddings ();
295
+ params.reranking = request->reranking ();
294
296
295
297
if (request->ropescaling () == " none" ) { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
296
298
else if (request->ropescaling () == " yarn" ) { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
@@ -791,6 +793,93 @@ class BackendServiceImpl final : public backend::Backend::Service {
791
793
return grpc::Status::OK;
792
794
}
793
795
796
+ grpc::Status Rerank (ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
797
+ if (!ctx_server.params_base .reranking || ctx_server.params_base .embedding ) {
798
+ return grpc::Status (grpc::StatusCode::UNIMPLEMENTED, " This server does not support reranking. Start it with `--reranking` and without `--embedding`" );
799
+ }
800
+
801
+ // Validate request
802
+ if (request->query ().empty ()) {
803
+ return grpc::Status (grpc::StatusCode::INVALID_ARGUMENT, " \" query\" must be provided" );
804
+ }
805
+
806
+ if (request->documents_size () == 0 ) {
807
+ return grpc::Status (grpc::StatusCode::INVALID_ARGUMENT, " \" documents\" must be a non-empty string array" );
808
+ }
809
+
810
+ // Tokenize the query
811
+ llama_tokens tokenized_query = tokenize_input_prompts (ctx_server.vocab , request->query (), /* add_special */ false , true )[0 ];
812
+
813
+ // Create and queue the task
814
+ json responses = json::array ();
815
+ bool error = false ;
816
+ std::unordered_set<int > task_ids;
817
+ {
818
+ std::vector<server_task> tasks;
819
+ std::vector<std::string> documents;
820
+ for (int i = 0 ; i < request->documents_size (); i++) {
821
+ documents.push_back (request->documents (i));
822
+ }
823
+
824
+ auto tokenized_docs = tokenize_input_prompts (ctx_server.vocab , documents, /* add_special */ false , true );
825
+ tasks.reserve (tokenized_docs.size ());
826
+ for (size_t i = 0 ; i < tokenized_docs.size (); i++) {
827
+ auto tmp = format_rerank (ctx_server.vocab , tokenized_query, tokenized_docs[i]);
828
+ server_task task = server_task (SERVER_TASK_TYPE_RERANK);
829
+ task.id = ctx_server.queue_tasks .get_new_id ();
830
+ task.index = i;
831
+ task.prompt_tokens = server_tokens (tmp, ctx_server.mctx != nullptr );
832
+ tasks.push_back (std::move (task));
833
+ }
834
+
835
+ task_ids = server_task::get_list_id (tasks);
836
+ ctx_server.queue_results .add_waiting_tasks (tasks);
837
+ ctx_server.queue_tasks .post (std::move (tasks));
838
+ }
839
+
840
+ // Get the results
841
+ ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
842
+ for (auto & res : results) {
843
+ GGML_ASSERT (dynamic_cast <server_task_result_rerank*>(res.get ()) != nullptr );
844
+ responses.push_back (res->to_json ());
845
+ }
846
+ }, [&](const json & error_data) {
847
+ error = true ;
848
+ }, [&]() {
849
+ return false ;
850
+ });
851
+
852
+ ctx_server.queue_results .remove_waiting_task_ids (task_ids);
853
+
854
+ if (error) {
855
+ return grpc::Status (grpc::StatusCode::INTERNAL, " Error in receiving results" );
856
+ }
857
+
858
+ // Set usage information
859
+ backend::Usage* usage = rerankResult->mutable_usage ();
860
+ int total_tokens = 0 ;
861
+ int prompt_tokens = 0 ;
862
+
863
+ // Create document results
864
+ for (const auto & response : responses) {
865
+ backend::DocumentResult* doc_result = rerankResult->add_results ();
866
+ doc_result->set_index (response.value (" index" , 0 ));
867
+ doc_result->set_text (request->documents (response.value (" index" , 0 )));
868
+ doc_result->set_relevance_score (response.value (" score" , 0 .0f ));
869
+
870
+ // Add tokens evaluated for this document
871
+ int tokens_evaluated = response.value (" tokens_evaluated" , 0 );
872
+ total_tokens += tokens_evaluated;
873
+ prompt_tokens += tokens_evaluated;
874
+ }
875
+
876
+ // Set the total tokens in usage
877
+ usage->set_total_tokens (total_tokens);
878
+ usage->set_prompt_tokens (prompt_tokens);
879
+
880
+ return grpc::Status::OK;
881
+ }
882
+
794
883
grpc::Status TokenizeString (ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
795
884
json body = parse_options (false , request);
796
885
body[" stream" ] = false ;
0 commit comments