diff --git a/Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift b/Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift index f59bce3..5f4288b 100644 --- a/Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift +++ b/Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift @@ -25,19 +25,35 @@ extension PromptLiteral { switch part { case .text(let value): self.init(stringLiteral: value) - case .data: - TODO.unimplemented - case .fileData: - TODO.unimplemented - case .functionCall: - TODO.unimplemented - case .functionResponse: - TODO.unimplemented - case .executableCode(_): - TODO.unimplemented - case .codeExecutionResult(_): - TODO.unimplemented - } + case .data(let mimetype, _): + + self.init(stringLiteral: "[Data: \(mimetype)]") + case .fileData(let mimetype, let uri): + + self.init(stringLiteral: "[File: \(mimetype) at \(uri)]") + case .functionCall(let functionCall): + + self.init(stringLiteral: "Function Call: \(functionCall.name)") + case .functionResponse(let functionResponse): + + self.init(stringLiteral: "Function Response: \(functionResponse.name)") + case .executableCode(let executableCode): + + self.init(stringLiteral: """ + ```\(executableCode.language.lowercased()) + \(executableCode.code) + ``` + """) + case .codeExecutionResult(let result): + + let status = result.outcome == .ok ? "Success" : "Error" + self.init(stringLiteral: """ + Execution Result (\(status)): + ``` + \(result.output) + ``` + """) + } } } diff --git a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift index dfa5184..eaf8f7f 100644 --- a/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift +++ b/Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift @@ -36,7 +36,7 @@ extension Gemini.Client: LLMRequestHandling { return try cast(_completion) } - + private func _complete( prompt: AbstractLLM.TextPrompt, parameters: AbstractLLM.TextCompletionParameters @@ -80,7 +80,102 @@ extension Gemini.Client: LLMRequestHandling { message: try AbstractLLM.ChatMessage(_from: firstCandidate.content), stopReason: try AbstractLLM.ChatCompletion.StopReason(_from: firstCandidate.finishReason.unwrap()) ) - + + return completion + } + + + // Function Calling + public func _complete( + _ messages: [AbstractLLM.ChatMessage], + functions: [AbstractLLM.ChatFunctionDefinition], + model: Gemini.Model, + as type: AbstractLLM.ChatFunctionCall.Type + ) async throws -> [FunctionCall] { + let service = GenerativeAIService( + apiKey: configuration.apiKey, + urlSession: .shared + ) + + let functionDeclarations: [FunctionDeclaration] = functions.map { function in + let parameterSchemas = function.parameters.properties?.reduce(into: [String: Schema]()) { result, property in + result[property.key] = Schema( + type: .string, + description: property.value.description + ) + } ?? [:] + + return FunctionDeclaration( + name: function.name.rawValue, + description: function.context, + parameters: parameterSchemas, + requiredParameters: function.parameters.required + ) + } + + let systemMessage = messages.first { $0.role == .system } + let systemInstruction = ModelContent( + role: "system", + parts: [.text(try systemMessage?.content._stripToText() ?? "")] + ) + + let userMessages = messages.filter { $0.role != .system } + let userContent = userMessages.map { message in + ModelContent( + role: "user", + parts: [.text(try! message.content._stripToText())] + ) + } + + let request = GenerateContentRequest( + model: "models/" + model.rawValue, + contents: userContent, + generationConfig: nil, + safetySettings: nil, + tools: [Tool(functionDeclarations: functionDeclarations)], + toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)), + systemInstruction: systemInstruction, + isStreaming: false, + options: RequestOptions() + ) + + let response = try await service.loadRequest(request: request) + + let functionCalls = response.candidates.first?.content.parts.compactMap { part -> FunctionCall? in + if case .functionCall(let functionCall) = part { + return functionCall + } + return nil + } ?? [] + + return functionCalls + } + + // Code Execution + public func _complete( + _ messages: [AbstractLLM.ChatMessage], + codeExecution: Bool, + model: Gemini.Model + ) async throws -> AbstractLLM.ChatCompletion { + let (systemInstruction, modelContent) = try await _makeSystemInstructionAndModelContent(messages: messages) + + let generativeModel = GenerativeModel( + name: model.rawValue, + apiKey: configuration.apiKey, + generationConfig: nil, + tools: codeExecution ? [Tool(codeExecution: CodeExecution())] : nil, + systemInstruction: systemInstruction + ) + + let response: GenerateContentResponse = try await generativeModel.generateContent(modelContent) + let firstCandidate: CandidateResponse = try response.candidates.toCollectionOfOne().value + + let completion = AbstractLLM.ChatCompletion( + prompt: AbstractLLM.ChatPrompt(messages: messages), + message: try AbstractLLM.ChatMessage(_from: firstCandidate.content), + stopReason: try AbstractLLM.ChatCompletion.StopReason(_from: firstCandidate.finishReason.unwrap()) + ) + return completion } } @@ -101,7 +196,7 @@ extension Gemini.Client { stopSequences: parameters.stops ) } - + private func _makeSystemInstructionAndModelContent( messages: [AbstractLLM.ChatMessage] ) async throws -> (systemInstruction: ModelContent?, content: [ModelContent]) { @@ -117,19 +212,19 @@ extension Gemini.Client { var content: [ModelContent] = [] for message in messages { - try _tryAssert(message.role != .system) + try _tryAssert(message.role != .system) content.append(try await ModelContent(_from: message)) } return (systemInstruction, content) } - + private func _modelContent( from prompt: AbstractLLM.TextPrompt ) throws -> [ModelContent] { let promptText = try prompt.prefix.promptLiteral._stripToText() - + return [ModelContent(role: "user", parts: promptText)] } diff --git a/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift b/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift index 9325224..c6e2411 100644 --- a/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift +++ b/Sources/GoogleAI/Preternatural/ModelContent+LargeLanguageModels.swift @@ -31,14 +31,14 @@ extension ModelContent { case .image(let image): let data: Data = try image.jpegData.unwrap() let mimeType: String = _MediaAssetFileType.jpeg.mimeType - + parts.append(.data(mimetype: mimeType, data)) case .base64DataURL: TODO.unimplemented } - case .functionCall(_): + case .functionCall(let function): TODO.unimplemented - case .resultOfFunctionCall(_): + case .resultOfFunctionCall(let result): TODO.unimplemented } } diff --git a/Tests/GoogleAITests/Preternatural/Config.swift b/Tests/GoogleAITests/Preternatural/Config.swift new file mode 100644 index 0000000..85ecbe1 --- /dev/null +++ b/Tests/GoogleAITests/Preternatural/Config.swift @@ -0,0 +1,12 @@ +// +// Config.swift +// Gemini +// +// Created by Jared Davidson on 12/13/24. +// + +import Gemini + +let apiKey = "" + +let client = Gemini.Client(apiKey: apiKey) diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+CodeExecution.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+CodeExecution.swift new file mode 100644 index 0000000..c68be76 --- /dev/null +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+CodeExecution.swift @@ -0,0 +1,157 @@ +// +// Untitled.swift +// Gemini +// +// Created by Jared Davidson on 12/13/24. +// + +import AI +import CorePersistence +import Gemini +import Testing + +@Suite struct GeminiCodeExecutionTests { + + @Test + func testAbstractLLMCodeExecution() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer." + }, + .user { + "What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this." + } + ] + + let completion = try await client._complete( + messages, + codeExecution: true, + model: .gemini_1_5_pro_latest + ) + + let messageContent = try completion.message.content._stripToText() + + #expect(!messageContent.isEmpty, "Response should not be empty") + #expect(messageContent.contains("```python"), "Should contain Python code block") + #expect(messageContent.contains("```"), "Should have code block formatting") + #expect(messageContent.contains("5117"), "Should contain correct sum of first 50 primes") + #expect(messageContent.contains("Output") || messageContent.contains("Result"), + "Should contain execution output") + } + + @Test + func testBasicCodeExecution() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer." + }, + .user { + "What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this." + } + ] + + let model = GenerativeModel( + name: "gemini-1.5-pro-latest", + apiKey: apiKey, + tools: [Tool(codeExecution: CodeExecution())], + systemInstruction: ModelContent( + role: "system", + parts: [.text("You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer.")] + ) + ) + + let userContent = [ModelContent( + role: "user", + parts: [.text("What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this.")] + )] + + let response = try await model.generateContent(userContent) + + let hasExecutableCode = response.candidates.first?.content.parts.contains { part in + if case .executableCode = part { return true } + return false + } ?? false + #expect(hasExecutableCode, "Response should contain executable code") + + let hasExecutionResults = response.candidates.first?.content.parts.contains { part in + if case .codeExecutionResult = part { return true } + return false + } ?? false + #expect(hasExecutionResults, "Response should contain code execution results") + + if let executionResult = response.candidates.first?.content.parts.compactMap({ part in + if case .codeExecutionResult(let result) = part { return result } + return nil + }).first { + #expect(executionResult.outcome == .ok, "Code execution should complete successfully") + #expect(!executionResult.output.isEmpty, "Execution output should not be empty") + #expect(executionResult.output.contains("5117"), "Output should contain correct sum of first 50 primes (5117)") + } + } + + @Test + func testComplexCodeExecution() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer." + }, + .user { + """ + I need a Python function that: + 1. Generates 100 random numbers between 1 and 1000 + 2. Sorts them in descending order + 3. Filters out any numbers divisible by 3 + 4. Returns the sum of the remaining numbers + + Please write and execute this code. + """ + } + ] + + let model = GenerativeModel( + name: "gemini-1.5-pro-latest", + apiKey: apiKey, + tools: [Tool(codeExecution: CodeExecution())], + systemInstruction: ModelContent( + role: "system", + parts: [.text("You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer.")] + ) + ) + + let userContent = try [ModelContent( + role: "user", + parts: [.text(messages.last!.content._stripToText())] + )] + + let response = try await model.generateContent(userContent) + + dump(response) + + // Validate code structure + if let executableCode = response.candidates.first?.content.parts.compactMap({ part in + if case .executableCode(let code) = part { return code } + return nil + }).first { + #expect(executableCode.language.lowercased() == "python", "Should use Python") + #expect(executableCode.code.contains("random.randint"), "Should use random number generation") + #expect(executableCode.code.contains("sort"), "Should include sorting") + #expect(executableCode.code.contains("filter"), "Should include filtering") + } + + // Validate execution outcome + if let executionResult = response.candidates.first?.content.parts.compactMap({ part in + if case .codeExecutionResult(let result) = part { return result } + return nil + }).first { + #expect(executionResult.outcome == .ok, "Code execution should complete successfully") + #expect(!executionResult.output.isEmpty, "Execution output should not be empty") + + // Parse the output number to verify it's a valid sum + let numberPattern = #/\d+/# + if let match = executionResult.output.firstMatch(of: numberPattern) { + let sum = Int(match.output) + #expect(sum != nil && sum! > 0, "Output should contain a positive number") + } + } + } +} diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift new file mode 100644 index 0000000..6b1132c --- /dev/null +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests+FunctionCalling.swift @@ -0,0 +1,168 @@ +// +// PreternaturalAI_Tests+FunctionCalling.swift +// Gemini +// +// Created by Jared Davidson on 12/12/24. +// + +import AI +import CorePersistence +import Gemini +import Testing + +@Suite struct GeminiLightingSystemTests { + + @Test + func testLightingSystem() async throws { + let messages: [AbstractLLM.ChatMessage] = [ + .system { + "You are a helpful lighting system bot. You can turn lights on and off, and you can set the color. Do not perform any other tasks." + }, + .user { + "Turn on the lights and set them to red." + } + ] + + let functions = [ + AbstractLLM.ChatFunctionDefinition( + name: "enable_lights", + context: "Turn on the lighting system.", + parameters: JSONSchema( + type: .object, + properties: [ + "dummy": JSONSchema( + type: .string, + description: "Placeholder parameter" + ) + ] + ) + ), + AbstractLLM.ChatFunctionDefinition( + name: "set_light_color", + context: "Set the light color. Lights must be enabled for this to work.", + parameters: JSONSchema( + type: .object, + properties: [ + "rgb_hex": JSONSchema( + type: .string, + description: "The light color as a 6-digit hex string, e.g. ff0000 for red." + ) + ], + required: ["rgb_hex"] + ) + ), + AbstractLLM.ChatFunctionDefinition( + name: "stop_lights", + context: "Turn off the lighting system.", + parameters: JSONSchema( + type: .object, + properties: [ + "dummy": JSONSchema( + type: .string, + description: "Placeholder parameter" + ) + ] + ) + ) + ] + + //FIXME: This should ideally be returning or working with something from AbstractLLM. + + let functionCalls: [FunctionCall] = try await client._complete( + messages, + functions: functions, + model: Gemini.Model.gemini_1_5_pro_latest, + as: AbstractLLM.ChatFunctionCall.self + ) + + guard let lastFunctionCall = functionCalls.last?.args else { return } + + let result = try lastFunctionCall.toJSONData().decode(LightingCommandParameters.self) + + #expect(result.rgb_hex != nil, "Light color parameter should not be nil") + } + + @Test + func testDirectFunctionCalling() async throws { + + let functionDeclarations = [ + FunctionDeclaration( + name: "enable_lights", + description: "Turn on the lighting system.", + parameters: [ + "dummy": Schema( + type: .string, + description: "Placeholder parameter" + ) + ] + ), + FunctionDeclaration( + name: "set_light_color", + description: "Set the light color. Lights must be enabled for this to work.", + parameters: [ + "rgb_hex": Schema( + type: .string, + description: "The light color as a 6-digit hex string, e.g. ff0000 for red." + ) + ], + requiredParameters: ["rgb_hex"] + ), + FunctionDeclaration( + name: "stop_lights", + description: "Turn off the lighting system.", + parameters: [ + "dummy": Schema( + type: .string, + description: "Placeholder parameter" + ) + ] + ) + ] + + let tools = Tool(functionDeclarations: functionDeclarations) + let toolConfig = ToolConfig( + functionCallingConfig: FunctionCallingConfig(mode: .auto) + ) + + let systemInstruction = ModelContent( + role: "system", + parts: [.text("You are a helpful lighting system bot. You can turn lights on and off, and you can set the color. Do not perform any other tasks.")] + ) + + let userContent = [ModelContent( + role: "user", + parts: [.text("Turn on the lights and set them to red.")] + )] + + let model = GenerativeModel( + name: "gemini-1.5-pro-latest", + apiKey: apiKey, + tools: [tools], + toolConfig: toolConfig, + systemInstruction: systemInstruction + ) + + let response = try await model.generateContent(userContent) + + dump(response) + + #expect(!response.functionCalls.isEmpty, "Should have function calls") + + for functionCall in response.functionCalls { + #expect( + ["enable_lights", "set_light_color", "stop_lights"].contains(functionCall.name), + "Function call should be one of the defined functions" + ) + + if functionCall.name == "set_light_color", + let args = functionCall.args["rgb_hex"] as? String { + #expect(args == "ff0000", "Light color should be red (ff0000)") + } + } + } + + + struct LightingCommandParameters: Codable, Hashable, Sendable { + let rgb_hex: String? + } +} diff --git a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests.swift b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests.swift index 9754cc2..4e095fb 100644 --- a/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests.swift +++ b/Tests/GoogleAITests/Preternatural/PreternaturalAI_Tests.swift @@ -18,8 +18,6 @@ import SwiftUIX import XCTest final class PreternaturalAI_Tests: XCTestCase { - let client: any LLMRequestHandling = Gemini.Client(apiKey: "API_KEY") - func testAvailableModels() { let models = client._availableModels