Skip to content

Commit 3b0cf52

Browse files
authored
feat(llama.cpp): add reranking (#5396)
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent bac3022 commit 3b0cf52

File tree

5 files changed

+105
-0
lines changed

5 files changed

+105
-0
lines changed

backend/backend.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ message ModelOptions {
255255
string CacheTypeValue = 64;
256256

257257
repeated GrammarTrigger GrammarTriggers = 65;
258+
259+
bool Reranking = 71;
258260
}
259261

260262
message Result {

backend/cpp/llama/grpc-server.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ static void params_parse(const backend::ModelOptions* request,
231231
params.n_parallel = 1;
232232
}
233233

234+
234235
const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS");
235236
if (llama_grpc_servers != NULL) {
236237
add_rpc_devices(std::string(llama_grpc_servers));
@@ -291,6 +292,7 @@ static void params_parse(const backend::ModelOptions* request,
291292
params.ctx_shift = false; // We control context-shifting in any case (and we disable it as it could just lead to infinite loops)
292293

293294
params.embedding = request->embeddings();
295+
params.reranking = request->reranking();
294296

295297
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
296298
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
@@ -791,6 +793,93 @@ class BackendServiceImpl final : public backend::Backend::Service {
791793
return grpc::Status::OK;
792794
}
793795

796+
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
797+
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
798+
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
799+
}
800+
801+
// Validate request
802+
if (request->query().empty()) {
803+
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided");
804+
}
805+
806+
if (request->documents_size() == 0) {
807+
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
808+
}
809+
810+
// Tokenize the query
811+
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, request->query(), /* add_special */ false, true)[0];
812+
813+
// Create and queue the task
814+
json responses = json::array();
815+
bool error = false;
816+
std::unordered_set<int> task_ids;
817+
{
818+
std::vector<server_task> tasks;
819+
std::vector<std::string> documents;
820+
for (int i = 0; i < request->documents_size(); i++) {
821+
documents.push_back(request->documents(i));
822+
}
823+
824+
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
825+
tasks.reserve(tokenized_docs.size());
826+
for (size_t i = 0; i < tokenized_docs.size(); i++) {
827+
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
828+
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
829+
task.id = ctx_server.queue_tasks.get_new_id();
830+
task.index = i;
831+
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
832+
tasks.push_back(std::move(task));
833+
}
834+
835+
task_ids = server_task::get_list_id(tasks);
836+
ctx_server.queue_results.add_waiting_tasks(tasks);
837+
ctx_server.queue_tasks.post(std::move(tasks));
838+
}
839+
840+
// Get the results
841+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
842+
for (auto & res : results) {
843+
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
844+
responses.push_back(res->to_json());
845+
}
846+
}, [&](const json & error_data) {
847+
error = true;
848+
}, [&]() {
849+
return false;
850+
});
851+
852+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
853+
854+
if (error) {
855+
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
856+
}
857+
858+
// Set usage information
859+
backend::Usage* usage = rerankResult->mutable_usage();
860+
int total_tokens = 0;
861+
int prompt_tokens = 0;
862+
863+
// Create document results
864+
for (const auto& response : responses) {
865+
backend::DocumentResult* doc_result = rerankResult->add_results();
866+
doc_result->set_index(response.value("index", 0));
867+
doc_result->set_text(request->documents(response.value("index", 0)));
868+
doc_result->set_relevance_score(response.value("score", 0.0f));
869+
870+
// Add tokens evaluated for this document
871+
int tokens_evaluated = response.value("tokens_evaluated", 0);
872+
total_tokens += tokens_evaluated;
873+
prompt_tokens += tokens_evaluated;
874+
}
875+
876+
// Set the total tokens in usage
877+
usage->set_total_tokens(total_tokens);
878+
usage->set_prompt_tokens(prompt_tokens);
879+
880+
return grpc::Status::OK;
881+
}
882+
794883
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
795884
json body = parse_options(false, request);
796885
body["stream"] = false;

backend/go/llm/llama/llama.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
5858
if opts.Embeddings {
5959
llamaOpts = append(llamaOpts, llama.EnableEmbeddings)
6060
}
61+
if opts.Reranking {
62+
llamaOpts = append(llamaOpts, llama.EnableReranking)
63+
}
6164
if opts.NGPULayers != 0 {
6265
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers)))
6366
}

core/backend/options.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
9494
lowVRAM = *c.LowVRAM
9595
}
9696

97+
reranking := false
98+
if c.Reranking != nil {
99+
reranking = *c.Reranking
100+
}
101+
97102
mmap := false
98103
if c.MMap != nil {
99104
mmap = *c.MMap
@@ -178,6 +183,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
178183
RopeFreqScale: c.RopeFreqScale,
179184
NUMA: c.NUMA,
180185
Embeddings: embeddings,
186+
Reranking: reranking,
181187
LowVRAM: lowVRAM,
182188
NGPULayers: int32(nGPULayers),
183189
MMap: mmap,

core/config/backend_config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ type LLMConfig struct {
120120
MMap *bool `yaml:"mmap"`
121121
MMlock *bool `yaml:"mmlock"`
122122
LowVRAM *bool `yaml:"low_vram"`
123+
Reranking *bool `yaml:"reranking"`
123124
Grammar string `yaml:"grammar"`
124125
StopWords []string `yaml:"stopwords"`
125126
Cutstrings []string `yaml:"cutstrings"`
@@ -372,6 +373,10 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {
372373
cfg.Embeddings = &falseV
373374
}
374375

376+
if cfg.Reranking == nil {
377+
cfg.Reranking = &falseV
378+
}
379+
375380
if threads == 0 {
376381
// Threads can't be 0
377382
threads = 4

0 commit comments

Comments
 (0)