From 49eee924bf07465094a8a1c134f836da288930fc Mon Sep 17 00:00:00 2001 From: Rodrigo Malara Date: Wed, 2 Jul 2025 11:05:05 -0300 Subject: [PATCH 1/3] [GH-3723] Vertex AI Gemini logprobs support Signed-off-by: Rodrigo Malara --- .../gemini/VertexAiGeminiChatModel.java | 28 +++++++++- .../gemini/VertexAiGeminiChatOptions.java | 51 +++++++++++++++++-- .../gemini/api/VertexAiGeminiApi.java | 16 ++++++ .../gemini/CreateGeminiRequestTests.java | 4 ++ .../gemini/VertexAiGeminiChatModelIT.java | 21 ++++++++ 5 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/api/VertexAiGeminiApi.java diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 01ab8b96c02..d1c1a1e327f 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -49,6 +49,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -580,8 +581,29 @@ protected List responseCandidateToGeneration(Candidate candidate) { int candidateIndex = candidate.getIndex(); FinishReason candidateFinishReason = candidate.getFinishReason(); + // Convert from VertexAI protobuf to VertexAiGeminiApi DTOs + List topCandidates = candidate.getLogprobsResult() + .getTopCandidatesList() + .stream() + .filter(topCandidate -> !topCandidate.getCandidatesList().isEmpty()) + .map(topCandidate -> new VertexAiGeminiApi.LogProbs.TopContent(topCandidate.getCandidatesList() + .stream() + .map(c -> new VertexAiGeminiApi.LogProbs.Content(c.getToken(), c.getLogProbability(), c.getTokenId())) + .toList())) + .toList(); + + // Convert from VertexAI protobuf to VertexAiGeminiApi DTOs + List chosenCandidates = candidate.getLogprobsResult() + .getChosenCandidatesList() + .stream() + .map(c -> new VertexAiGeminiApi.LogProbs.Content(c.getToken(), c.getLogProbability(), c.getTokenId())) + .toList(); + + VertexAiGeminiApi.LogProbs logprobs = new VertexAiGeminiApi.LogProbs(candidate.getAvgLogprobs(), topCandidates, + chosenCandidates); + Map messageMetadata = Map.of("candidateIndex", candidateIndex, "finishReason", - candidateFinishReason); + candidateFinishReason, "logprobs", logprobs); ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.builder() .finishReason(candidateFinishReason.name()) @@ -737,6 +759,10 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) { if (options.getPresencePenalty() != null) { generationConfigBuilder.setPresencePenalty(options.getPresencePenalty().floatValue()); } + if (options.getLogprobs() != null) { + generationConfigBuilder.setLogprobs(options.getLogprobs()); + } + generationConfigBuilder.setResponseLogprobs(options.getResponseLogprobs()); return generationConfigBuilder.build(); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 68ae24a92e2..06cfcb97393 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -64,6 +64,20 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { */ private @JsonProperty("temperature") Double temperature; + /** + * Optional. Enable returning the log probabilities of the top candidate tokens at each generation step. + * The model's chosen token might not be the same as the top candidate token at each step. + * Specify the number of candidates to return by using an integer value in the range of 1-20. + * Should not be set unless responseLogprobs is set to true. + */ + private @JsonProperty("logprobs") Integer logprobs; + + /** + * Optional. If true, returns the log probabilities of the tokens that were chosen by the model at each step. + * By default, this parameter is set to false. + */ + private @JsonProperty("responseLogprobs") boolean responseLogprobs; + /** * Optional. If specified, nucleus sampling will be used. */ @@ -162,6 +176,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setSafetySettings(fromOptions.getSafetySettings()); options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); options.setToolContext(fromOptions.getToolContext()); + options.setLogprobs(fromOptions.getLogprobs()); + options.setResponseLogprobs(fromOptions.getResponseLogprobs()); return options; } @@ -183,6 +199,10 @@ public void setTemperature(Double temperature) { this.temperature = temperature; } + public void setResponseLogprobs(boolean responseLogprobs) { + this.responseLogprobs = responseLogprobs; + } + @Override public Double getTopP() { return this.topP; @@ -326,6 +346,18 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + public Integer getLogprobs() { + return logprobs; + } + + public void setLogprobs(Integer logprobs) { + this.logprobs = logprobs; + } + + public boolean getResponseLogprobs() { + return responseLogprobs; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -346,7 +378,8 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) - && Objects.equals(this.toolContext, that.toolContext); + && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.logprobs, that.logprobs) + && Objects.equals(this.responseLogprobs, that.responseLogprobs); } @Override @@ -354,7 +387,8 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings, - this.internalToolExecutionEnabled, this.toolContext); + this.internalToolExecutionEnabled, this.toolContext, this.toolContext, this.logprobs, + this.responseLogprobs); } @Override @@ -365,7 +399,8 @@ public String toString() { + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval - + ", safetySettings=" + this.safetySettings + '}'; + + ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs + ", responseLogprobs=" + + this.responseLogprobs + '}'; } @Override @@ -488,6 +523,16 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder logprobs(Integer logprobs) { + this.options.setLogprobs(logprobs); + return this; + } + + public Builder responseLogprobs(Boolean responseLogprobs) { + this.options.setResponseLogprobs(responseLogprobs); + return this; + } + public VertexAiGeminiChatOptions build() { return this.options; } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/api/VertexAiGeminiApi.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/api/VertexAiGeminiApi.java new file mode 100644 index 00000000000..5ce9f93265f --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/api/VertexAiGeminiApi.java @@ -0,0 +1,16 @@ +package org.springframework.ai.vertexai.gemini.api; + +import java.util.List; + +public class VertexAiGeminiApi { + + public record LogProbs(Double avgLogprobs, List topCandidates, + List chosenCandidates) { + public record Content(String token, Float logprob, Integer id) { + } + + public record TopContent(List candidates) { + } + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index bcb32a748fa..a0ce5d23305 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -262,6 +262,8 @@ public void createRequestWithGenerationConfigOptions() { .stopSequences(List.of("stop1", "stop2")) .candidateCount(1) .responseMimeType("application/json") + .responseLogprobs(true) + .logprobs(2) .build()) .build(); @@ -280,6 +282,8 @@ public void createRequestWithGenerationConfigOptions() { assertThat(request.model().getGenerationConfig().getStopSequences(0)).isEqualTo("stop1"); assertThat(request.model().getGenerationConfig().getStopSequences(1)).isEqualTo("stop2"); assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json"); + assertThat(request.model().getGenerationConfig().getLogprobs()).isEqualTo(2); + assertThat(request.model().getGenerationConfig().getResponseLogprobs()).isEqualTo(true); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java index 2c37f0608a6..8bec557cdd5 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java @@ -47,6 +47,7 @@ import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel; +import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -226,6 +227,26 @@ void textStream() { assertThat(generationTextFromStream).isNotEmpty(); } + @Test + void logprobs() { + VertexAiGeminiChatOptions chatOptions = VertexAiGeminiChatOptions.builder() + .logprobs(1) + .responseLogprobs(true) + .build(); + + var logprobs = (VertexAiGeminiApi.LogProbs) this.chatModel + .call(new Prompt("Explain Bulgaria? Answer in 10 paragraphs.", chatOptions)) + .getResult() + .getOutput() + .getMetadata() + .get("logprobs"); + + assertThat(logprobs).isNotNull(); + assertThat(logprobs.avgLogprobs()).isNotZero(); + assertThat(logprobs.topCandidates()).isNotEmpty(); + assertThat(logprobs.chosenCandidates()).isNotEmpty(); + } + @Test void beanStreamOutputConverterRecords() { From 4965fc99e22b7ae865c0949956dba1dc5a3ed70f Mon Sep 17 00:00:00 2001 From: Rodrigo Malara Date: Wed, 2 Jul 2025 11:16:28 -0300 Subject: [PATCH 2/3] [GH-3723] Vertex AI Gemini logprobs support Closes GH-3723 Signed-off-by: Rodrigo Malara --- .../ai/vertexai/gemini/VertexAiGeminiChatModel.java | 1 - 1 file changed, 1 deletion(-) diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index d1c1a1e327f..8e494deb210 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -592,7 +592,6 @@ protected List responseCandidateToGeneration(Candidate candidate) { .toList())) .toList(); - // Convert from VertexAI protobuf to VertexAiGeminiApi DTOs List chosenCandidates = candidate.getLogprobsResult() .getChosenCandidatesList() .stream() From 13f3a6c1ffda2fa7595ef0383b65a8556c38eb3e Mon Sep 17 00:00:00 2001 From: Rodrigo Malara Date: Wed, 2 Jul 2025 12:59:28 -0300 Subject: [PATCH 3/3] [GH-3723] addressing review comments Closes GH-3723 Signed-off-by: Rodrigo Malara --- .../vertexai/gemini/VertexAiGeminiChatModel.java | 2 +- .../gemini/VertexAiGeminiChatOptions.java | 3 +-- .../ai/vertexai/gemini/api/VertexAiGeminiApi.java | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 8e494deb210..ebb1c5c40bd 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -49,7 +49,6 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -85,6 +84,7 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; import org.springframework.ai.vertexai.gemini.schema.VertexToolCallingManager; diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 06cfcb97393..69f32c8440c 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -387,8 +387,7 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings, - this.internalToolExecutionEnabled, this.toolContext, this.toolContext, this.logprobs, - this.responseLogprobs); + this.internalToolExecutionEnabled, this.toolContext, this.logprobs, this.responseLogprobs); } @Override diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/api/VertexAiGeminiApi.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/api/VertexAiGeminiApi.java index 5ce9f93265f..7bb3e1b4da9 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/api/VertexAiGeminiApi.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/api/VertexAiGeminiApi.java @@ -1,3 +1,18 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.springframework.ai.vertexai.gemini.api; import java.util.List;