Skip to content

Support for Re-ranking Retrieval Optimization in Advanced RAGΒ #1366

Open
@kevintsai1202

Description

@kevintsai1202

I want to add a re-ranking feature in RAG but have encountered some issues that need improvement.
Initially, I planned to add an Advisor for handling re-ranking after QuestionAnswerAdvisor, but I checked that QuestionAnswerAdvisor already adds the search results context to .withUserParams(advisedUserParams) in the code.

Although a subsequent Advisor can override the UserParams, when using a re-ranking model, the process typically involves vector search to retrieve 50-100 chunks, followed by re-ranking. If all these results are added to UserParams, it would waste both time and memory, even though the later Advisor can override them.

I hope the RAG-related Advisors can be improved in the following ways:

  1. Retrieve the document contents and attach them to UserParams only after all Advisors have completed processing.
  2. Add support for a re-ranking model, which is a common retrieval optimization technique.
  3. Typically, a re-ranking model only returns the index and similarity scores. However, it would be ideal if the re-ranking process still returns a List<Documents> containing the metadata.

Below is the code I have modified.

public class RerankRAGAdvisor implements RequestResponseAdvisor {
	private static final String DEFAULT_USER_TEXT_ADVISE = """
			Context information is below.
			---------------------
			{question_answer_context}
			---------------------
			Given the context and provided history information and not prior knowledge,
			reply to the user comment. If the answer is not in the context, inform
			the user that you can't answer the question.
			""";
	private final VectorStore vectorStore;
	private final String userTextAdvise;
	private final SearchRequest searchRequest;
	public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";
	public static final String FILTER_EXPRESSION = "qa_filter_expression";
	private final RestClient restClient;
	private String apiKey= System.getenv("VOYAGE_KEY");
	// Re-ranking API return data
	@JsonInclude(Include.NON_NULL)
	public record RerankList(
			@JsonProperty("object") String object,
			@JsonProperty("data") List<Rerank> data,
			@JsonProperty("model") String model,
			@JsonProperty("usage") Usage usage) {
	}
	@JsonInclude(Include.NON_NULL)
	public record Rerank(
			@JsonProperty("index") Integer index,
			@JsonProperty("relevance_score") float relevanceScore,
			@JsonProperty("document") String document) {
	}
	
	//Re-ranking API
	public ResponseEntity<RerankList> rerankDocuments(String query, List<Document> documents) {
        String url = "https://api.voyageai.com/v1/rerank";
        String bearerStr = "Bearer "+this.apiKey;
        Map<String, Object> requestBody = new HashMap<>();
        requestBody.put("query", query);			
        requestBody.put("model", "rerank-1");		
        requestBody.put("top_k", 5);				
        requestBody.put("return_documents",true);
        requestBody.put("documents", documents.stream().map(Document::getContent).toList());

        return restClient.post()
            .uri(url)
            .contentType(MediaType.APPLICATION_JSON)
            .header("Authorization",bearerStr)
            .body(requestBody)
            .retrieve()
            .toEntity(new ParameterizedTypeReference<>() {
			});
    }
	public RerankRAGAdvisor(RestClient restClient, VectorStore vectorStore) {
		this(restClient, vectorStore, SearchRequest.defaults(), DEFAULT_USER_TEXT_ADVISE);
	}
	public RerankRAGAdvisor(RestClient restClient, VectorStore vectorStore, SearchRequest searchRequest) {
		this(restClient, vectorStore, searchRequest, DEFAULT_USER_TEXT_ADVISE);
	}
	public RerankRAGAdvisor(RestClient restClient, VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
		Assert.notNull(restClient, "The restClient must not be null!");
		Assert.notNull(vectorStore, "The vectorStore must not be null!");
		Assert.notNull(searchRequest, "The searchRequest must not be null!");
		Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");

		this.restClient = restClient;
		this.vectorStore = vectorStore;
		this.searchRequest = searchRequest;
		this.userTextAdvise = userTextAdvise;
	}
	@Override
	public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
		// 1. Advise the system text.
		String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;
		var searchRequestToUse = SearchRequest.from(this.searchRequest)
			.withQuery(request.userText())
			.withTopK(100)
			.withFilterExpression(doGetFilterExpression(context));
		// 2. Search for similar documents in the vector store.
		List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
		// 3. Re-ranking
		List<Rerank> rerankDocs = rerankDocuments(request.userText(), documents).getBody().data();
		context.put(RETRIEVED_DOCUMENTS, rerankDocs);
		// 4. Create the context from the documents.
		String documentContext = rerankDocs.stream()
			.map(Rerank::document)
			.collect(Collectors.joining(System.lineSeparator()));
		// 5. Advise the user parameters.
		Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
		advisedUserParams.put("question_answer_context", documentContext);
		AdvisedRequest advisedRequest = AdvisedRequest.from(request)
			.withUserText(advisedUserText)
			.withUserParams(advisedUserParams)
			.build();
		return advisedRequest;
	}
	@Override
	public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
		ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response);
		chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
		return chatResponseBuilder.build();
	}
	@Override
	public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
		return fluxResponse.map(cr -> {
			ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr);
			chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
			return chatResponseBuilder.build();
		});
	}
	protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {
		if (!context.containsKey(FILTER_EXPRESSION)
				|| !StringUtils.hasText(context.get(FILTER_EXPRESSION).toString())) {
			return this.searchRequest.getFilterExpression();
		}
		return new FilterExpressionTextParser().parse(context.get(FILTER_EXPRESSION).toString());
	}
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    RAGIssues related to Retrieval Augmented Generationadvisors

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions