diff --git a/README.md b/README.md index dba0e71d..d1390ca6 100644 --- a/README.md +++ b/README.md @@ -533,6 +533,200 @@ If you use `@JsonProperty(required = false)`, the `false` value will be ignored. must mark all properties as _required_, so the schema generated from your Java classes will respect that restriction and ignore any annotation that would violate it. +## Function calling with JSON schemas + +OpenAI [Function Calling](https://platform.openai.com/docs/guides/function-calling?api-mode=chat) +lets you integrate external functions directly into the language model's responses. Instead of +producing plain text, the model can output instructions (with parameters) for calling a function +when appropriate. You define a [JSON schema](https://json-schema.org/overview/what-is-jsonschema) +for functions, and the model uses it to decide when and how to trigger these calls, enabling more +interactive, data-driven applications. + +A JSON schema describing a function's parameters can be defined via the API by building a +[`ChatCompletionTool`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionTool.kt) +containing a +[`FunctionDefinition`](openai-java-core/src/main/kotlin/com/openai/models/FunctionDefinition.kt) +and then using `addTool` to set it on the input parameters. The response from the AI model may then +contain requests to call your functions, detailing the functions' names and their parameter values +as JSON data that conforms to the JSON schema from the function definition. You can then parse the +parameter values from this JSON, invoke your functions, and pass your functions' results back to the +AI model. A full, working example of _Function Calling_ using the low-level API can be seen in +[`FunctionCallingRawExample`](openai-java-example/src/main/java/com/openai/example/FunctionCallingRawExample.java). + +However, for greater convenience, the SDK can derive a function and its parameters automatically +from the structure of an arbitrary Java class: the class's name provides the function name, and the +class's fields define the function's parameters. When the AI model responds with the parameter +values in JSON form, you can then easily convert that JSON to an instance of your Java class and +use the parameter values to invoke your custom function. A full, working example of the use of +_Function Calling_ with Java classes to define function parameters can be seen in +[`FunctionCallingExample`](openai-java-example/src/main/java/com/openai/example/FunctionCallingExample.java). + +Like for [Structured Outputs](#structured-outputs-with-json-schemas), Java classes can contain +fields declared to be instances of other classes and can use collections. Optionally, annotations +can be used to set the descriptions of the function (class) and its parameters (fields) to assist +the AI model in understanding the purpose of the function and the possible values of its parameters. + +```java +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +@JsonClassDescription("Gets the quality of the given SDK.") +static class GetSdkQuality { + @JsonPropertyDescription("The name of the SDK.") + public String name; + + public SdkQuality execute() { + return new SdkQuality( + name, name.contains("OpenAI") ? "It's robust and polished!" : "*shrug*"); + } +} + +static class SdkQuality { + public String quality; + + public SdkQuality(String name, String evaluation) { + quality = name + ": " + evaluation; + } +} + +@JsonClassDescription("Gets the review score (out of 10) for the named SDK.") +static class GetSdkScore { + public String name; + + public int execute() { + return name.contains("OpenAI") ? 10 : 3; + } +} +``` + +When your functions are defined, add them to the input parameters using `addTool(Class)` and then +call them if requested to do so in the AI model's response. `Function.argments(Class)` can be +used to parse a function's parameters in JSON form to an instance of your function-defining class. +The fields of that instance will be set to the values of the parameters to the function call. + +After calling the function, use `ChatCompletionToolMessageParam.Builder.contentAsJson(Object)` to +pass the function's result back to the AI model. The method will convert the result to JSON form +for consumption by the model. The `Object` can be any object, including simple `String` instances +and boxed primitive types. + +```java +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.*; +import java.util.Collection; + +OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + +ChatCompletionCreateParams.Builder createParamsBuilder = ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_3_5_TURBO) + .maxCompletionTokens(2048) + .addTool(GetSdkQuality.class) + .addTool(GetSdkScore.class) + .addUserMessage("How good are the following SDKs and what do reviewers say: " + + "OpenAI Java SDK, Unknown Company SDK."); + +client.chat().completions().create(createParamsBuilder.build()).choices().stream() + .map(ChatCompletion.Choice::message) + // Add each assistant message onto the builder so that we keep track of the + // conversation for asking a follow-up question later. + .peek(createParamsBuilder::addMessage) + .flatMap(message -> { + message.content().ifPresent(System.out::println); + return message.toolCalls().stream().flatMap(Collection::stream); + }) + .forEach(toolCall -> { + Object result = callFunction(toolCall.function()); + // Add the tool call result to the conversation. + createParamsBuilder.addMessage(ChatCompletionToolMessageParam.builder() + .toolCallId(toolCall.id()) + .contentAsJson(result) + .build()); + }); + +// Ask a follow-up question about the function call result. +createParamsBuilder.addUserMessage("Why do you say that?"); +client.chat().completions().create(createParamsBuilder.build()).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .forEach(System.out::println); + +static Object callFunction(ChatCompletionMessageToolCall.Function function) { + switch (function.name()) { + case "GetSdkQuality": + return function.arguments(GetSdkQuality.class).execute(); + case "GetSdkScore": + return function.arguments(GetSdkScore.class).execute(); + default: + throw new IllegalArgumentException("Unknown function: " + function.name()); + } +} +``` + +In the code above, an `execute()` method encapsulates each function's logic. However, there is no +requirement to follow that pattern. You are free to implement your function's logic in any way that +best suits your use case. The pattern above is only intended to _suggest_ that a suitable pattern +may make the process of function calling simpler to understand and implement. + +### Usage with the Responses API + +_Function Calling_ is also supported for the Responses API. The usage is the same as described +except where the Responses API differs slightly from the Chat Completions API. Pass the top-level +class to `addTool(Class)` when building the parameters. In the response, look for +[`RepoonseOutputItem`](openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseOutputItem.kt) +instances that are function calls. Parse the parameters to each function call to an instance of the +class using +[`ResponseFunctionToolCall.arguments(Class)`](openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseFunctionToolCall.kt). +Finally, pass the result of each call back to the model. + +For a full example of the usage of _Function Calling_ with the Responses API using the low-level +API to define and parse function parameters, see +[`ResponsesFunctionCallingRawExample`](openai-java-example/src/main/java/com/openai/example/ResponsesFunctionCallingRawExample.java). + +For a full example of the usage of _Function Calling_ with the Responses API using Java classes to +define and parse function parameters, see +[`ResponsesFunctionCallingExample`](openai-java-example/src/main/java/com/openai/example/ResponsesFunctionCallingExample.java). + +### Local function JSON schema validation + +Like for _Structured Outputs_, you can perform local validation to check that the JSON schema +derived from your function class respects the restrictions imposed by OpenAI on such schemas. Local +validation is enabled by default, but it can be disabled by adding `JsonSchemaLocalValidation.NO` to +the call to `addTool`. + +```java +ChatCompletionCreateParams.Builder createParamsBuilder = ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_3_5_TURBO) + .maxCompletionTokens(2048) + .addTool(GetSdkQuality.class, JsonSchemaLocalValidation.NO) + .addTool(GetSdkScore.class, JsonSchemaLocalValidation.NO) + .addUserMessage("How good are the following SDKs and what do reviewers say: " + + "OpenAI Java SDK, Unknown Company SDK."); +``` + +See [Local JSON schema validation](#local-json-schema-validation) for more details on local schema +validation and under what circumstances you might want to disable it. + +### Annotating function classes + +You can use annotations to add further information about functions to the JSON schemas that are +derived from your function classes, or to exclude individual fields from the parameters to the +function. Details from annotations captured in the JSON schema may be used by the AI model to +improve its response. The SDK supports the use of +[Jackson Databind](https://github.com/FasterXML/jackson-databind) annotations. + +- Use `@JsonClassDescription` to add a description to a function class detailing when and how to use + that function. +- Use `@JsonTypeName` to set the function name to something other than the simple name of the class, + which is used by default. +- Use `@JsonPropertyDescription` to add a detailed description to function parameter (a field of + a function class). +- Use `@JsonIgnore` to omit a field of a class from the generated JSON schema for a function's + parameters. + +OpenAI provides some +[Best practices for defining functions](https://platform.openai.com/docs/guides/function-calling#best-practices-for-defining-functions) +that may help you to understand how to use the above annotations effectively for your functions. + ## File uploads The SDK defines methods that accept files. diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt index 85c20b43..f3d74364 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt @@ -392,8 +392,7 @@ internal class JsonSchemaValidator private constructor() { // The schema must declare that additional properties are not allowed. For this check, it // does not matter if there are no "properties" in the schema. verify( - schema.get(ADDITIONAL_PROPS) != null && - schema.get(ADDITIONAL_PROPS).asBoolean() == false, + schema.get(ADDITIONAL_PROPS) != null && !schema.get(ADDITIONAL_PROPS).asBoolean(), path, ) { "'$ADDITIONAL_PROPS' field is missing or is not set to 'false'." @@ -401,27 +400,37 @@ internal class JsonSchemaValidator private constructor() { val properties = schema.get(PROPS) - // The "properties" field may be missing (there may be no properties to declare), but if it - // is present, it must be a non-empty object, or validation cannot continue. - // TODO: Decide if a missing or empty "properties" field is OK or not. + // An object schema _must_ have a `"properties"` field, and it must contain at least one + // property. The AI model will report an error relating to a missing or empty `"required"` + // array if the "properties" field is missing or empty (and therefore the `"required"` array + // will also be missing or empty). This condition can arise if a `Map` is used as the field + // type: it will cause the generation of an object schema with no defined properties. If not + // present or empty, validation cannot continue. verify( - properties == null || (properties.isObject && !properties.isEmpty), + properties != null && properties.isObject && !properties.isEmpty, path, - { "'$PROPS' field is not a non-empty object." }, + { "'$PROPS' field is missing, empty or not an object." }, ) { return } - if (properties != null) { // Must be an object. - // If a "properties" field is present, there must also be a "required" field. All - // properties must be named in the list of required properties. - validatePropertiesRequired( - properties.fieldNames().asSequence().toSet(), - schema.get(REQUIRED), - "$path/$REQUIRED", - ) - validateProperties(properties, "$path/$PROPS", depth) + // Similarly, insist that the `"required"` array is present or stop validation. + val required = schema.get(REQUIRED) + + verify( + required != null && required.isArray && !required.isEmpty, + path, + { "'$REQUIRED' field is missing, empty or not an array." }, + ) { + return } + + validatePropertiesRequired( + properties.fieldNames().asSequence().toSet(), + required, + "$path/$REQUIRED", + ) + validateProperties(properties, "$path/$PROPS", depth) } /** @@ -554,10 +563,10 @@ internal class JsonSchemaValidator private constructor() { */ private fun validatePropertiesRequired( propertyNames: Collection, - required: JsonNode?, + required: JsonNode, path: String, ) { - val requiredNames = required?.map { it.asText() }?.toSet() ?: emptySet() + val requiredNames = required.map { it.asText() }.toSet() propertyNames.forEach { propertyName -> verify(propertyName in requiredNames, path) { diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index df48c1af..611e6b07 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -1,7 +1,9 @@ package com.openai.core +import com.fasterxml.jackson.annotation.JsonTypeName import com.fasterxml.jackson.databind.JsonNode import com.fasterxml.jackson.databind.json.JsonMapper +import com.fasterxml.jackson.databind.node.ObjectNode import com.fasterxml.jackson.datatype.jdk8.Jdk8Module import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule import com.fasterxml.jackson.module.kotlin.kotlinModule @@ -11,7 +13,10 @@ import com.github.victools.jsonschema.generator.SchemaGenerator import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder import com.github.victools.jsonschema.module.jackson.JacksonModule import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.FunctionDefinition import com.openai.models.ResponseFormatJsonSchema +import com.openai.models.chat.completions.ChatCompletionTool +import com.openai.models.responses.FunctionTool import com.openai.models.responses.ResponseFormatTextJsonSchemaConfig import com.openai.models.responses.ResponseTextConfig @@ -30,15 +35,19 @@ private val MAPPER = * class. */ @JvmSynthetic -internal fun responseFormatFromClass( - type: Class, +internal fun responseFormatFromClass( + type: Class<*>, localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, ): ResponseFormatJsonSchema = ResponseFormatJsonSchema.builder() .jsonSchema( ResponseFormatJsonSchema.JsonSchema.builder() .name("json-schema-from-${type.simpleName}") - .schema(JsonValue.fromJsonNode(extractAndValidateSchema(type, localValidation))) + .schema( + JsonValue.fromJsonNode( + validateSchema(extractSchema(type), type, localValidation) + ) + ) // Ensure the model's output strictly adheres to this JSON schema. This is the // essential "ON switch" for Structured Outputs. .strict(true) @@ -46,21 +55,29 @@ internal fun responseFormatFromClass( ) .build() -private fun extractAndValidateSchema( - type: Class, +/** + * Validates the given JSON schema with respect to OpenAI's JSON schema restrictions. + * + * @param schema The JSON schema to be validated. + * @param sourceType The class from which the JSON schema was derived. This is only used in error + * messages. + * @param localValidation Set to [JsonSchemaLocalValidation.YES] to perform the validation. Other + * values will cause validation to be skipped. + */ +@JvmSynthetic +internal fun validateSchema( + schema: ObjectNode, + sourceType: Class<*>, localValidation: JsonSchemaLocalValidation, -): JsonNode { - val schema = extractSchema(type) - +): ObjectNode { if (localValidation == JsonSchemaLocalValidation.YES) { val validator = JsonSchemaValidator.create().validate(schema) require(validator.isValid()) { - "Local validation failed for JSON schema derived from $type:\n" + + "Local validation failed for JSON schema derived from $sourceType:\n" + validator.errors().joinToString("\n") { " - $it" } } } - return schema } @@ -69,15 +86,19 @@ private fun extractAndValidateSchema( * arbitrary Java class. */ @JvmSynthetic -internal fun textConfigFromClass( - type: Class, +internal fun textConfigFromClass( + type: Class<*>, localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, ): ResponseTextConfig = ResponseTextConfig.builder() .format( ResponseFormatTextJsonSchemaConfig.builder() .name("json-schema-from-${type.simpleName}") - .schema(JsonValue.fromJsonNode(extractAndValidateSchema(type, localValidation))) + .schema( + JsonValue.fromJsonNode( + validateSchema(extractSchema(type), type, localValidation) + ) + ) // Ensure the model's output strictly adheres to this JSON schema. This is the // essential "ON switch" for Structured Outputs. .strict(true) @@ -85,6 +106,82 @@ internal fun textConfigFromClass( ) .build() +// "internal" instead of "private" for testing purposes. +internal data class FunctionInfo( + val name: String, + val description: String?, + val schema: ObjectNode, +) + +@JvmSynthetic +// "internal" instead of "private" for testing purposes. +internal fun extractFunctionInfo( + parametersType: Class<*>, + localValidation: JsonSchemaLocalValidation, +): FunctionInfo { + val schema = extractSchema(parametersType) + + validateSchema(schema, parametersType, localValidation) + + // The JSON schema generator ignores the `@JsonTypeName` annotation, so it never sets the "name" + // field at the root of the schema. Respect that annotation here and use it to set the name + // (outside the schema). Fall back to using the simple name of the class. + val name = + parametersType.getAnnotation(JsonTypeName::class.java)?.value ?: parametersType.simpleName + + // The JSON schema generator will copy the `@JsonClassDescription` into the schema. If present, + // remove it from the schema so it can be set on the function definition/tool. + val descriptionNode: JsonNode? = schema.remove("description") + val description: String? = descriptionNode?.textValue() + + return FunctionInfo(name, description, schema) +} + +/** + * Creates a Chat Completions API tool defining a function whose input parameters are derived from + * the fields of a class. + */ +@JvmSynthetic +internal fun functionToolFromClass( + parametersType: Class<*>, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, +): ChatCompletionTool { + val functionInfo = extractFunctionInfo(parametersType, localValidation) + + return ChatCompletionTool.builder() + .function( + FunctionDefinition.builder() + .name(functionInfo.name) + .apply { functionInfo.description?.let(::description) } + .parameters(JsonValue.fromJsonNode(functionInfo.schema)) + // OpenAI: "Setting strict to true will ensure function calls reliably adhere to the + // function schema, instead of being best effort. We recommend always enabling + // strict mode." + .strict(true) + .build() + ) + .build() +} + +/** + * Creates a Responses API function tool defining a function whose input parameters are derived from + * the fields of a class. + */ +@JvmSynthetic +internal fun responseFunctionToolFromClass( + parametersType: Class<*>, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, +): FunctionTool { + val functionInfo = extractFunctionInfo(parametersType, localValidation) + + return FunctionTool.builder() + .name(functionInfo.name) + .apply { functionInfo.description?.let(::description) } + .parameters(JsonValue.fromJsonNode(functionInfo.schema)) + .strict(true) + .build() +} + /** * Derives a JSON schema from the structure of an arbitrary Java class. * @@ -93,7 +190,7 @@ internal fun textConfigFromClass( * thrown and any recorded validation errors can be inspected at leisure by the tests. */ @JvmSynthetic -internal fun extractSchema(type: Class): JsonNode { +internal fun extractSchema(type: Class<*>): ObjectNode { val configBuilder = SchemaGeneratorConfigBuilder( com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, @@ -119,6 +216,9 @@ internal fun extractSchema(type: Class): JsonNode { /** * Creates an instance of a Java class using data from a JSON. The JSON data should conform to the * JSON schema previously extracted from the Java class. + * + * @throws OpenAIInvalidDataException If the JSON data cannot be parsed to an instance of the + * [responseType] class. */ @JvmSynthetic internal fun responseTypeFromJson(json: String, responseType: Class): T = @@ -130,3 +230,10 @@ internal fun responseTypeFromJson(json: String, responseType: Class): T = // sensitive data are not exposed in logs. throw OpenAIInvalidDataException("Error parsing JSON: $json", e) } + +/** + * Converts any object into a JSON-formatted string. For `Object` types (other than strings and + * boxed primitives) a JSON object is created with its fields and values set from the fields of the + * object. + */ +@JvmSynthetic internal fun toJsonString(obj: Any): String = MAPPER.writeValueAsString(obj) diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index 74065589..c0d0793b 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -25,6 +25,7 @@ import com.openai.core.Params import com.openai.core.allMaxBy import com.openai.core.checkKnown import com.openai.core.checkRequired +import com.openai.core.functionToolFromClass import com.openai.core.getOrThrow import com.openai.core.http.Headers import com.openai.core.http.QueryParams @@ -1536,6 +1537,21 @@ private constructor( */ fun addTool(tool: ChatCompletionTool) = apply { body.addTool(tool) } + /** + * Adds a single [ChatCompletionTool] to [tools] where the JSON schema describing the + * function parameters is derived from the fields of a given class. Local validation of that + * JSON schema can be performed to check if the schema is likely to pass remote validation + * by the AI model. By default, local validation is enabled; disable it by setting + * [localValidation] to [JsonSchemaLocalValidation.NO]. + * + * @see addTool + */ + @JvmOverloads + fun addTool( + functionParametersType: Class<*>, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { addTool(functionToolFromClass(functionParametersType, localValidation)) } + /** * An integer between 0 and 20 specifying the number of most likely tokens to return at each * token position, each with an associated log probability. `logprobs` must be set to `true` diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionMessageToolCall.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionMessageToolCall.kt index d240a9d9..bf089cdc 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionMessageToolCall.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionMessageToolCall.kt @@ -11,6 +11,7 @@ import com.openai.core.JsonField import com.openai.core.JsonMissing import com.openai.core.JsonValue import com.openai.core.checkRequired +import com.openai.core.responseTypeFromJson import com.openai.errors.OpenAIInvalidDataException import java.util.Collections import java.util.Objects @@ -258,6 +259,22 @@ private constructor( */ fun arguments(): String = arguments.getRequired("arguments") + /** + * Gets the arguments to the function call, converting the values from the model in JSON + * format to an instance of a class that holds those values. The class must previously have + * been used to define the JSON schema for the function definition's parameters, so that the + * JSON corresponds to structure of the given class. + * + * @throws OpenAIInvalidDataException If the JSON data is missing, `null`, or cannot be + * parsed to an instance of the [functionParametersType] class. This might occur if the + * class is not the same as the class that was originally used to define the arguments, or + * if the data from the AI model is invalid or incomplete (e.g., truncated). + * @see ChatCompletionCreateParams.Builder.addTool + * @see arguments + */ + fun arguments(functionParametersType: Class): T = + responseTypeFromJson(arguments(), functionParametersType) + /** * The name of the function to call. * diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionToolMessageParam.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionToolMessageParam.kt index 01df56d8..2e01b766 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionToolMessageParam.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionToolMessageParam.kt @@ -22,6 +22,7 @@ import com.openai.core.JsonValue import com.openai.core.allMaxBy import com.openai.core.checkRequired import com.openai.core.getOrThrow +import com.openai.core.toJsonString import com.openai.errors.OpenAIInvalidDataException import java.util.Collections import java.util.Objects @@ -146,6 +147,14 @@ private constructor( /** Alias for calling [content] with `Content.ofText(text)`. */ fun content(text: String) = content(Content.ofText(text)) + /** + * Sets the content to text representing the JSON serialized form of a given object. This is + * useful when passing data that is the result of a function call. + * + * @see content + */ + fun contentAsJson(functionResult: Any) = content(toJsonString(functionResult)) + /** * Alias for calling [content] with `Content.ofArrayOfContentParts(arrayOfContentParts)`. */ diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt index 73741b58..8c67696f 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -532,6 +532,13 @@ internal constructor( /** @see ChatCompletionCreateParams.Builder.addTool */ fun addTool(tool: ChatCompletionTool) = apply { paramsBuilder.addTool(tool) } + /** @see ChatCompletionCreateParams.Builder.addTool */ + @JvmOverloads + fun addTool( + functionParametersType: Class<*>, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { paramsBuilder.addTool(functionParametersType, localValidation) } + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ fun topLogprobs(topLogprobs: Long?) = apply { paramsBuilder.topLogprobs(topLogprobs) } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt index bd83a710..3ab027db 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseCreateParams.kt @@ -28,6 +28,7 @@ import com.openai.core.checkRequired import com.openai.core.getOrThrow import com.openai.core.http.Headers import com.openai.core.http.QueryParams +import com.openai.core.responseFunctionToolFromClass import com.openai.core.toImmutable import com.openai.errors.OpenAIInvalidDataException import com.openai.models.ChatModel @@ -911,6 +912,23 @@ private constructor( /** Alias for calling [addTool] with `Tool.ofFunction(function)`. */ fun addTool(function: FunctionTool) = apply { body.addTool(function) } + /** + * Adds a single [FunctionTool] where the JSON schema describing the function parameters is + * derived from the fields of a given class. Local validation of that JSON schema can be + * performed to check if the schema is likely to pass remote validation by the AI model. By + * default, local validation is enabled; disable it by setting [localValidation] to + * [JsonSchemaLocalValidation.NO]. + * + * @see addTool + */ + @JvmOverloads + fun addTool( + functionParametersType: Class<*>, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { + body.addTool(responseFunctionToolFromClass(functionParametersType, localValidation)) + } + /** Alias for calling [addTool] with `Tool.ofFileSearch(fileSearch)`. */ fun addTool(fileSearch: FileSearchTool) = apply { body.addTool(fileSearch) } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseFunctionToolCall.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseFunctionToolCall.kt index 223b1a64..a5da8484 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseFunctionToolCall.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseFunctionToolCall.kt @@ -12,6 +12,7 @@ import com.openai.core.JsonField import com.openai.core.JsonMissing import com.openai.core.JsonValue import com.openai.core.checkRequired +import com.openai.core.responseTypeFromJson import com.openai.errors.OpenAIInvalidDataException import java.util.Collections import java.util.Objects @@ -52,6 +53,21 @@ private constructor( */ fun arguments(): String = arguments.getRequired("arguments") + /** + * Parses the JSON string defining the arguments to pass to the function to an instance of a + * class. The class must be the same as the class that was used to define the function's + * parameters when the function was defined. + * + * @throws OpenAIInvalidDataException If the JSON data is missing, `null`, or cannot be parsed + * to an instance of the [functionParametersType] class. This might occur if the class is not + * the same as the class that was originally used to define the arguments, or if the data from + * the AI model is invalid or incomplete (e.g., truncated). + * @see ResponseCreateParams.Builder.addTool + * @see arguments + */ + fun arguments(functionParametersType: Class): T = + responseTypeFromJson(arguments(), functionParametersType) + /** * The unique ID of the function tool call generated by the model. * diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseInputItem.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseInputItem.kt index b97bcb46..8a5ed5c6 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseInputItem.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/ResponseInputItem.kt @@ -25,6 +25,7 @@ import com.openai.core.checkKnown import com.openai.core.checkRequired import com.openai.core.getOrThrow import com.openai.core.toImmutable +import com.openai.core.toJsonString import com.openai.errors.OpenAIInvalidDataException import java.util.Collections import java.util.Objects @@ -2574,6 +2575,14 @@ private constructor( /** A JSON string of the output of the function tool call. */ fun output(output: String) = output(JsonField.of(output)) + /** + * Sets the output to text representing the JSON serialized form of a given object. This + * is useful when passing data that is the result of a function call. + * + * @see output + */ + fun outputAsJson(functionResult: Any) = apply { output(toJsonString(functionResult)) } + /** * Sets [Builder.output] to an arbitrary JSON value. * diff --git a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt index aae191ac..e41ee85e 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/responses/StructuredResponseCreateParams.kt @@ -90,6 +90,22 @@ class StructuredResponseCreateParams( /** @see ResponseCreateParams.Builder.model */ fun model(only: ResponsesModel.ResponsesOnlyModel) = apply { paramsBuilder.model(only) } + /** @see ResponseCreateParams.Builder.background */ + fun background(background: Boolean?) = apply { paramsBuilder.background(background) } + + /** @see ResponseCreateParams.Builder.background */ + fun background(background: Boolean) = apply { paramsBuilder.background(background) } + + /** @see ResponseCreateParams.Builder.background */ + fun background(background: Optional) = apply { + paramsBuilder.background(background) + } + + /** @see ResponseCreateParams.Builder.background */ + fun background(background: JsonField) = apply { + paramsBuilder.background(background) + } + /** @see ResponseCreateParams.Builder.include */ fun include(include: List?) = apply { paramsBuilder.include(include) } @@ -286,6 +302,16 @@ class StructuredResponseCreateParams( /** @see ResponseCreateParams.Builder.addTool */ fun addTool(tool: Tool) = apply { paramsBuilder.addTool(tool) } + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(function: FunctionTool) = apply { paramsBuilder.addTool(function) } + + /** @see ResponseCreateParams.Builder.addTool */ + @JvmOverloads + fun addTool( + functionParametersType: Class<*>, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { paramsBuilder.addTool(functionParametersType, localValidation) } + /** @see ResponseCreateParams.Builder.addTool */ fun addTool(fileSearch: FileSearchTool) = apply { paramsBuilder.addTool(fileSearch) } @@ -294,9 +320,6 @@ class StructuredResponseCreateParams( paramsBuilder.addFileSearchTool(vectorStoreIds) } - /** @see ResponseCreateParams.Builder.addTool */ - fun addTool(function: FunctionTool) = apply { paramsBuilder.addTool(function) } - /** @see ResponseCreateParams.Builder.addTool */ fun addTool(webSearch: WebSearchTool) = apply { paramsBuilder.addTool(webSearch) } @@ -305,6 +328,37 @@ class StructuredResponseCreateParams( paramsBuilder.addTool(computerUsePreview) } + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(mcp: Tool.Mcp) = apply { paramsBuilder.addTool(mcp) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(codeInterpreter: Tool.CodeInterpreter) = apply { + paramsBuilder.addTool(codeInterpreter) + } + + /** @see ResponseCreateParams.Builder.addCodeInterpreterTool */ + fun addCodeInterpreterTool(container: Tool.CodeInterpreter.Container) = apply { + paramsBuilder.addCodeInterpreterTool(container) + } + + /** @see ResponseCreateParams.Builder.addCodeInterpreterTool */ + fun addCodeInterpreterTool(string: String) = apply { + paramsBuilder.addCodeInterpreterTool(string) + } + + /** @see ResponseCreateParams.Builder.addCodeInterpreterTool */ + fun addCodeInterpreterTool( + codeInterpreterToolAuto: Tool.CodeInterpreter.Container.CodeInterpreterToolAuto + ) = apply { paramsBuilder.addCodeInterpreterTool(codeInterpreterToolAuto) } + + /** @see ResponseCreateParams.Builder.addTool */ + fun addTool(imageGeneration: Tool.ImageGeneration) = apply { + paramsBuilder.addTool(imageGeneration) + } + + /** @see ResponseCreateParams.Builder.addToolLocalShell */ + fun addToolLocalShell() = apply { paramsBuilder.addToolLocalShell() } + /** @see ResponseCreateParams.Builder.topP */ fun topP(topP: Double?) = apply { paramsBuilder.topP(topP) } diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt index dd5cdd57..970b3dbe 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonClassDescription import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.annotation.JsonProperty import com.fasterxml.jackson.annotation.JsonPropertyDescription +import com.fasterxml.jackson.annotation.JsonTypeName import com.fasterxml.jackson.databind.JsonNode import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.node.ObjectNode @@ -14,7 +15,6 @@ import org.assertj.core.api.Assertions.assertThatNoException import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.AfterTestExecutionCallback -import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.api.extension.RegisterExtension /** Tests for the `StructuredOutputs` functions and the [JsonSchemaValidator]. */ @@ -59,17 +59,14 @@ internal class StructuredOutputsTest { @Suppress("unused") @RegisterExtension val printValidationErrorsOnFailure: AfterTestExecutionCallback = - object : AfterTestExecutionCallback { - @Throws(Exception::class) - override fun afterTestExecution(context: ExtensionContext) { - if ( - context.displayName.startsWith("schemaTest_") && - (VERBOSE_MODE || context.executionException.isPresent) - ) { - // Test failed. - println("Schema: ${schema.toPrettyString()}\n") - println("$validator\n") - } + AfterTestExecutionCallback { context -> + if ( + context.displayName.startsWith("schemaTest_") && + (VERBOSE_MODE || context.executionException.isPresent) + ) { + // Test failed. + println("Schema: ${schema.toPrettyString()}\n") + println("$validator\n") } } @@ -83,7 +80,34 @@ internal class StructuredOutputsTest { schema = extractSchema(X::class.java) validator.validate(schema) - assertThat(validator.isValid()).isTrue + // Expect a failure. If a class has no properties, then the schema is meaningless to the + // AI. It can only reply with values to _named_ properties, so there must be at least one. + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'properties' field is missing, empty or not an object.") + } + + @Test + fun schemaTest_mapHasNoNamedProperties() { + @Suppress("unused") class X(val m: Map) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // A map results in a schema that declares an "object" sub-schema, but that sub-schema has + // no named `"properties"` and no `"required"` array. Only the first problem is reported. + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/properties/m: 'properties' field is missing, empty or not an object.") + + // Do this check of `toString()` once for a validation failure, but do not repeat it in + // other tests. + assertThat(validator.toString()) + .isEqualTo( + "JsonSchemaValidator{isValidationComplete=true, totalStringLength=1, " + + "totalObjectProperties=1, totalEnumValues=0, errors=[" + + "#/properties/m: 'properties' field is missing, empty or not an object.]}" + ) } @Test @@ -115,6 +139,11 @@ internal class StructuredOutputsTest { // The reason for the failure is that generic type information is erased for scopes like // local variables, but generic type information for fields is retained as part of the class // metadata. This is the expected behavior in Java, so this test expects an invalid schema. + // + // The `extractSchema` function could be defined to accept type parameters and these could + // be passed to the schema generator (which accepts them) and the above would work. However, + // there would be no simple way to deserialize the JSON response back to a parameterized + // type like `List` without again providing the type parameters. assertThat(validator.isValid()).isFalse assertThat(validator.errors()).hasSize(2) assertThat(validator.errors()[0]).isEqualTo("#/items: Schema or sub-schema is empty.") @@ -161,6 +190,14 @@ internal class StructuredOutputsTest { validator.validate(schema) assertThat(validator.isValid()).isTrue + + // Do this check of `toString()` once for a validation success, but do not repeat it in + // other tests. + assertThat(validator.toString()) + .isEqualTo( + "JsonSchemaValidator{isValidationComplete=true, totalStringLength=10, " + + "totalObjectProperties=0, totalEnumValues=2, errors=[]}" + ) } @Test @@ -517,6 +554,32 @@ internal class StructuredOutputsTest { @Test fun schemaTest_propertyNotMarkedRequired() { + // Use two properties, so the `"required"` array is not empty, but is still not listing + // _all_ of the properties. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "name" : { "type" : "string" }, + "address" : { "type" : "string" } + }, + "additionalProperties" : false, + "required" : [ "name" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'address' is not listed as 'required'.") + } + + @Test + fun schemaTest_requiredArrayEmpty() { schema = parseJson( """ @@ -533,7 +596,7 @@ internal class StructuredOutputsTest { assertThat(validator.errors()).hasSize(1) assertThat(validator.errors()[0]) - .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + .isEqualTo("#: 'required' field is missing, empty or not an array.") } @Test @@ -554,7 +617,7 @@ internal class StructuredOutputsTest { assertThat(validator.errors()).hasSize(1) assertThat(validator.errors()[0]) - .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + .isEqualTo("#: 'required' field is missing, empty or not an array.") } @Test @@ -574,7 +637,7 @@ internal class StructuredOutputsTest { assertThat(validator.errors()).hasSize(1) assertThat(validator.errors()[0]) - .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + .isEqualTo("#: 'required' field is missing, empty or not an array.") } @Test @@ -633,8 +696,11 @@ internal class StructuredOutputsTest { ) validator.validate(schema) - // For now, allow that an object may have no properties. Update this if that is revised. - assertThat(validator.isValid()).isTrue() + // An object must explicitly declare some properties, as no `"additionalProperties"` will + // be allowed and the AI model will have nothing it can populate. + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'properties' field is missing, empty or not an object.") } @Test @@ -655,7 +721,7 @@ internal class StructuredOutputsTest { assertThat(validator.errors()).hasSize(1) assertThat(validator.errors()[0]) - .isEqualTo("#: 'properties' field is not a non-empty object.") + .isEqualTo("#: 'properties' field is missing, empty or not an object.") } @Test @@ -676,7 +742,7 @@ internal class StructuredOutputsTest { assertThat(validator.errors()).hasSize(1) assertThat(validator.errors()[0]) - .isEqualTo("#: 'properties' field is not a non-empty object.") + .isEqualTo("#: 'properties' field is missing, empty or not an object.") } @Test @@ -1200,7 +1266,7 @@ internal class StructuredOutputsTest { @Test fun schemaTest_annotatedWithJsonClassDescription() { // Add a "description" to the root schema using an annotation. - @JsonClassDescription("A simple schema.") class X() + @Suppress("unused") @JsonClassDescription("A simple schema.") class X(val s: String) schema = extractSchema(X::class.java) validator.validate(schema) @@ -1352,6 +1418,14 @@ internal class StructuredOutputsTest { fun validatorBeforeValidation() { assertThat(validator.errors()).isEmpty() assertThat(validator.isValid()).isFalse + + // Do this check of `toString()` once for an unused validator, but do not repeat it in other + // tests. + assertThat(validator.toString()) + .isEqualTo( + "JsonSchemaValidator{isValidationComplete=false, totalStringLength=0, " + + "totalObjectProperties=0, totalEnumValues=0, errors=[]}" + ) } @Test @@ -1440,7 +1514,7 @@ internal class StructuredOutputsTest { } @Test - fun fromClassEnablesStrictAdherenceToSchema() { + fun responseFormatFromClassEnablesStrictAdherenceToSchema() { @Suppress("unused") class X(val s: String) val jsonSchema = responseFormatFromClass(X::class.java) @@ -1449,11 +1523,21 @@ internal class StructuredOutputsTest { // to the JSON schema. assertThat(jsonSchema.jsonSchema().strict()).isPresent assertThat(jsonSchema.jsonSchema().strict().get()).isTrue + + // The `schema()` accessor cannot be called successfully because of the way the field was + // set to a schema. This is OK, as the serialization will still work. Just confirm the + // expected failure, so if the conditions change, they will be noticed. + assertThatThrownBy { jsonSchema.jsonSchema().schema() } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + + // Use the `_schema()` accessor instead and check that the value is not null or missing. + assertThat(jsonSchema.jsonSchema()._schema()) + .isNotInstanceOfAny(JsonMissing::class.java, JsonNull::class.java) } @Test @Suppress("unused") - fun fromClassSuccessWithoutValidation() { + fun responseFormatFromClassSuccessWithoutValidation() { // Exceed the maximum nesting depth, but do not enable validation. class U(val s: String) class V(val u: U) @@ -1468,7 +1552,7 @@ internal class StructuredOutputsTest { } @Test - fun fromClassSuccessWithValidation() { + fun responseFormatFromClassSuccessWithValidation() { @Suppress("unused") class X(val s: String) assertThatNoException().isThrownBy { @@ -1478,7 +1562,7 @@ internal class StructuredOutputsTest { @Test @Suppress("unused") - fun fromClassFailureWithValidation() { + fun responseFormatFromClassFailureWithValidation() { // Exceed the maximum nesting depth and enable validation. class U(val s: String) class V(val u: U) @@ -1498,8 +1582,8 @@ internal class StructuredOutputsTest { @Test @Suppress("unused") - fun fromClassFailureWithValidationDefault() { - // Confirm that the default value of the `localValidation` argument is `true` by expecting + fun responseFormatFromClassFailureWithValidationDefault() { + // Confirm that the default value of the `localValidation` argument is `YES` by expecting // a validation error when that argument is not given an explicit value. class U(val s: String) class V(val u: U) @@ -1517,4 +1601,338 @@ internal class StructuredOutputsTest { "/properties/s: Current nesting depth is 6, but maximum is 5." ) } + + @Test + fun textConfigFromClassEnablesStrictAdherenceToSchema() { + @Suppress("unused") class X(val s: String) + + val textConfig = textConfigFromClass(X::class.java) + val jsonSchema = textConfig.format().get().jsonSchema().get() + + // The "strict" flag _must_ be set to ensure that the model's output will _always_ conform + // to the JSON schema. + assertThat(jsonSchema.strict()).isPresent + assertThat(jsonSchema.strict().get()).isTrue + + // The `schema()` accessor cannot be called successfully because of the way the field was + // set to a schema. This is OK, as the serialization will still work. Just confirm the + // expected failure, so if the conditions change, they will be noticed. + assertThatThrownBy { jsonSchema.schema() } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + + // Use the `_schema()` accessor instead and check that the value is not null or missing. + assertThat(jsonSchema._schema()) + .isNotInstanceOfAny(JsonMissing::class.java, JsonNull::class.java) + } + + @Test + @Suppress("unused") + fun textConfigFromClassSuccessWithoutValidation() { + // Exceed the maximum nesting depth, but do not enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatNoException().isThrownBy { + textConfigFromClass(Z::class.java, JsonSchemaLocalValidation.NO) + } + } + + @Test + fun textConfigFromClassSuccessWithValidation() { + @Suppress("unused") class X(val s: String) + + assertThatNoException().isThrownBy { + textConfigFromClass(X::class.java, JsonSchemaLocalValidation.YES) + } + } + + @Test + @Suppress("unused") + fun textConfigFromClassFailureWithValidation() { + // Exceed the maximum nesting depth and enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatThrownBy { textConfigFromClass(Z::class.java, JsonSchemaLocalValidation.YES) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + @Suppress("unused") + fun textConfigFromClassFailureWithValidationDefault() { + // Confirm that the default value of the `localValidation` argument is `YES` by expecting + // a validation error when that argument is not given an explicit value. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + // Use default for `localValidation` flag. + assertThatThrownBy { textConfigFromClass(Z::class.java) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + fun extractFunctionInfoUsingClassNameAndNoDescription() { + @Suppress("unused") class X(val s: String) + + val info = extractFunctionInfo(X::class.java, JsonSchemaLocalValidation.NO) + + assertThat(info.name).isEqualTo("X") + assertThat(info.description).isNull() + } + + @Test + fun extractFunctionInfoUsingAnnotationNameAndNoDescription() { + @Suppress("unused") @JsonTypeName("fnX") class X(val s: String) + + val info = extractFunctionInfo(X::class.java, JsonSchemaLocalValidation.NO) + + assertThat(info.name).isEqualTo("fnX") + assertThat(info.description).isNull() + } + + @Test + fun extractFunctionInfoUsingClassNameAndAnnotationDescription() { + @Suppress("unused") @JsonClassDescription("Something about X") class X(val s: String) + + val info = extractFunctionInfo(X::class.java, JsonSchemaLocalValidation.NO) + + assertThat(info.name).isEqualTo("X") + assertThat(info.description).isEqualTo("Something about X") + // If the description annotation is set, it will be added to the schema by the generator, + // but should them be moved out by `extractFunctionInfo` into the function info. + assertThat(info.schema.get("description")).isNull() + } + + @Test + fun functionToolFromClassEnablesStrictAdherenceToSchema() { + @Suppress("unused") @JsonClassDescription("Something about X") class X(val s: String) + + val functionTool = functionToolFromClass(X::class.java) + val fnDef = functionTool.function() + + // The "strict" flag _must_ be set to ensure that the model's output will _always_ conform + // to the JSON schema. + assertThat(fnDef.strict()).isPresent + assertThat(fnDef.strict().get()).isTrue + // Test here that the name, description and parameters (schema) are applied. There is no + // need to test these again for the other cases. + assertThat(fnDef.name()).isEqualTo("X") + assertThat(fnDef.description()).isPresent + assertThat(fnDef.description().get()).isEqualTo("Something about X") + + // The `parameters()` accessor cannot be called successfully because of the way the field + // was set to a schema. This is OK, as the serialization will still work. Just confirm the + // expected failure, so if the conditions change, they will be noticed. + assertThatThrownBy { fnDef.parameters() } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + + // Use the `_parameters()` accessor instead and check that the value is not null or missing. + assertThat(fnDef._parameters()) + .isNotInstanceOfAny(JsonMissing::class.java, JsonNull::class.java) + } + + @Test + @Suppress("unused") + fun functionToolFromClassSuccessWithoutValidation() { + // Exceed the maximum nesting depth, but do not enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatNoException().isThrownBy { + functionToolFromClass(Z::class.java, JsonSchemaLocalValidation.NO) + } + } + + @Test + fun functionToolFromClassSuccessWithValidation() { + @Suppress("unused") class X(val s: String) + + assertThatNoException().isThrownBy { + functionToolFromClass(X::class.java, JsonSchemaLocalValidation.YES) + } + } + + @Test + @Suppress("unused") + fun functionToolFromClassFailureWithValidation() { + // Exceed the maximum nesting depth and enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatThrownBy { functionToolFromClass(Z::class.java, JsonSchemaLocalValidation.YES) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + @Suppress("unused") + fun functionToolFromClassFailureWithValidationDefault() { + // Confirm that the default value of the `localValidation` argument is `YES` by expecting a + // validation error when that argument is not given an explicit value. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + // Use default for `localValidation` flag. + assertThatThrownBy { functionToolFromClass(Z::class.java) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + fun responseFunctionToolFromClassEnablesStrictAdherenceToSchema() { + @Suppress("unused") @JsonClassDescription("Something about X") class X(val s: String) + + val fnTool = responseFunctionToolFromClass(X::class.java) + + // The "strict" flag _must_ be set to ensure that the model's output will _always_ conform + // to the JSON schema. + assertThat(fnTool.strict()).isPresent + assertThat(fnTool.strict().get()).isTrue + // Test here that the name, description and parameters (schema) are applied. There is no + // need to test these again for the other cases. + assertThat(fnTool.name()).isEqualTo("X") + assertThat(fnTool.description()).isPresent + assertThat(fnTool.description().get()).isEqualTo("Something about X") + + // The `parameters()` accessor cannot be called successfully because of the way the field + // was set to a schema. This is OK, as the serialization will still work. Just confirm the + // expected failure, so if the conditions change, they will be noticed. + assertThatThrownBy { fnTool.parameters() } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + + // Use the `_parameters()` accessor instead and check that the value is not null or missing. + assertThat(fnTool._parameters()) + .isNotInstanceOfAny(JsonMissing::class.java, JsonNull::class.java) + } + + @Test + @Suppress("unused") + fun responseFunctionToolFromClassSuccessWithoutValidation() { + // Exceed the maximum nesting depth, but do not enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatNoException().isThrownBy { + responseFunctionToolFromClass(Z::class.java, JsonSchemaLocalValidation.NO) + } + } + + @Test + fun responseFunctionToolFromClassSuccessWithValidation() { + @Suppress("unused") class X(val s: String) + + assertThatNoException().isThrownBy { + responseFunctionToolFromClass(X::class.java, JsonSchemaLocalValidation.YES) + } + } + + @Test + @Suppress("unused") + fun responseFunctionToolFromClassFailureWithValidation() { + // Exceed the maximum nesting depth and enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatThrownBy { + responseFunctionToolFromClass(Z::class.java, JsonSchemaLocalValidation.YES) + } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + @Suppress("unused") + fun responseFunctionToolFromClassFailureWithValidationDefault() { + // Confirm that the default value of the `localValidation` argument is `YES` by expecting + // a validation error when that argument is not given an explicit value. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + // Use default for `localValidation` flag. + assertThatThrownBy { responseFunctionToolFromClass(Z::class.java) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + fun toJsonString() { + val boolPrimitive = toJsonString(true) + val boolObject = toJsonString(java.lang.Boolean.TRUE) + val numberPrimitive = toJsonString(42) + val numberObject = toJsonString(Integer.valueOf(42)) + val stringObject = toJsonString("Hello, World!") + val optional = toJsonString(Optional.of("optional")) + val optionalNullable = toJsonString(Optional.ofNullable(null)) + + assertThat(boolPrimitive).isEqualTo("true") + assertThat(boolObject).isEqualTo("true") + assertThat(numberPrimitive).isEqualTo("42") + assertThat(numberObject).isEqualTo("42") + // String values should be in quotes. + assertThat(stringObject).isEqualTo("\"Hello, World!\"") + assertThat(optional).isEqualTo("\"optional\"") + assertThat(optionalNullable).isEqualTo("null") + } } diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt index c62eb8ed..267e17d4 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTestUtils.kt @@ -33,6 +33,7 @@ internal val NULLABLE_DOUBLE: Double? = null internal val LIST = listOf(STRING) internal val SET = setOf(STRING) internal val MAP = mapOf(STRING to STRING) +internal val CLASS = X::class.java /** * Defines a test case where a function in a delegator returns a value from a corresponding function @@ -104,12 +105,15 @@ internal fun checkAllDelegation( } // Drop the first parameter from each function, as it is the implicit "this" object and has - // the type of the class declaring the function, which will never match. + // the type of the class declaring the function, which will never match. Compare only the + // "classifiers" of the types, so that generic type parameters are ignored. For example, + // one `java.lang.Class` is then considered equal to another `java.lang.Class`. For + // the data set being processed, this simplification does not cause any problems. val supersetFunction = supersetClass.declaredFunctions.find { it.name == subsetFunction.name && - it.parameters.drop(1).map { it.type } == - subsetFunction.parameters.drop(1).map { it.type } + it.parameters.drop(1).map { it.type.classifier } == + subsetFunction.parameters.drop(1).map { it.type.classifier } } if (supersetFunction == null) { @@ -171,12 +175,11 @@ internal fun checkOneDelegationWrite( private fun invokeMethod(method: Method, target: Any, testCase: DelegationWriteTestCase) { val numParams = testCase.inputValues.size - val inputValue1 = testCase.inputValues[0] - val inputValue2 = testCase.inputValues.getOrNull(1) when (numParams) { - 1 -> method.invoke(target, inputValue1) - 2 -> method.invoke(target, inputValue1, inputValue2) + 0 -> method.invoke(target) + 1 -> method.invoke(target, testCase.inputValues[0]) + 2 -> method.invoke(target, testCase.inputValues[0], testCase.inputValues.getOrNull(1)) else -> fail { "Unexpected number of function parameters ($numParams)." } } } @@ -187,11 +190,12 @@ private fun invokeMethod(method: Method, target: Any, testCase: DelegationWriteT */ internal fun findDelegationMethod(target: Any, testCase: DelegationWriteTestCase): Method { val numParams = testCase.inputValues.size - val inputValue1: Any? = testCase.inputValues[0] + val inputValue1: Any? = if (numParams > 0) testCase.inputValues[0] else null val inputValue2 = if (numParams > 1) testCase.inputValues[1] else null val method = when (numParams) { + 0 -> findJavaMethod(target.javaClass, testCase.functionName) 1 -> if (inputValue1 != null) { findJavaMethod( diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt index 4b289198..287f9c00 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt @@ -1,10 +1,12 @@ package com.openai.models.chat.completions import com.openai.core.BOOLEAN +import com.openai.core.CLASS import com.openai.core.DOUBLE import com.openai.core.DelegationWriteTestCase import com.openai.core.JSON_FIELD import com.openai.core.JSON_VALUE +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.LIST import com.openai.core.LONG import com.openai.core.MAP @@ -105,6 +107,8 @@ internal class StructuredChatCompletionCreateParamsTest { private val HEADERS = Headers.builder().build() private val QUERY_PARAMS = QueryParams.builder().build() + private val VALIDATION = JsonSchemaLocalValidation.NO + // The list order follows the declaration order in `ChatCompletionCreateParams.Builder` for // easier maintenance. @JvmStatic @@ -216,6 +220,7 @@ internal class StructuredChatCompletionCreateParamsTest { DelegationWriteTestCase("tools", LIST), DelegationWriteTestCase("tools", JSON_FIELD), DelegationWriteTestCase("addTool", TOOL), + DelegationWriteTestCase("addTool", CLASS, VALIDATION), DelegationWriteTestCase("topLogprobs", NULLABLE_LONG), DelegationWriteTestCase("topLogprobs", LONG), DelegationWriteTestCase("topLogprobs", OPTIONAL), diff --git a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt index cf7349dc..9d5136d4 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/responses/StructuredResponseCreateParamsTest.kt @@ -1,10 +1,12 @@ package com.openai.models.responses import com.openai.core.BOOLEAN +import com.openai.core.CLASS import com.openai.core.DOUBLE import com.openai.core.DelegationWriteTestCase import com.openai.core.JSON_FIELD import com.openai.core.JSON_VALUE +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.LIST import com.openai.core.LONG import com.openai.core.MAP @@ -76,10 +78,20 @@ internal class StructuredResponseCreateParamsTest { .environment(ComputerTool.Environment.LINUX) .build() private val TOOL = Tool.ofFunction(FUNCTION_TOOL) + private val MCP_TOOL = Tool.Mcp.builder().serverLabel(STRING).serverUrl(STRING).build() + private val CODE_INTERPRETER_TOOL_AUTO = + Tool.CodeInterpreter.Container.CodeInterpreterToolAuto.builder().build() + private val CODE_INTERPRETER_CONTAINER = + Tool.CodeInterpreter.Container.ofCodeInterpreterToolAuto(CODE_INTERPRETER_TOOL_AUTO) + private val CODE_INTERPRETER_TOOL = + Tool.CodeInterpreter.builder().container(CODE_INTERPRETER_CONTAINER).build() + private val IMAGE_GENERATION_TOOL = Tool.ImageGeneration.builder().build() private val HEADERS = Headers.builder().build() private val QUERY_PARAMS = QueryParams.builder().build() + private val VALIDATION = JsonSchemaLocalValidation.NO + // The list order follows the declaration order in `ResponseCreateParams.Builder` for // easier maintenance. @JvmStatic @@ -95,6 +107,10 @@ internal class StructuredResponseCreateParamsTest { DelegationWriteTestCase("model", STRING), DelegationWriteTestCase("model", CHAT_MODEL), DelegationWriteTestCase("model", RESPONSES_ONLY_MODEL), + DelegationWriteTestCase("background", NULLABLE_BOOLEAN), + DelegationWriteTestCase("background", BOOLEAN), + DelegationWriteTestCase("background", OPTIONAL), + DelegationWriteTestCase("background", JSON_FIELD), DelegationWriteTestCase("include", LIST), DelegationWriteTestCase("include", OPTIONAL), DelegationWriteTestCase("include", JSON_FIELD), @@ -139,11 +155,19 @@ internal class StructuredResponseCreateParamsTest { DelegationWriteTestCase("tools", LIST), DelegationWriteTestCase("tools", JSON_FIELD), DelegationWriteTestCase("addTool", TOOL), + DelegationWriteTestCase("addTool", FUNCTION_TOOL), + DelegationWriteTestCase("addTool", CLASS, VALIDATION), DelegationWriteTestCase("addTool", FILE_SEARCH_TOOL), DelegationWriteTestCase("addFileSearchTool", LIST), - DelegationWriteTestCase("addTool", FUNCTION_TOOL), DelegationWriteTestCase("addTool", WEB_SEARCH_TOOL), DelegationWriteTestCase("addTool", COMPUTER_TOOL), + DelegationWriteTestCase("addTool", MCP_TOOL), + DelegationWriteTestCase("addTool", CODE_INTERPRETER_TOOL), + DelegationWriteTestCase("addCodeInterpreterTool", CODE_INTERPRETER_CONTAINER), + DelegationWriteTestCase("addCodeInterpreterTool", STRING), + DelegationWriteTestCase("addCodeInterpreterTool", CODE_INTERPRETER_TOOL_AUTO), + DelegationWriteTestCase("addTool", IMAGE_GENERATION_TOOL), + DelegationWriteTestCase("addToolLocalShell"), DelegationWriteTestCase("topP", NULLABLE_DOUBLE), DelegationWriteTestCase("topP", DOUBLE), DelegationWriteTestCase("topP", OPTIONAL), @@ -194,20 +218,30 @@ internal class StructuredResponseCreateParamsTest { @Test fun allBuilderDelegateFunctionsExistInDelegator() { - // The delegator class does not implement various functions from the delegate class: - // - text functions and body function - // - addCodeInterpreterTool methods - // - various tool-related methods (addTool variations, addToolLocalShell) - // - background-related methods checkAllDelegation( mockBuilderDelegate::class, builderDelegator::class, + // ************************************************************************************ + // NOTE: THIS TEST EXISTS TO ENSURE THAT WHEN NEW FUNCTIONS ARE ADDED MANUALLY OR VIA + // CODE GEN TO `ResponseCreateParams.Builder`, THAT THOSE FUNCTIONS ARE _ALSO_ ADDED + // _MANUALLY_ TO `StructuredResponseCreateParams.Builder`. FAILURE TO ADD THOSE + // FUNCTIONS RESULTS IN _MISSING_ FUNCTIONALITY WHEN USING STRUCTURED OUTPUTS. + // EXCEPTIONS ADDED TO THIS LIST ARE PRESENT BY DESIGN, NOT BECAUSE THE FUNCTIONS ARE + // SIMPLY NOT YET IMPLEMENTED IN THE DELEGATOR CLASS. + // + // DO NOT ADD EXCEPTIONS TO THIS LIST SIMPLY BECAUSE TESTS ARE FAILING. THE TESTS ARE + // SUPPOSED TO FAIL. ADD THE NEW FUNCTIONS TO `StructuredResponseCreateParams.Builder` + // AND ADD A PARAMETERIZED TEST TO `builderDelegationTestCases` (above) TO ENSURE + // CORRECT DELEGATION BEHAVIOR. + // ************************************************************************************ + + // For Structured Outputs, setting `body` would overwrite the previously set `text` + // property, which would break the Structured Outputs behavior. "body", + // For Structured Outputs, a new type-safe generic`text` function replaces all existing + // text functions, as they are mutually incompatible. This function has its own + // dedicated unit tests. "text", - "addCodeInterpreterTool", - "addTool", - "addToolLocalShell", - "background", ) } diff --git a/openai-java-example/src/main/java/com/openai/example/FunctionCallingExample.java b/openai-java-example/src/main/java/com/openai/example/FunctionCallingExample.java index 1557fcfa..7b8cff38 100644 --- a/openai-java-example/src/main/java/com/openai/example/FunctionCallingExample.java +++ b/openai-java-example/src/main/java/com/openai/example/FunctionCallingExample.java @@ -1,68 +1,76 @@ package com.openai.example; -import static com.openai.core.ObjectMappers.jsonMapper; - -import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; -import com.openai.core.JsonObject; -import com.openai.core.JsonValue; import com.openai.models.ChatModel; -import com.openai.models.FunctionDefinition; -import com.openai.models.FunctionParameters; import com.openai.models.chat.completions.*; import java.util.Collection; -import java.util.List; -import java.util.Map; public final class FunctionCallingExample { private FunctionCallingExample() {} + @JsonClassDescription("Gets the quality of the given SDK.") + static class GetSdkQuality { + @JsonPropertyDescription("The name of the SDK.") + public String name; + + public SdkQuality execute() { + return new SdkQuality(name, name.contains("OpenAI") ? "It's robust and polished!" : "*shrug*"); + } + } + + static class SdkQuality { + public String quality; + + public SdkQuality(String name, String evaluation) { + quality = name + ": " + evaluation; + } + } + + @JsonClassDescription("Gets the review score (out of 10) for the named SDK.") + static class GetSdkScore { + public String name; + + public int execute() { + return name.contains("OpenAI") ? 10 : 3; + } + } + public static void main(String[] args) { // Configures using one of: // - The `OPENAI_API_KEY` environment variable // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables OpenAIClient client = OpenAIOkHttpClient.fromEnv(); - // Use a builder so that we can append more messages to it below. - // Each time we call .build()` we get an immutable object that's unaffected by future mutations of the builder. + // Use a `Builder` so that more messages can be appended below. When `build()` is called, it + // creates an immutable object that is unaffected by future mutations of the builder. ChatCompletionCreateParams.Builder createParamsBuilder = ChatCompletionCreateParams.builder() .model(ChatModel.GPT_3_5_TURBO) .maxCompletionTokens(2048) - .addTool(ChatCompletionTool.builder() - .function(FunctionDefinition.builder() - .name("get-sdk-quality") - .description("Gets the quality of the given SDK.") - .parameters(FunctionParameters.builder() - .putAdditionalProperty("type", JsonValue.from("object")) - .putAdditionalProperty( - "properties", JsonValue.from(Map.of("name", Map.of("type", "string")))) - .putAdditionalProperty("required", JsonValue.from(List.of("name"))) - .putAdditionalProperty("additionalProperties", JsonValue.from(false)) - .build()) - .build()) - .build()) - .addUserMessage("How good are the following SDKs: OpenAI Java SDK, Unknown Company SDK"); + .addTool(GetSdkQuality.class) + .addTool(GetSdkScore.class) + .addUserMessage("How good are the following SDKs and what do reviewers say: " + + "OpenAI Java SDK, Unknown Company SDK."); client.chat().completions().create(createParamsBuilder.build()).choices().stream() .map(ChatCompletion.Choice::message) - // Add each assistant message onto the builder so that we keep track of the conversation for asking a - // follow-up question later. + // Add each assistant message onto the builder so that we keep track of the + // conversation for asking a follow-up question later. .peek(createParamsBuilder::addMessage) .flatMap(message -> { message.content().ifPresent(System.out::println); return message.toolCalls().stream().flatMap(Collection::stream); }) .forEach(toolCall -> { - String content = callFunction(toolCall.function()); + Object result = callFunction(toolCall.function()); // Add the tool call result to the conversation. createParamsBuilder.addMessage(ChatCompletionToolMessageParam.builder() .toolCallId(toolCall.id()) - .content(content) + .contentAsJson(result) .build()); - System.out.println(content); }); - System.out.println(); // Ask a follow-up question about the function call result. createParamsBuilder.addUserMessage("Why do you say that?"); @@ -71,23 +79,14 @@ public static void main(String[] args) { .forEach(System.out::println); } - private static String callFunction(ChatCompletionMessageToolCall.Function function) { - if (!function.name().equals("get-sdk-quality")) { - throw new IllegalArgumentException("Unknown function: " + function.name()); + private static Object callFunction(ChatCompletionMessageToolCall.Function function) { + switch (function.name()) { + case "GetSdkQuality": + return function.arguments(GetSdkQuality.class).execute(); + case "GetSdkScore": + return function.arguments(GetSdkScore.class).execute(); + default: + throw new IllegalArgumentException("Unknown function: " + function.name()); } - - JsonValue arguments; - try { - arguments = JsonValue.from(jsonMapper().readTree(function.arguments())); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException("Bad function arguments", e); - } - - String sdkName = ((JsonObject) arguments).values().get("name").asStringOrThrow(); - if (sdkName.contains("OpenAI")) { - return sdkName + ": It's robust and polished!"; - } - - return sdkName + ": *shrug*"; } } diff --git a/openai-java-example/src/main/java/com/openai/example/FunctionCallingRawExample.java b/openai-java-example/src/main/java/com/openai/example/FunctionCallingRawExample.java new file mode 100644 index 00000000..3ae451c6 --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/FunctionCallingRawExample.java @@ -0,0 +1,93 @@ +package com.openai.example; + +import static com.openai.core.ObjectMappers.jsonMapper; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonObject; +import com.openai.core.JsonValue; +import com.openai.models.ChatModel; +import com.openai.models.FunctionDefinition; +import com.openai.models.FunctionParameters; +import com.openai.models.chat.completions.*; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +public final class FunctionCallingRawExample { + private FunctionCallingRawExample() {} + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + + // Use a builder so that we can append more messages to it below. + // Each time we call .build()` we get an immutable object that's unaffected by future mutations of the builder. + ChatCompletionCreateParams.Builder createParamsBuilder = ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_3_5_TURBO) + .maxCompletionTokens(2048) + .addTool(ChatCompletionTool.builder() + .function(FunctionDefinition.builder() + .name("get-sdk-quality") + .description("Gets the quality of the given SDK.") + .parameters(FunctionParameters.builder() + .putAdditionalProperty("type", JsonValue.from("object")) + .putAdditionalProperty( + "properties", JsonValue.from(Map.of("name", Map.of("type", "string")))) + .putAdditionalProperty("required", JsonValue.from(List.of("name"))) + .putAdditionalProperty("additionalProperties", JsonValue.from(false)) + .build()) + .build()) + .build()) + .addUserMessage("How good are the following SDKs: OpenAI Java SDK, Unknown Company SDK"); + + client.chat().completions().create(createParamsBuilder.build()).choices().stream() + .map(ChatCompletion.Choice::message) + // Add each assistant message onto the builder so that we keep track of the conversation for asking a + // follow-up question later. + .peek(createParamsBuilder::addMessage) + .flatMap(message -> { + message.content().ifPresent(System.out::println); + return message.toolCalls().stream().flatMap(Collection::stream); + }) + .forEach(toolCall -> { + String result = callFunction(toolCall.function()); + // Add the tool call result to the conversation. + createParamsBuilder.addMessage(ChatCompletionToolMessageParam.builder() + .toolCallId(toolCall.id()) + .content(result) + .build()); + System.out.println(result); + }); + System.out.println(); + + // Ask a follow-up question about the function call result. + createParamsBuilder.addUserMessage("Why do you say that?"); + client.chat().completions().create(createParamsBuilder.build()).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .forEach(System.out::println); + } + + private static String callFunction(ChatCompletionMessageToolCall.Function function) { + if (!function.name().equals("get-sdk-quality")) { + throw new IllegalArgumentException("Unknown function: " + function.name()); + } + + JsonValue arguments; + try { + arguments = JsonValue.from(jsonMapper().readTree(function.arguments())); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Bad function arguments", e); + } + + String sdkName = ((JsonObject) arguments).values().get("name").asStringOrThrow(); + if (sdkName.contains("OpenAI")) { + return sdkName + ": It's robust and polished!"; + } + + return sdkName + ": *shrug*"; + } +} diff --git a/openai-java-example/src/main/java/com/openai/example/ResponsesFunctionCallingExample.java b/openai-java-example/src/main/java/com/openai/example/ResponsesFunctionCallingExample.java new file mode 100644 index 00000000..2e52e417 --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/ResponsesFunctionCallingExample.java @@ -0,0 +1,98 @@ +package com.openai.example; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.ResponseFunctionToolCall; +import com.openai.models.responses.ResponseInputItem; +import java.util.ArrayList; +import java.util.List; + +public final class ResponsesFunctionCallingExample { + private ResponsesFunctionCallingExample() {} + + @JsonClassDescription("Gets the quality of the given SDK.") + static class GetSdkQuality { + @JsonPropertyDescription("The name of the SDK.") + public String name; + + public SdkQuality execute() { + return new SdkQuality(name, name.contains("OpenAI") ? "It's robust and polished!" : "*shrug*"); + } + } + + static class SdkQuality { + public String quality; + + public SdkQuality(String name, String evaluation) { + quality = name + ": " + evaluation; + } + } + + @JsonClassDescription("Gets the review score (out of 10) for the given SDK.") + static class GetSdkScore { + public String name; + + public int execute() { + return name.contains("OpenAI") ? 10 : 3; + } + } + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + List inputs = new ArrayList<>(); + + inputs.add(ResponseInputItem.ofMessage(ResponseInputItem.Message.builder() + .addInputTextContent("What is the quality of the following SDKs and what do reviewers say: " + + "OpenAI Java SDK, Unknown Company SDK.") + .role(ResponseInputItem.Message.Role.USER) + .build())); + + // Use a `Builder` so that more messages can be appended below. When `build()` is called, it + // creates an immutable object that is unaffected by future mutations of the builder. + ResponseCreateParams.Builder createParamsBuilder = ResponseCreateParams.builder() + .model(ChatModel.GPT_3_5_TURBO) + .addTool(GetSdkQuality.class) + .addTool(GetSdkScore.class) + .maxOutputTokens(2048) + .input(ResponseCreateParams.Input.ofResponse(inputs)); + + client.responses().create(createParamsBuilder.build()).output().forEach(item -> { + if (item.isFunctionCall()) { + ResponseFunctionToolCall functionCall = item.asFunctionCall(); + + inputs.add(ResponseInputItem.ofFunctionCall(functionCall)); + inputs.add(ResponseInputItem.ofFunctionCallOutput(ResponseInputItem.FunctionCallOutput.builder() + .callId(functionCall.callId()) + .outputAsJson(callFunction(functionCall)) + .build())); + } + }); + + // Pass the function call results back to the model to complete the process. + createParamsBuilder.input(ResponseCreateParams.Input.ofResponse(inputs)); + + client.responses().create(createParamsBuilder.build()).output().stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .forEach(outputText -> System.out.println(outputText.text())); + } + + private static Object callFunction(ResponseFunctionToolCall function) { + switch (function.name()) { + case "GetSdkQuality": + return function.arguments(GetSdkQuality.class).execute(); + case "GetSdkScore": + return function.arguments(GetSdkScore.class).execute(); + default: + throw new IllegalArgumentException("Unknown function: " + function.name()); + } + } +} diff --git a/openai-java-example/src/main/java/com/openai/example/ResponsesFunctionCallingRawExample.java b/openai-java-example/src/main/java/com/openai/example/ResponsesFunctionCallingRawExample.java new file mode 100644 index 00000000..57a712a3 --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/ResponsesFunctionCallingRawExample.java @@ -0,0 +1,130 @@ +package com.openai.example; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonObject; +import com.openai.core.JsonValue; +import com.openai.models.ChatModel; +import com.openai.models.responses.FunctionTool; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.ResponseFunctionToolCall; +import com.openai.models.responses.ResponseInputItem; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public final class ResponsesFunctionCallingRawExample { + private ResponsesFunctionCallingRawExample() {} + + static class SdkQuality { + public String quality; + + public SdkQuality(String name, String evaluation) { + quality = name + ": " + evaluation; + } + } + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + List inputs = new ArrayList<>(); + + inputs.add(ResponseInputItem.ofMessage(ResponseInputItem.Message.builder() + .addInputTextContent("What is the quality of the following SDKs and what do reviewers say: " + + "OpenAI Java SDK, Unknown Company SDK.") + .role(ResponseInputItem.Message.Role.USER) + .build())); + + // Use a `Builder` so that more messages can be appended below. When `build()` is called, it + // creates an immutable object that is unaffected by future mutations of the builder. + ResponseCreateParams.Builder createParamsBuilder = ResponseCreateParams.builder() + .model(ChatModel.GPT_3_5_TURBO) + .addTool(FunctionTool.builder() + .name("get-sdk-quality") + .description("Gets the quality of the given SDK.") + .parameters(FunctionTool.Parameters.builder() + .putAdditionalProperty("type", JsonValue.from("object")) + .putAdditionalProperty( + "properties", + JsonValue.from(Map.of( + "name", + Map.of("type", "string", "description", "The name of the SDK.")))) + .putAdditionalProperty("required", JsonValue.from(List.of("name"))) + .putAdditionalProperty("additionalProperties", JsonValue.from(false)) + .build()) + .strict(true) + .build()) + .addTool(FunctionTool.builder() + .name("get-sdk-score") + .description("Gets the review score (out of 10) for the given SDK.") + .parameters(FunctionTool.Parameters.builder() + .putAdditionalProperty("type", JsonValue.from("object")) + .putAdditionalProperty( + "properties", JsonValue.from(Map.of("name", Map.of("type", "string")))) + .putAdditionalProperty("required", JsonValue.from(List.of("name"))) + .putAdditionalProperty("additionalProperties", JsonValue.from(false)) + .build()) + .strict(true) + .build()) + .maxOutputTokens(2048) + .input(ResponseCreateParams.Input.ofResponse(inputs)); + + client.responses().create(createParamsBuilder.build()).output().forEach(item -> { + if (item.isFunctionCall()) { + ResponseFunctionToolCall functionCall = item.asFunctionCall(); + + inputs.add(ResponseInputItem.ofFunctionCall(functionCall)); + inputs.add(ResponseInputItem.ofFunctionCallOutput(ResponseInputItem.FunctionCallOutput.builder() + .callId(functionCall.callId()) + .output(callFunction(functionCall)) + .build())); + } + }); + + // Pass the function call results back to the model to complete the process. + createParamsBuilder.input(ResponseCreateParams.Input.ofResponse(inputs)); + + client.responses().create(createParamsBuilder.build()).output().stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .forEach(outputText -> System.out.println(outputText.text())); + } + + private static String callFunction(ResponseFunctionToolCall function) { + ObjectMapper mapper = new ObjectMapper(); + JsonValue arguments; + + try { + arguments = JsonValue.from(mapper.readTree(function.arguments())); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Bad function arguments", e); + } + + String sdkName = ((JsonObject) arguments).values().get("name").asStringOrThrow(); + Object result; + + switch (function.name()) { + case "get-sdk-quality": + result = new SdkQuality(sdkName, sdkName.contains("OpenAI") ? "It's robust and polished!" : "*shrug*"); + break; + + case "get-sdk-score": + result = sdkName.contains("OpenAI") ? 10 : 3; + break; + + default: + throw new IllegalArgumentException("Unknown function: " + function.name()); + } + + try { + return mapper.writeValueAsString(result); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("Bad function result", e); + } + } +}