From 778a478c0df90ea660a5becec0b331d2c07e3b94 Mon Sep 17 00:00:00 2001 From: "Jared Davidson (Archetapp)" Date: Thu, 12 Dec 2024 17:47:50 -0700 Subject: [PATCH 1/7] Tests introduction --- ...reternaturalAI_Tests+FunctionCalling.swift | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift new file mode 100644 index 0000000..bd927fb --- /dev/null +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift @@ -0,0 +1,85 @@ +// +// PreternaturalAI_Tests+FunctionCalling.swift +// Gemini +// +// Created by Jared Davidson on 12/12/24. +// + +import AI +import CorePersistence +import Gemini +import Testing + +@Suite struct GeminiFunctionCallingTests { + let llm: any LLMRequestHandling = Gemini.Client(apiKey: "AIzaSyBmz1E3wsIm93XpxSByWVurLiWqNXLZ_hQ") + + @Test + func testFunctionCalling() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a Meteorologist Expert accurately giving weather data in fahrenheit at any given city around the world" + }, + .user { + "What is the weather in San Francisco, CA?" + } + ] + + let functionCall: AbstractLLM.ChatFunctionCall = try await llm.complete( + messages, + functions: [makeGetWeatherFunction()], + model: Gemini.Model.gemini_1_5_pro_latest, + as: .functionCall + ) + + dump(functionCall) + + let result = try functionCall.decode(GetWeatherParameters.self) + + #expect(result.weather.first?.unit_fahrenheit != nil, "Weather temperature should not be nil") + print(result) + } + + private func makeGetWeatherFunction() -> AbstractLLM.ChatFunctionDefinition { + let weatherObjectSchema = JSONSchema( + type: .object, + description: "Weather in a certain location", + properties: [ + "location": JSONSchema( + type: .string, + description: "The city and state, e.g. San Francisco, CA" + ), + "unit_fahrenheit": JSONSchema( + type: .number, + description: "The unit of temperature in 'fahrenheit'" + ) + ], + required: [ + "location", + "unit_fahrenheit" + ] + ) + + let getWeatherFunction: AbstractLLM.ChatFunctionDefinition = AbstractLLM.ChatFunctionDefinition( + name: "get_weather", + context: "Get the current weather in a given location", + parameters: JSONSchema( + type: .object, + properties: [ + "weather": .array(weatherObjectSchema) + ], + required: ["weather"] + ) + ) + + return getWeatherFunction + } + + struct GetWeatherParameters: Codable, Hashable, Sendable { + let weather: [WeatherObject] + } + + struct WeatherObject: Codable, Hashable, Sendable { + let location: String + let unit_fahrenheit: Double + } +} From 85681677f378d449cf8c5ec350be76f09b525325 Mon Sep 17 00:00:00 2001 From: "Jared Davidson (Archetapp)" Date: Thu, 12 Dec 2024 20:40:58 -0700 Subject: [PATCH 2/7] ChatFunctionCalls --- .../Gemini.Client+LLMRequestHandling.swift | 136 ++++++++++++- .../ModelContent+LargeLanguageModels.swift | 8 +- ...reternaturalAI_Tests+FunctionCalling.swift | 184 +++++++++++++----- 3 files changed, 269 insertions(+), 59 deletions(-) diff --git a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift index dfa5184..1aa37be 100644 --- a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift +++ b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift @@ -36,7 +36,7 @@ extension Gemini.Client: LLMRequestHandling { return try cast(_completion) } - + private func _complete( prompt: AbstractLLM.TextPrompt, parameters: AbstractLLM.TextCompletionParameters @@ -80,9 +80,79 @@ extension Gemini.Client: LLMRequestHandling { message: try AbstractLLM.ChatMessage(_from: firstCandidate.content), stopReason: try AbstractLLM.ChatCompletion.StopReason(_from: firstCandidate.finishReason.unwrap()) ) - + return completion } + + public func _complete( + _ messages: [AbstractLLM.ChatMessage], + functions: [AbstractLLM.ChatFunctionDefinition], + model: Gemini.Model, + as type: AbstractLLM.ChatFunctionCall.Type + ) async throws -> [FunctionCall] { + + //FIXME: This should ideally be AbstractLLM.ChatFunctionCall. + + let service = GenerativeAIService( + apiKey: configuration.apiKey, + urlSession: .shared + ) + + let functionDeclarations: [FunctionDeclaration] = functions.map { function in + FunctionDeclaration( + name: function.name.rawValue, + description: function.context, + parameters: [ + function.name.rawValue == "set_light_color" ? "rgb_hex" : "dummy": Schema( + type: .string, + description: function.parameters.properties?.first?.value.description ?? "Placeholder parameter" + ) + ], + requiredParameters: function.name.rawValue == "set_light_color" ? ["rgb_hex"] : nil + ) + } + + let systemMessage = messages.first { $0.role == .system } + let systemInstruction = ModelContent( + role: "system", + parts: [.text(try systemMessage?.content._stripToText() ?? "")] + ) + + let userMessages = messages.filter { $0.role != .system } + let userContent = userMessages.map { message in + ModelContent( + role: "user", + parts: [.text(try! message.content._stripToText())] + ) + } + + let request = GenerateContentRequest( + model: "models/" + model.rawValue, + contents: userContent, + generationConfig: nil, + safetySettings: nil, + tools: [Tool(functionDeclarations: functionDeclarations)], + toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)), + systemInstruction: systemInstruction, + isStreaming: false, + options: RequestOptions() + ) + + //FIXME: This should ideally be AbstractLLM.ChatFunctionCall. + + let response = try await service.loadRequest(request: request) + + dump(response) + + let functionCalls = response.candidates.first?.content.parts.compactMap { part -> FunctionCall? in + if case .functionCall(let functionCall) = part { + return functionCall + } + return nil + } ?? [] + + return functionCalls + } } extension Gemini.Client { @@ -101,7 +171,7 @@ extension Gemini.Client { stopSequences: parameters.stops ) } - + private func _makeSystemInstructionAndModelContent( messages: [AbstractLLM.ChatMessage] ) async throws -> (systemInstruction: ModelContent?, content: [ModelContent]) { @@ -117,19 +187,19 @@ extension Gemini.Client { var content: [ModelContent] = [] for message in messages { - try _tryAssert(message.role != .system) + try _tryAssert(message.role != .system) content.append(try await ModelContent(_from: message)) } return (systemInstruction, content) } - + private func _modelContent( from prompt: AbstractLLM.TextPrompt ) throws -> [ModelContent] { let promptText = try prompt.prefix.promptLiteral._stripToText() - + return [ModelContent(role: "user", parts: promptText)] } @@ -166,3 +236,57 @@ extension Gemini.Client { } } } + +extension AbstractLLM.ChatFunctionDefinition { + func toGeminiFunctionDeclaration() -> FunctionDeclaration { + func schemaToGeminiSchema(_ schema: JSONSchema) -> Schema { + switch schema.type { + case .string: + return Schema( + type: .string, + description: schema.description + ) + case .object: + var parameters: [String: Schema] = [:] + if let properties = schema.properties { + for (key, value) in properties { + parameters[key] = schemaToGeminiSchema(value) + } + } + return Schema( + type: .object, + description: schema.description, + properties: parameters, + requiredProperties: schema.required + ) + // Add other type conversions as needed + default: + return Schema(type: .string, description: "Fallback type") // Default fallback + } + } + + return FunctionDeclaration( + name: name.rawValue, + description: context, + parameters: schemaToGeminiSchema(parameters).properties, + requiredParameters: parameters.required + ) + } +} + +extension AbstractLLM.ChatRole { + // Convert AbstractLLM role to Gemini role string + func toGeminiRole() -> String { + switch self { + case .system: + return "systemInstruction" // Special case for Gemini + case .user: + return "user" + case .assistant: + return "model" // Gemini uses "model" instead of "assistant" + case .other(let value): + return value.rawValue + } + } +} + diff --git a/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift b/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift index 9325224..48da061 100644 --- a/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift +++ b/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift @@ -36,10 +36,10 @@ extension ModelContent { case .base64DataURL: TODO.unimplemented } - case .functionCall(_): - TODO.unimplemented - case .resultOfFunctionCall(_): - TODO.unimplemented + case .functionCall(let function): + print(function) + case .resultOfFunctionCall(let result): + print(result) } } diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift index bd927fb..ec52b25 100644 --- a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift @@ -10,76 +10,162 @@ import CorePersistence import Gemini import Testing -@Suite struct GeminiFunctionCallingTests { - let llm: any LLMRequestHandling = Gemini.Client(apiKey: "AIzaSyBmz1E3wsIm93XpxSByWVurLiWqNXLZ_hQ") +@Suite struct GeminiLightingSystemTests { + let apiKey = "" + let llm = Gemini.Client(apiKey: "") @Test - func testFunctionCalling() async throws { + func testLightingSystem() async throws { let messages: [AbstractLLM.ChatMessage] = [ .system { - "You are a Meteorologist Expert accurately giving weather data in fahrenheit at any given city around the world" + "You are a helpful lighting system bot. You can turn lights on and off, and you can set the color. Do not perform any other tasks." }, .user { - "What is the weather in San Francisco, CA?" + "Turn on the lights and set them to red." } ] - let functionCall: AbstractLLM.ChatFunctionCall = try await llm.complete( + // Define functions with correct structure + let functions = [ + AbstractLLM.ChatFunctionDefinition( + name: "enable_lights", + context: "Turn on the lighting system.", + parameters: JSONSchema( + type: .object, + properties: [ + "dummy": JSONSchema( + type: .string, + description: "Placeholder parameter" + ) + ] + ) + ), + AbstractLLM.ChatFunctionDefinition( + name: "set_light_color", + context: "Set the light color. Lights must be enabled for this to work.", + parameters: JSONSchema( + type: .object, + properties: [ + "rgb_hex": JSONSchema( + type: .string, + description: "The light color as a 6-digit hex string, e.g. ff0000 for red." + ) + ], + required: ["rgb_hex"] + ) + ), + AbstractLLM.ChatFunctionDefinition( + name: "stop_lights", + context: "Turn off the lighting system.", + parameters: JSONSchema( + type: .object, + properties: [ + "dummy": JSONSchema( + type: .string, + description: "Placeholder parameter" + ) + ] + ) + ) + ] + + //FIXME: This should ideally be returning or working with something from AbstractLLM. + + let functionCalls: [FunctionCall] = try await llm._complete( messages, - functions: [makeGetWeatherFunction()], + functions: functions, model: Gemini.Model.gemini_1_5_pro_latest, - as: .functionCall + as: AbstractLLM.ChatFunctionCall.self ) + + guard let lastFunctionCall = functionCalls.last?.args else { return } - dump(functionCall) - - let result = try functionCall.decode(GetWeatherParameters.self) - - #expect(result.weather.first?.unit_fahrenheit != nil, "Weather temperature should not be nil") - print(result) + let result = try lastFunctionCall.toJSONData().decode(LightingCommandParameters.self) + + #expect(result.rgb_hex != nil, "Light color parameter should not be nil") } - private func makeGetWeatherFunction() -> AbstractLLM.ChatFunctionDefinition { - let weatherObjectSchema = JSONSchema( - type: .object, - description: "Weather in a certain location", - properties: [ - "location": JSONSchema( - type: .string, - description: "The city and state, e.g. San Francisco, CA" - ), - "unit_fahrenheit": JSONSchema( - type: .number, - description: "The unit of temperature in 'fahrenheit'" - ) - ], - required: [ - "location", - "unit_fahrenheit" - ] - ) - - let getWeatherFunction: AbstractLLM.ChatFunctionDefinition = AbstractLLM.ChatFunctionDefinition( - name: "get_weather", - context: "Get the current weather in a given location", - parameters: JSONSchema( - type: .object, - properties: [ - "weather": .array(weatherObjectSchema) + @Test + func testDirectFunctionCalling() async throws { + // Define function declarations with proper empty object structure + let functionDeclarations = [ + FunctionDeclaration( + name: "enable_lights", + description: "Turn on the lighting system.", + parameters: [ + "dummy": Schema( + type: .string, // Changed from .object to .string + description: "Placeholder parameter" + ) + ] + ), + FunctionDeclaration( + name: "set_light_color", + description: "Set the light color. Lights must be enabled for this to work.", + parameters: [ + "rgb_hex": Schema( + type: .string, + description: "The light color as a 6-digit hex string, e.g. ff0000 for red." + ) ], - required: ["weather"] + requiredParameters: ["rgb_hex"] + ), + FunctionDeclaration( + name: "stop_lights", + description: "Turn off the lighting system.", + parameters: [ + "dummy": Schema( + type: .string, + description: "Placeholder parameter" + ) + ] ) + ] + + let tools = Tool(functionDeclarations: functionDeclarations) + let toolConfig = ToolConfig( + functionCallingConfig: FunctionCallingConfig(mode: .auto) + ) + + let systemInstruction = ModelContent( + role: "system", + parts: [.text("You are a helpful lighting system bot. You can turn lights on and off, and you can set the color. Do not perform any other tasks.")] + ) + + let userContent = [ModelContent( + role: "user", + parts: [.text("Turn on the lights and set them to red.")] + )] + + let model = GenerativeModel( + name: "gemini-1.5-pro-latest", + apiKey: apiKey, + tools: [tools], + toolConfig: toolConfig, + systemInstruction: systemInstruction ) - return getWeatherFunction + let response = try await model.generateContent(userContent) + + dump(response) + + #expect(!response.functionCalls.isEmpty, "Should have function calls") + + for functionCall in response.functionCalls { + #expect( + ["enable_lights", "set_light_color", "stop_lights"].contains(functionCall.name), + "Function call should be one of the defined functions" + ) + + if functionCall.name == "set_light_color", + let args = functionCall.args["rgb_hex"] as? String { + #expect(args == "ff0000", "Light color should be red (ff0000)") + } + } } - struct GetWeatherParameters: Codable, Hashable, Sendable { - let weather: [WeatherObject] - } - struct WeatherObject: Codable, Hashable, Sendable { - let location: String - let unit_fahrenheit: Double + struct LightingCommandParameters: Codable, Hashable, Sendable { + let rgb_hex: String? } } From 19bd6e0aecff10dbc00d75f9b052ea418a48bbd4 Mon Sep 17 00:00:00 2001 From: "Jared Davidson (Archetapp)" Date: Thu, 12 Dec 2024 20:51:34 -0700 Subject: [PATCH 3/7] removed hardcoding --- .../Gemini.Client+LLMRequestHandling.swift | 25 ++++++++----------- ...reternaturalAI_Tests+FunctionCalling.swift | 4 +-- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift index 1aa37be..7bf6b95 100644 --- a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift +++ b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift @@ -90,25 +90,24 @@ extension Gemini.Client: LLMRequestHandling { model: Gemini.Model, as type: AbstractLLM.ChatFunctionCall.Type ) async throws -> [FunctionCall] { - - //FIXME: This should ideally be AbstractLLM.ChatFunctionCall. - let service = GenerativeAIService( apiKey: configuration.apiKey, urlSession: .shared ) let functionDeclarations: [FunctionDeclaration] = functions.map { function in - FunctionDeclaration( + let parameterSchemas = function.parameters.properties?.reduce(into: [String: Schema]()) { result, property in + result[property.key] = Schema( + type: .string, + description: property.value.description + ) + } ?? [:] + + return FunctionDeclaration( name: function.name.rawValue, description: function.context, - parameters: [ - function.name.rawValue == "set_light_color" ? "rgb_hex" : "dummy": Schema( - type: .string, - description: function.parameters.properties?.first?.value.description ?? "Placeholder parameter" - ) - ], - requiredParameters: function.name.rawValue == "set_light_color" ? ["rgb_hex"] : nil + parameters: parameterSchemas, + requiredParameters: function.parameters.required ) } @@ -138,12 +137,8 @@ extension Gemini.Client: LLMRequestHandling { options: RequestOptions() ) - //FIXME: This should ideally be AbstractLLM.ChatFunctionCall. - let response = try await service.loadRequest(request: request) - dump(response) - let functionCalls = response.candidates.first?.content.parts.compactMap { part -> FunctionCall? in if case .functionCall(let functionCall) = part { return functionCall diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift index ec52b25..55867ad 100644 --- a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift @@ -87,14 +87,14 @@ import Testing @Test func testDirectFunctionCalling() async throws { - // Define function declarations with proper empty object structure + let functionDeclarations = [ FunctionDeclaration( name: "enable_lights", description: "Turn on the lighting system.", parameters: [ "dummy": Schema( - type: .string, // Changed from .object to .string + type: .string, description: "Placeholder parameter" ) ] From d4637692eb4d573bb7ae688ea72cd1282f6e45ab Mon Sep 17 00:00:00 2001 From: "Jared Davidson (Archetapp)" Date: Fri, 13 Dec 2024 09:27:35 -0700 Subject: [PATCH 4/7] Removed unused code --- .../Gemini.Client+LLMRequestHandling.swift | 54 ------------------- 1 file changed, 54 deletions(-) diff --git a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift index 7bf6b95..052764e 100644 --- a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift +++ b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift @@ -231,57 +231,3 @@ extension Gemini.Client { } } } - -extension AbstractLLM.ChatFunctionDefinition { - func toGeminiFunctionDeclaration() -> FunctionDeclaration { - func schemaToGeminiSchema(_ schema: JSONSchema) -> Schema { - switch schema.type { - case .string: - return Schema( - type: .string, - description: schema.description - ) - case .object: - var parameters: [String: Schema] = [:] - if let properties = schema.properties { - for (key, value) in properties { - parameters[key] = schemaToGeminiSchema(value) - } - } - return Schema( - type: .object, - description: schema.description, - properties: parameters, - requiredProperties: schema.required - ) - // Add other type conversions as needed - default: - return Schema(type: .string, description: "Fallback type") // Default fallback - } - } - - return FunctionDeclaration( - name: name.rawValue, - description: context, - parameters: schemaToGeminiSchema(parameters).properties, - requiredParameters: parameters.required - ) - } -} - -extension AbstractLLM.ChatRole { - // Convert AbstractLLM role to Gemini role string - func toGeminiRole() -> String { - switch self { - case .system: - return "systemInstruction" // Special case for Gemini - case .user: - return "user" - case .assistant: - return "model" // Gemini uses "model" instead of "assistant" - case .other(let value): - return value.rawValue - } - } -} - From 2444b132672e58055a635e0dacd6d913ad084a4d Mon Sep 17 00:00:00 2001 From: "Jared Davidson (Archetapp)" Date: Fri, 13 Dec 2024 14:48:11 -0700 Subject: [PATCH 5/7] CodeExecution --- .../Preternatural/AbstractLLM+Gemini.swift | 42 +++-- .../Gemini.Client+LLMRequestHandling.swift | 30 ++++ .../GoogleAITests/Preternatural/Config.swift | 12 ++ .../PreternaturalAI_Tests+CodeExecution.swift | 157 ++++++++++++++++++ ...reternaturalAI_Tests+FunctionCalling.swift | 4 +- .../Preternatural/PreternaturalAI_Tests.swift | 2 - 6 files changed, 229 insertions(+), 18 deletions(-) create mode 100644 Tests/GoogleAITests/Preternatural/Config.swift create mode 100644 Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+CodeExecution.swift diff --git a/Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift b/Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift index f59bce3..5f4288b 100644 --- a/Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift +++ b/Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift @@ -25,19 +25,35 @@ extension PromptLiteral { switch part { case .text(let value): self.init(stringLiteral: value) - case .data: - TODO.unimplemented - case .fileData: - TODO.unimplemented - case .functionCall: - TODO.unimplemented - case .functionResponse: - TODO.unimplemented - case .executableCode(_): - TODO.unimplemented - case .codeExecutionResult(_): - TODO.unimplemented - } + case .data(let mimetype, _): + + self.init(stringLiteral: "[Data: \(mimetype)]") + case .fileData(let mimetype, let uri): + + self.init(stringLiteral: "[File: \(mimetype) at \(uri)]") + case .functionCall(let functionCall): + + self.init(stringLiteral: "Function Call: \(functionCall.name)") + case .functionResponse(let functionResponse): + + self.init(stringLiteral: "Function Response: \(functionResponse.name)") + case .executableCode(let executableCode): + + self.init(stringLiteral: """ + ```\(executableCode.language.lowercased()) + \(executableCode.code) + ``` + """) + case .codeExecutionResult(let result): + + let status = result.outcome == .ok ? "Success" : "Error" + self.init(stringLiteral: """ + Execution Result (\(status)): + ``` + \(result.output) + ``` + """) + } } } diff --git a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift index 052764e..eaf8f7f 100644 --- a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift +++ b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift @@ -84,6 +84,8 @@ extension Gemini.Client: LLMRequestHandling { return completion } + + // Function Calling public func _complete( _ messages: [AbstractLLM.ChatMessage], functions: [AbstractLLM.ChatFunctionDefinition], @@ -148,6 +150,34 @@ extension Gemini.Client: LLMRequestHandling { return functionCalls } + + // Code Execution + public func _complete( + _ messages: [AbstractLLM.ChatMessage], + codeExecution: Bool, + model: Gemini.Model + ) async throws -> AbstractLLM.ChatCompletion { + let (systemInstruction, modelContent) = try await _makeSystemInstructionAndModelContent(messages: messages) + + let generativeModel = GenerativeModel( + name: model.rawValue, + apiKey: configuration.apiKey, + generationConfig: nil, + tools: codeExecution ? [Tool(codeExecution: CodeExecution())] : nil, + systemInstruction: systemInstruction + ) + + let response: GenerateContentResponse = try await generativeModel.generateContent(modelContent) + let firstCandidate: CandidateResponse = try response.candidates.toCollectionOfOne().value + + let completion = AbstractLLM.ChatCompletion( + prompt: AbstractLLM.ChatPrompt(messages: messages), + message: try AbstractLLM.ChatMessage(_from: firstCandidate.content), + stopReason: try AbstractLLM.ChatCompletion.StopReason(_from: firstCandidate.finishReason.unwrap()) + ) + + return completion + } } extension Gemini.Client { diff --git a/Tests/GoogleAITests/Preternatural/Config.swift b/Tests/GoogleAITests/Preternatural/Config.swift new file mode 100644 index 0000000..85ecbe1 --- /dev/null +++ b/Tests/GoogleAITests/Preternatural/Config.swift @@ -0,0 +1,12 @@ +// +// Config.swift +// Gemini +// +// Created by Jared Davidson on 12/13/24. +// + +import Gemini + +let apiKey = "" + +let client = Gemini.Client(apiKey: apiKey) diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+CodeExecution.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+CodeExecution.swift new file mode 100644 index 0000000..c68be76 --- /dev/null +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+CodeExecution.swift @@ -0,0 +1,157 @@ +// +// Untitled.swift +// Gemini +// +// Created by Jared Davidson on 12/13/24. +// + +import AI +import CorePersistence +import Gemini +import Testing + +@Suite struct GeminiCodeExecutionTests { + + @Test + func testAbstractLLMCodeExecution() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer." + }, + .user { + "What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this." + } + ] + + let completion = try await client._complete( + messages, + codeExecution: true, + model: .gemini_1_5_pro_latest + ) + + let messageContent = try completion.message.content._stripToText() + + #expect(!messageContent.isEmpty, "Response should not be empty") + #expect(messageContent.contains("```python"), "Should contain Python code block") + #expect(messageContent.contains("```"), "Should have code block formatting") + #expect(messageContent.contains("5117"), "Should contain correct sum of first 50 primes") + #expect(messageContent.contains("Output") || messageContent.contains("Result"), + "Should contain execution output") + } + + @Test + func testBasicCodeExecution() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer." + }, + .user { + "What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this." + } + ] + + let model = GenerativeModel( + name: "gemini-1.5-pro-latest", + apiKey: apiKey, + tools: [Tool(codeExecution: CodeExecution())], + systemInstruction: ModelContent( + role: "system", + parts: [.text("You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer.")] + ) + ) + + let userContent = [ModelContent( + role: "user", + parts: [.text("What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this.")] + )] + + let response = try await model.generateContent(userContent) + + let hasExecutableCode = response.candidates.first?.content.parts.contains { part in + if case .executableCode = part { return true } + return false + } ?? false + #expect(hasExecutableCode, "Response should contain executable code") + + let hasExecutionResults = response.candidates.first?.content.parts.contains { part in + if case .codeExecutionResult = part { return true } + return false + } ?? false + #expect(hasExecutionResults, "Response should contain code execution results") + + if let executionResult = response.candidates.first?.content.parts.compactMap({ part in + if case .codeExecutionResult(let result) = part { return result } + return nil + }).first { + #expect(executionResult.outcome == .ok, "Code execution should complete successfully") + #expect(!executionResult.output.isEmpty, "Execution output should not be empty") + #expect(executionResult.output.contains("5117"), "Output should contain correct sum of first 50 primes (5117)") + } + } + + @Test + func testComplexCodeExecution() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer." + }, + .user { + """ + I need a Python function that: + 1. Generates 100 random numbers between 1 and 1000 + 2. Sorts them in descending order + 3. Filters out any numbers divisible by 3 + 4. Returns the sum of the remaining numbers + + Please write and execute this code. + """ + } + ] + + let model = GenerativeModel( + name: "gemini-1.5-pro-latest", + apiKey: apiKey, + tools: [Tool(codeExecution: CodeExecution())], + systemInstruction: ModelContent( + role: "system", + parts: [.text("You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer.")] + ) + ) + + let userContent = try [ModelContent( + role: "user", + parts: [.text(messages.last!.content._stripToText())] + )] + + let response = try await model.generateContent(userContent) + + dump(response) + + // Validate code structure + if let executableCode = response.candidates.first?.content.parts.compactMap({ part in + if case .executableCode(let code) = part { return code } + return nil + }).first { + #expect(executableCode.language.lowercased() == "python", "Should use Python") + #expect(executableCode.code.contains("random.randint"), "Should use random number generation") + #expect(executableCode.code.contains("sort"), "Should include sorting") + #expect(executableCode.code.contains("filter"), "Should include filtering") + } + + // Validate execution outcome + if let executionResult = response.candidates.first?.content.parts.compactMap({ part in + if case .codeExecutionResult(let result) = part { return result } + return nil + }).first { + #expect(executionResult.outcome == .ok, "Code execution should complete successfully") + #expect(!executionResult.output.isEmpty, "Execution output should not be empty") + + // Parse the output number to verify it's a valid sum + let numberPattern = #/\d+/# + if let match = executionResult.output.firstMatch(of: numberPattern) { + let sum = Int(match.output) + #expect(sum != nil && sum! > 0, "Output should contain a positive number") + } + } + } +} diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift index 55867ad..4503f8d 100644 --- a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift @@ -11,8 +11,6 @@ import Gemini import Testing @Suite struct GeminiLightingSystemTests { - let apiKey = "" - let llm = Gemini.Client(apiKey: "") @Test func testLightingSystem() async throws { @@ -71,7 +69,7 @@ import Testing //FIXME: This should ideally be returning or working with something from AbstractLLM. - let functionCalls: [FunctionCall] = try await llm._complete( + let functionCalls: [FunctionCall] = try await client._complete( messages, functions: functions, model: Gemini.Model.gemini_1_5_pro_latest, diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests.swift index 9754cc2..4e095fb 100644 --- a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests.swift +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests.swift @@ -18,8 +18,6 @@ import SwiftUIX import XCTest final class PreternaturalAI_Tests: XCTestCase { - let client: any LLMRequestHandling = Gemini.Client(apiKey: "API_KEY") - func testAvailableModels() { let models = client._availableModels From f8dc9fec5b180be019eebbcb9be79e38579c7b97 Mon Sep 17 00:00:00 2001 From: "Jared Davidson (Archetapp)" Date: Fri, 13 Dec 2024 15:32:10 -0700 Subject: [PATCH 6/7] cleanup --- .../Preternatural/PreternaturalAI_Tests+FunctionCalling.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift index 4503f8d..6b1132c 100644 --- a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift @@ -23,7 +23,6 @@ import Testing } ] - // Define functions with correct structure let functions = [ AbstractLLM.ChatFunctionDefinition( name: "enable_lights", From 8ef53a14fdd696d8c0906465d9d127af5f632c1e Mon Sep 17 00:00:00 2001 From: "Jared Davidson (Archetapp)" Date: Fri, 20 Dec 2024 13:46:21 -0700 Subject: [PATCH 7/7] unimplemented. --- .../Preternatural/ModelContent+LargeLanguageModels.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift b/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift index 48da061..c6e2411 100644 --- a/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift +++ b/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift @@ -31,15 +31,15 @@ extension ModelContent { case .image(let image): let data: Data = try image.jpegData.unwrap() let mimeType: String = _MediaAssetFileType.jpeg.mimeType - + parts.append(.data(mimetype: mimeType, data)) case .base64DataURL: TODO.unimplemented } case .functionCall(let function): - print(function) + TODO.unimplemented case .resultOfFunctionCall(let result): - print(result) + TODO.unimplemented } }