Skip to content

Commit 49eee92

Browse files
committed
[GH-3723] Vertex AI Gemini logprobs support
Signed-off-by: Rodrigo Malara <[email protected]>
1 parent af07517 commit 49eee92

File tree

5 files changed

+116
-4
lines changed

5 files changed

+116
-4
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
5050
import org.slf4j.Logger;
5151
import org.slf4j.LoggerFactory;
52+
import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi;
5253
import reactor.core.publisher.Flux;
5354
import reactor.core.scheduler.Schedulers;
5455

@@ -580,8 +581,29 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
580581
int candidateIndex = candidate.getIndex();
581582
FinishReason candidateFinishReason = candidate.getFinishReason();
582583

584+
// Convert from VertexAI protobuf to VertexAiGeminiApi DTOs
585+
List<VertexAiGeminiApi.LogProbs.TopContent> topCandidates = candidate.getLogprobsResult()
586+
.getTopCandidatesList()
587+
.stream()
588+
.filter(topCandidate -> !topCandidate.getCandidatesList().isEmpty())
589+
.map(topCandidate -> new VertexAiGeminiApi.LogProbs.TopContent(topCandidate.getCandidatesList()
590+
.stream()
591+
.map(c -> new VertexAiGeminiApi.LogProbs.Content(c.getToken(), c.getLogProbability(), c.getTokenId()))
592+
.toList()))
593+
.toList();
594+
595+
// Convert from VertexAI protobuf to VertexAiGeminiApi DTOs
596+
List<VertexAiGeminiApi.LogProbs.Content> chosenCandidates = candidate.getLogprobsResult()
597+
.getChosenCandidatesList()
598+
.stream()
599+
.map(c -> new VertexAiGeminiApi.LogProbs.Content(c.getToken(), c.getLogProbability(), c.getTokenId()))
600+
.toList();
601+
602+
VertexAiGeminiApi.LogProbs logprobs = new VertexAiGeminiApi.LogProbs(candidate.getAvgLogprobs(), topCandidates,
603+
chosenCandidates);
604+
583605
Map<String, Object> messageMetadata = Map.of("candidateIndex", candidateIndex, "finishReason",
584-
candidateFinishReason);
606+
candidateFinishReason, "logprobs", logprobs);
585607

586608
ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.builder()
587609
.finishReason(candidateFinishReason.name())
@@ -737,6 +759,10 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
737759
if (options.getPresencePenalty() != null) {
738760
generationConfigBuilder.setPresencePenalty(options.getPresencePenalty().floatValue());
739761
}
762+
if (options.getLogprobs() != null) {
763+
generationConfigBuilder.setLogprobs(options.getLogprobs());
764+
}
765+
generationConfigBuilder.setResponseLogprobs(options.getResponseLogprobs());
740766

741767
return generationConfigBuilder.build();
742768
}

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {
6464
*/
6565
private @JsonProperty("temperature") Double temperature;
6666

67+
/**
68+
* Optional. Enable returning the log probabilities of the top candidate tokens at each generation step.
69+
* The model's chosen token might not be the same as the top candidate token at each step.
70+
* Specify the number of candidates to return by using an integer value in the range of 1-20.
71+
* Should not be set unless responseLogprobs is set to true.
72+
*/
73+
private @JsonProperty("logprobs") Integer logprobs;
74+
75+
/**
76+
* Optional. If true, returns the log probabilities of the tokens that were chosen by the model at each step.
77+
* By default, this parameter is set to false.
78+
*/
79+
private @JsonProperty("responseLogprobs") boolean responseLogprobs;
80+
6781
/**
6882
* Optional. If specified, nucleus sampling will be used.
6983
*/
@@ -162,6 +176,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
162176
options.setSafetySettings(fromOptions.getSafetySettings());
163177
options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
164178
options.setToolContext(fromOptions.getToolContext());
179+
options.setLogprobs(fromOptions.getLogprobs());
180+
options.setResponseLogprobs(fromOptions.getResponseLogprobs());
165181
return options;
166182
}
167183

@@ -183,6 +199,10 @@ public void setTemperature(Double temperature) {
183199
this.temperature = temperature;
184200
}
185201

