Skip to content

Commit b86aecd

Browse files
committed
Merge PR TheoKanning#424 [Chat completion API] Support tools and tool_choice by @Tudor44
1 parent be26563 commit b86aecd

File tree

8 files changed

+169
-7
lines changed

8 files changed

+169
-7
lines changed

api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,17 @@ public static ChatCompletionRequestFunctionCall of(String name) {
118118
}
119119

120120
}
121+
122+
/**
123+
* A list of tools the model may call. Currently, only functions are supported as a tool.
124+
*/
125+
List<ChatTool> tools;
126+
127+
/**
128+
* Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function.
129+
*/
130+
@JsonProperty("tool_choice")
131+
String toolChoice;
132+
133+
121134
}

api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import com.fasterxml.jackson.annotation.JsonProperty;
55
import lombok.*;
66

7+
import java.util.List;
8+
79
/**
810
* <p>Each object has a role (either "system", "user", or "assistant") and content (the content of the message). Conversations can be as short as 1 message or fill many pages.</p>
911
* <p>Typically, a conversation is formatted with a system message first, followed by alternating user and assistant messages.</p>
@@ -30,6 +32,10 @@ public class ChatMessage {
3032
String content;
3133
//name is optional, The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
3234
String name;
35+
36+
@JsonProperty("tool_calls")
37+
List<ChatToolCalls> toolCalls;
38+
3339
@JsonProperty("function_call")
3440
ChatFunctionCall functionCall;
3541

api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ public enum ChatMessageRole {
77
SYSTEM("system"),
88
USER("user"),
99
ASSISTANT("assistant"),
10-
FUNCTION("function");
10+
FUNCTION("function"),
11+
TOOL("tool");
1112

1213
private final String value;
1314

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import lombok.Builder;
5+
import lombok.Data;
6+
7+
/**
8+
* <p>Chat Message specialization for tool system
9+
* </p>
10+
*
11+
* see here for more info <a href="https://platform.openai.com/docs/guides/function-calling">Function Calling</a>
12+
*/
13+
14+
@Data
15+
public class ChatMessageTool extends ChatMessage {
16+
17+
@JsonProperty("tool_call_id")
18+
private String toolCallId;
19+
20+
public ChatMessageTool(String toolCallId, String role, String content, String name) {
21+
super(role,content,name);
22+
this.toolCallId = toolCallId;
23+
}
24+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import lombok.Data;
4+
import lombok.NoArgsConstructor;
5+
import lombok.NonNull;
6+
7+
@Data
8+
@NoArgsConstructor
9+
public class ChatTool<T> {
10+
11+
12+
/**
13+
* The name of the tool being called, only function supported for now.
14+
*/
15+
@NonNull
16+
private String type = "function";
17+
18+
19+
@NonNull
20+
private T function;
21+
22+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import com.fasterxml.jackson.databind.JsonNode;
4+
import lombok.AllArgsConstructor;
5+
import lombok.Data;
6+
import lombok.NoArgsConstructor;
7+
8+
@Data
9+
@AllArgsConstructor
10+
@NoArgsConstructor
11+
public class ChatToolCalls {
12+
13+
/**
14+
* The ID of the tool call
15+
*/
16+
String id;
17+
18+
/**
19+
* The type of the tool. Currently, only function is supported.
20+
*/
21+
String type;
22+
23+
/**
24+
* The function that the model called.
25+
*/
26+
ChatFunctionCall function;
27+
28+
}

service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
import com.fasterxml.jackson.databind.ObjectMapper;
66
import com.fasterxml.jackson.databind.node.ObjectNode;
77
import com.fasterxml.jackson.databind.node.TextNode;
8-
import com.theokanning.openai.completion.chat.ChatFunction;
9-
import com.theokanning.openai.completion.chat.ChatFunctionCall;
10-
import com.theokanning.openai.completion.chat.ChatMessage;
11-
import com.theokanning.openai.completion.chat.ChatMessageRole;
8+
import com.theokanning.openai.completion.chat.*;
129

1310
import java.util.*;
1411

service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ void streamChatCompletion() {
8383
assertTrue(chunks.size() > 0);
8484
assertNotNull(chunks.get(0).getChoices().get(0));
8585
}
86-
8786
@Test
8887
void createChatCompletionWithFunctions() {
8988
final List<ChatFunction> functions = Collections.singletonList(ChatFunction.builder()
@@ -300,4 +299,76 @@ void streamChatCompletionWithDynamicFunctions() {
300299
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit"));
301300
}
302301

303-
}
302+
@Test
303+
void createChatCompletionWithToolFunctions() {
304+
305+
final List<ChatFunction> functions = Collections.singletonList(ChatFunction.builder()
306+
.name("get_weather")
307+
.description("Get the current weather in a given location")
308+
.executor(Weather.class, w -> new WeatherResponse(w.location, w.unit, 25, "sunny"))
309+
.build());
310+
final FunctionExecutor functionExecutor = new FunctionExecutor(functions);
311+
final ChatTool tool = new ChatTool();
312+
tool.setFunction(functionExecutor.getFunctions().get(0));
313+
final List<ChatMessage> messages = new ArrayList<>();
314+
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant.");
315+
final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?");
316+
messages.add(systemMessage);
317+
messages.add(userMessage);
318+
319+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
320+
.builder()
321+
.model("gpt-3.5-turbo-0613")
322+
.messages(messages)
323+
.tools(List.of(tool))
324+
.toolChoice("auto")
325+
.n(1)
326+
.maxTokens(100)
327+
.logitBias(new HashMap<>())
328+
.build();
329+
330+
ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0);
331+
assertEquals("tool_calls", choice.getFinishReason());
332+
333+
assertEquals("get_weather", choice.getMessage().getToolCalls().get(0).getFunction().getName());
334+
assertInstanceOf(ObjectNode.class, choice.getMessage().getToolCalls().get(0).getFunction().getArguments());
335+
336+
ChatMessage callResponse = functionExecutor.executeAndConvertToMessageHandlingExceptions(choice.getMessage().getToolCalls().get(0).getFunction());
337+
assertNotEquals("error", callResponse.getName());
338+
339+
// this performs an unchecked cast
340+
WeatherResponse functionExecutionResponse = functionExecutor.execute(choice.getMessage().getToolCalls().get(0).getFunction());
341+
assertInstanceOf(WeatherResponse.class, functionExecutionResponse);
342+
assertEquals(25, functionExecutionResponse.temperature);
343+
344+
JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(choice.getMessage().getToolCalls().get(0).getFunction());
345+
assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse);
346+
assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText());
347+
348+
//Construct message for tool_calls
349+
ChatMessageTool chatMessageTool = new ChatMessageTool(choice.getMessage().getToolCalls().get(0).getId(),
350+
ChatMessageRole.TOOL.value(),
351+
jsonFunctionExecutionResponse.toString(),
352+
choice.getMessage().getToolCalls().get(0).getFunction().getName());
353+
354+
messages.add(choice.getMessage());
355+
messages.add(chatMessageTool);
356+
357+
ChatCompletionRequest chatCompletionRequest2 = ChatCompletionRequest
358+
.builder()
359+
.model("gpt-3.5-turbo-0613")
360+
.messages(messages)
361+
.tools(List.of(tool))
362+
.toolChoice("auto")
363+
.n(1)
364+
.maxTokens(100)
365+
.logitBias(new HashMap<>())
366+
.build();
367+
368+
ChatCompletionChoice choice2 = service.createChatCompletion(chatCompletionRequest2).getChoices().get(0);
369+
assertNotEquals("tool_calls", choice2.getFinishReason()); // could be stop or length, but should not be function_call
370+
assertNull(choice2.getMessage().getFunctionCall());
371+
assertNotNull(choice2.getMessage().getContent());
372+
}
373+
374+
}

0 commit comments

Comments
 (0)