Skip to content

Commit 932fc87

Browse files
timosalmtzolov
authored andcommitted
Take userText parameters into account for QuestionAnswerAdvisor's similarity search
Resolves #1234
1 parent 560315c commit 932fc87

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
3434
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
3535
import org.springframework.ai.chat.model.ChatResponse;
36+
import org.springframework.ai.chat.prompt.PromptTemplate;
3637
import org.springframework.ai.document.Document;
3738
import org.springframework.ai.model.Content;
3839
import org.springframework.ai.vectorstore.SearchRequest;
@@ -47,6 +48,7 @@
4748
* user text.
4849
*
4950
* @author Christian Tzolov
51+
* @author Timo Salm
5052
* @since 1.0.0
5153
*/
5254
public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
@@ -106,7 +108,7 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
106108
* @param vectorStore The vector store to use
107109
* @param searchRequest The search request defined using the portable filter
108110
* expression syntax
109-
* @param userTextAdvise the user text to append to the existing user prompt. The text
111+
* @param userTextAdvise The user text to append to the existing user prompt. The text
110112
* should contain a placeholder named "question_answer_context".
111113
*/
112114
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
@@ -119,9 +121,9 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
119121
* @param vectorStore The vector store to use
120122
* @param searchRequest The search request defined using the portable filter
121123
* expression syntax
122-
* @param userTextAdvise the user text to append to the existing user prompt. The text
124+
* @param userTextAdvise The user text to append to the existing user prompt. The text
123125
* should contain a placeholder named "question_answer_context".
124-
* @param protectFromBlocking if true the advisor will protect the execution from
126+
* @param protectFromBlocking If true the advisor will protect the execution from
125127
* blocking threads. If false the advisor will not protect the execution from blocking
126128
* threads. This is useful when the advisor is used in a non-blocking environment. It
127129
* is true by default.
@@ -137,13 +139,13 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
137139
* @param vectorStore The vector store to use
138140
* @param searchRequest The search request defined using the portable filter
139141
* expression syntax
140-
* @param userTextAdvise the user text to append to the existing user prompt. The text
142+
* @param userTextAdvise The user text to append to the existing user prompt. The text
141143
* should contain a placeholder named "question_answer_context".
142-
* @param protectFromBlocking if true the advisor will protect the execution from
144+
* @param protectFromBlocking If true the advisor will protect the execution from
143145
* blocking threads. If false the advisor will not protect the execution from blocking
144146
* threads. This is useful when the advisor is used in a non-blocking environment. It
145147
* is true by default.
146-
* @param order the order of the advisor.
148+
* @param order The order of the advisor.
147149
*/
148150
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
149151
boolean protectFromBlocking, int order) {
@@ -213,16 +215,17 @@ private AdvisedRequest before(AdvisedRequest request) {
213215
// 1. Advise the system text.
214216
String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;
215217

218+
// 2. Search for similar documents in the vector store.
219+
String query = new PromptTemplate(request.userText(), request.userParams()).render();
216220
var searchRequestToUse = SearchRequest.from(this.searchRequest)
217-
.withQuery(request.userText())
221+
.withQuery(query)
218222
.withFilterExpression(doGetFilterExpression(context));
219223

220-
// 2. Search for similar documents in the vector store.
221224
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
222225

226+
// 3. Create the context from the documents.
223227
context.put(RETRIEVED_DOCUMENTS, documents);
224228

225-
// 3. Create the context from the documents.
226229
String documentContext = documents.stream()
227230
.map(Content::getContent)
228231
.collect(Collectors.joining(System.lineSeparator()));

spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.springframework.ai.chat.model.ChatResponse;
3939
import org.springframework.ai.chat.model.Generation;
4040
import org.springframework.ai.chat.prompt.Prompt;
41+
import org.springframework.ai.chat.prompt.PromptTemplate;
4142
import org.springframework.ai.document.Document;
4243
import org.springframework.ai.vectorstore.SearchRequest;
4344
import org.springframework.ai.vectorstore.VectorStore;
@@ -48,6 +49,7 @@
4849

4950
/**
5051
* @author Christian Tzolov
52+
* @author Timo Salm
5153
*/
5254
@ExtendWith(MockitoExtension.class)
5355
public class QuestionAnswerAdvisorTests {
@@ -178,7 +180,63 @@ public Duration getTokensReset() {
178180
assertThat(this.vectorSearchCaptor.getValue().getFilterExpression()).isEqualTo(new FilterExpressionBuilder().eq("type", "Spring").build());
179181
assertThat(this.vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d);
180182
assertThat(this.vectorSearchCaptor.getValue().getTopK()).isEqualTo(6);
183+
}
184+
185+
@Test
186+
public void qaAdvisorTakesUserTextParametersIntoAccountForSimilaritySearch() {
187+
given(this.chatModel.call(this.promptCaptor.capture()))
188+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))),
189+
ChatResponseMetadata.builder().build()));
190+
191+
given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture()))
192+
.willReturn(List.of(new Document("doc1"), new Document("doc2")));
181193

194+
var chatClient = ChatClient.builder(this.chatModel).build();
195+
var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.defaults());
182196

197+
var userTextTemplate = "Please answer my question {question}";
198+
// @formatter:off
199+
chatClient.prompt()
200+
.user(u -> u.text(userTextTemplate).param("question", "XYZ"))
201+
.advisors(qaAdvisor)
202+
.call()
203+
.chatResponse();
204+
//formatter:on
205+
206+
var expectedQuery = "Please answer my question XYZ";
207+
var userPrompt = this.promptCaptor.getValue().getInstructions().get(0).getContent();
208+
assertThat(userPrompt).doesNotContain(userTextTemplate);
209+
assertThat(userPrompt).contains(expectedQuery);
210+
assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(expectedQuery);
183211
}
212+
213+
@Test
214+
public void qaAdvisorTakesUserParameterizedUserMessagesIntoAccountForSimilaritySearch() {
215+
given(this.chatModel.call(this.promptCaptor.capture()))
216+
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your answer is ZXY"))),
217+
ChatResponseMetadata.builder().build()));
218+
219+
given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture()))
220+
.willReturn(List.of(new Document("doc1"), new Document("doc2")));
221+
222+
var chatClient = ChatClient.builder(this.chatModel).build();
223+
var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.defaults());
224+
225+
var userTextTemplate = "Please answer my question {question}";
226+
var userPromptTemplate = new PromptTemplate(userTextTemplate, Map.of("question", "XYZ"));
227+
var userMessage = userPromptTemplate.createMessage();
228+
// @formatter:off
229+
chatClient.prompt(new Prompt(userMessage))
230+
.advisors(qaAdvisor)
231+
.call()
232+
.chatResponse();
233+
//formatter:on
234+
235+
var expectedQuery = "Please answer my question XYZ";
236+
var userPrompt = this.promptCaptor.getValue().getInstructions().get(0).getContent();
237+
assertThat(userPrompt).doesNotContain(userTextTemplate);
238+
assertThat(userPrompt).contains(expectedQuery);
239+
assertThat(this.vectorSearchCaptor.getValue().getQuery()).isEqualTo(expectedQuery);
240+
}
241+
184242
}

0 commit comments

Comments
 (0)