202+
public void setResponseLogprobs(boolean responseLogprobs) {
203+
this.responseLogprobs = responseLogprobs;
204+
}
205+
186206
@Override
187207
public Double getTopP() {
188208
return this.topP;
@@ -326,6 +346,18 @@ public void setToolContext(Map<String, Object> toolContext) {
326346
this.toolContext = toolContext;
327347
}
328348

349+
public Integer getLogprobs() {
350+
return logprobs;
351+
}
352+
353+
public void setLogprobs(Integer logprobs) {
354+
this.logprobs = logprobs;
355+
}
356+
357+
public boolean getResponseLogprobs() {
358+
return responseLogprobs;
359+
}
360+
329361
@Override
330362
public boolean equals(Object o) {
331363
if (this == o) {
@@ -346,15 +378,17 @@ public boolean equals(Object o) {
346378
&& Objects.equals(this.toolNames, that.toolNames)
347379
&& Objects.equals(this.safetySettings, that.safetySettings)
348380
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
349-
&& Objects.equals(this.toolContext, that.toolContext);
381+
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.logprobs, that.logprobs)
382+
&& Objects.equals(this.responseLogprobs, that.responseLogprobs);
350383
}
351384

352385
@Override
353386
public int hashCode() {
354387
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
355388
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
356389
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
357-
this.internalToolExecutionEnabled, this.toolContext);
390+
this.internalToolExecutionEnabled, this.toolContext, this.toolContext, this.logprobs,
391+
this.responseLogprobs);
358392
}
359393

360394
@Override
@@ -365,7 +399,8 @@ public String toString() {
365399
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
366400
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
367401
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
368-
+ ", safetySettings=" + this.safetySettings + '}';
402+
+ ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs + ", responseLogprobs="
403+
+ this.responseLogprobs + '}';
369404
}
370405

371406
@Override
@@ -488,6 +523,16 @@ public Builder toolContext(Map<String, Object> toolContext) {
488523
return this;
489524
}
490525

526+
public Builder logprobs(Integer logprobs) {
527+
this.options.setLogprobs(logprobs);
528+
return this;
529+
}
530+
531+
public Builder responseLogprobs(Boolean responseLogprobs) {
532+
this.options.setResponseLogprobs(responseLogprobs);
533+
return this;
534+
}
535+
491536
public VertexAiGeminiChatOptions build() {
492537
return this.options;
493538
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package org.springframework.ai.vertexai.gemini.api;
2+
3+
import java.util.List;
4+
5+
public class VertexAiGeminiApi {
6+
7+
public record LogProbs(Double avgLogprobs, List<TopContent> topCandidates,
8+
List<LogProbs.Content> chosenCandidates) {
9+
public record Content(String token, Float logprob, Integer id) {
10+
}
11+
12+
public record TopContent(List<Content> candidates) {
13+
}
14+
}
15+
16+
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ public void createRequestWithGenerationConfigOptions() {
262262
.stopSequences(List.of("stop1", "stop2"))
263263
.candidateCount(1)
264264
.responseMimeType("application/json")
265+
.responseLogprobs(true)
266+
.logprobs(2)
265267
.build())
266268
.build();
267269

@@ -280,6 +282,8 @@ public void createRequestWithGenerationConfigOptions() {
280282
assertThat(request.model().getGenerationConfig().getStopSequences(0)).isEqualTo("stop1");
281283
assertThat(request.model().getGenerationConfig().getStopSequences(1)).isEqualTo("stop2");
282284
assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json");
285+
assertThat(request.model().getGenerationConfig().getLogprobs()).isEqualTo(2);
286+
assertThat(request.model().getGenerationConfig().getResponseLogprobs()).isEqualTo(true);
283287
}
284288

285289
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import org.springframework.ai.model.tool.ToolCallingManager;
4848
import org.springframework.ai.tool.annotation.Tool;
4949
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
50+
import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi;
5051
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
5152
import org.springframework.beans.factory.annotation.Autowired;
5253
import org.springframework.beans.factory.annotation.Value;
@@ -226,6 +227,26 @@ void textStream() {
226227
assertThat(generationTextFromStream).isNotEmpty();
227228
}
228229

230+
@Test
231+
void logprobs() {
232+
VertexAiGeminiChatOptions chatOptions = VertexAiGeminiChatOptions.builder()
233+
.logprobs(1)
234+
.responseLogprobs(true)
235+
.build();
236+
237+
var logprobs = (VertexAiGeminiApi.LogProbs) this.chatModel
238+
.call(new Prompt("Explain Bulgaria? Answer in 10 paragraphs.", chatOptions))
239+
.getResult()
240+
.getOutput()
241+
.getMetadata()
242+
.get("logprobs");
243+
244+
assertThat(logprobs).isNotNull();
245+
assertThat(logprobs.avgLogprobs()).isNotZero();
246+
assertThat(logprobs.topCandidates()).isNotEmpty();
247+
assertThat(logprobs.chosenCandidates()).isNotEmpty();
248+
}
249+
229250
@Test
230251
void beanStreamOutputConverterRecords() {
231252

0 commit comments

Comments
 (0)