Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENG-1726 - REST Interface #1

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions Sources/GoogleAI/Preternatural/AbstractLLM+Gemini.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,35 @@ extension PromptLiteral {
switch part {
case .text(let value):
self.init(stringLiteral: value)
case .data:
TODO.unimplemented
case .fileData:
TODO.unimplemented
case .functionCall:
TODO.unimplemented
case .functionResponse:
TODO.unimplemented
case .executableCode(_):
TODO.unimplemented
case .codeExecutionResult(_):
TODO.unimplemented
}
case .data(let mimetype, _):

self.init(stringLiteral: "[Data: \(mimetype)]")
case .fileData(let mimetype, let uri):

self.init(stringLiteral: "[File: \(mimetype) at \(uri)]")
case .functionCall(let functionCall):

self.init(stringLiteral: "Function Call: \(functionCall.name)")
case .functionResponse(let functionResponse):

self.init(stringLiteral: "Function Response: \(functionResponse.name)")
case .executableCode(let executableCode):

self.init(stringLiteral: """
```\(executableCode.language.lowercased())
\(executableCode.code)
```
""")
case .codeExecutionResult(let result):

let status = result.outcome == .ok ? "Success" : "Error"
self.init(stringLiteral: """
Execution Result (\(status)):
```
\(result.output)
```
""")
}
}
}

Expand Down
107 changes: 101 additions & 6 deletions Sources/GoogleAI/Preternatural/Gemini.Client+LLMRequestHandling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ extension Gemini.Client: LLMRequestHandling {

return try cast(_completion)
}

private func _complete(
prompt: AbstractLLM.TextPrompt,
parameters: AbstractLLM.TextCompletionParameters
Expand Down Expand Up @@ -80,7 +80,102 @@ extension Gemini.Client: LLMRequestHandling {
message: try AbstractLLM.ChatMessage(_from: firstCandidate.content),
stopReason: try AbstractLLM.ChatCompletion.StopReason(_from: firstCandidate.finishReason.unwrap())
)


return completion
}


// Function Calling
public func _complete(
_ messages: [AbstractLLM.ChatMessage],
functions: [AbstractLLM.ChatFunctionDefinition],
model: Gemini.Model,
as type: AbstractLLM.ChatFunctionCall.Type
) async throws -> [FunctionCall] {
let service = GenerativeAIService(
apiKey: configuration.apiKey,
urlSession: .shared
)

let functionDeclarations: [FunctionDeclaration] = functions.map { function in
let parameterSchemas = function.parameters.properties?.reduce(into: [String: Schema]()) { result, property in
result[property.key] = Schema(
type: .string,
description: property.value.description
)
} ?? [:]

return FunctionDeclaration(
name: function.name.rawValue,
description: function.context,
parameters: parameterSchemas,
requiredParameters: function.parameters.required
)
}

let systemMessage = messages.first { $0.role == .system }
let systemInstruction = ModelContent(
role: "system",
parts: [.text(try systemMessage?.content._stripToText() ?? "")]
)

let userMessages = messages.filter { $0.role != .system }
let userContent = userMessages.map { message in
ModelContent(
role: "user",
parts: [.text(try! message.content._stripToText())]
)
}

let request = GenerateContentRequest(
model: "models/" + model.rawValue,
contents: userContent,
generationConfig: nil,
safetySettings: nil,
tools: [Tool(functionDeclarations: functionDeclarations)],
toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)),
systemInstruction: systemInstruction,
isStreaming: false,
options: RequestOptions()
)

let response = try await service.loadRequest(request: request)

let functionCalls = response.candidates.first?.content.parts.compactMap { part -> FunctionCall? in
if case .functionCall(let functionCall) = part {
return functionCall
}
return nil
} ?? []

return functionCalls
}

// Code Execution
public func _complete(
_ messages: [AbstractLLM.ChatMessage],
codeExecution: Bool,
model: Gemini.Model
) async throws -> AbstractLLM.ChatCompletion {
let (systemInstruction, modelContent) = try await _makeSystemInstructionAndModelContent(messages: messages)

let generativeModel = GenerativeModel(
name: model.rawValue,
apiKey: configuration.apiKey,
generationConfig: nil,
tools: codeExecution ? [Tool(codeExecution: CodeExecution())] : nil,
systemInstruction: systemInstruction
)

let response: GenerateContentResponse = try await generativeModel.generateContent(modelContent)
let firstCandidate: CandidateResponse = try response.candidates.toCollectionOfOne().value

let completion = AbstractLLM.ChatCompletion(
prompt: AbstractLLM.ChatPrompt(messages: messages),
message: try AbstractLLM.ChatMessage(_from: firstCandidate.content),
stopReason: try AbstractLLM.ChatCompletion.StopReason(_from: firstCandidate.finishReason.unwrap())
)

return completion
}
}
Expand All @@ -101,7 +196,7 @@ extension Gemini.Client {
stopSequences: parameters.stops
)
}

private func _makeSystemInstructionAndModelContent(
messages: [AbstractLLM.ChatMessage]
) async throws -> (systemInstruction: ModelContent?, content: [ModelContent]) {
Expand All @@ -117,19 +212,19 @@ extension Gemini.Client {
var content: [ModelContent] = []

for message in messages {
try _tryAssert(message.role != .system)
try _tryAssert(message.role != .system)

content.append(try await ModelContent(_from: message))
}

return (systemInstruction, content)
}

private func _modelContent(
from prompt: AbstractLLM.TextPrompt
) throws -> [ModelContent] {
let promptText = try prompt.prefix.promptLiteral._stripToText()

return [ModelContent(role: "user", parts: promptText)]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ extension ModelContent {
case .image(let image):
let data: Data = try image.jpegData.unwrap()
let mimeType: String = _MediaAssetFileType.jpeg.mimeType

parts.append(.data(mimetype: mimeType, data))
case .base64DataURL:
TODO.unimplemented
}
case .functionCall(_):
case .functionCall(let function):
TODO.unimplemented
case .resultOfFunctionCall(_):
case .resultOfFunctionCall(let result):
TODO.unimplemented
}
}
Expand Down
12 changes: 12 additions & 0 deletions Tests/GoogleAITests/Preternatural/Config.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//
// Config.swift
// Gemini
//
// Created by Jared Davidson on 12/13/24.
//

import Gemini

let apiKey = ""

let client = Gemini.Client(apiKey: apiKey)
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
//
// Untitled.swift
// Gemini
//
// Created by Jared Davidson on 12/13/24.
//

import AI
import CorePersistence
import Gemini
import Testing

@Suite struct GeminiCodeExecutionTests {

@Test
func testAbstractLLMCodeExecution() async throws {
let messages: [AbstractLLM.ChatMessage] = [
.system {
"You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer."
},
.user {
"What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this."
}
]

let completion = try await client._complete(
messages,
codeExecution: true,
model: .gemini_1_5_pro_latest
)

let messageContent = try completion.message.content._stripToText()

#expect(!messageContent.isEmpty, "Response should not be empty")
#expect(messageContent.contains("```python"), "Should contain Python code block")
#expect(messageContent.contains("```"), "Should have code block formatting")
#expect(messageContent.contains("5117"), "Should contain correct sum of first 50 primes")
#expect(messageContent.contains("Output") || messageContent.contains("Result"),
"Should contain execution output")
}

@Test
func testBasicCodeExecution() async throws {
let messages: [AbstractLLM.ChatMessage] = [
.system {
"You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer."
},
.user {
"What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this."
}
]

let model = GenerativeModel(
name: "gemini-1.5-pro-latest",
apiKey: apiKey,
tools: [Tool(codeExecution: CodeExecution())],
systemInstruction: ModelContent(
role: "system",
parts: [.text("You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer.")]
)
)

let userContent = [ModelContent(
role: "user",
parts: [.text("What is the sum of the first 50 prime numbers? Write and execute Python code to calculate this.")]
)]

let response = try await model.generateContent(userContent)

let hasExecutableCode = response.candidates.first?.content.parts.contains { part in
if case .executableCode = part { return true }
return false
} ?? false
#expect(hasExecutableCode, "Response should contain executable code")

let hasExecutionResults = response.candidates.first?.content.parts.contains { part in
if case .codeExecutionResult = part { return true }
return false
} ?? false
#expect(hasExecutionResults, "Response should contain code execution results")

if let executionResult = response.candidates.first?.content.parts.compactMap({ part in
if case .codeExecutionResult(let result) = part { return result }
return nil
}).first {
#expect(executionResult.outcome == .ok, "Code execution should complete successfully")
#expect(!executionResult.output.isEmpty, "Execution output should not be empty")
#expect(executionResult.output.contains("5117"), "Output should contain correct sum of first 50 primes (5117)")
}
}

@Test
func testComplexCodeExecution() async throws {
let messages: [AbstractLLM.ChatMessage] = [
.system {
"You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer."
},
.user {
"""
I need a Python function that:
1. Generates 100 random numbers between 1 and 1000
2. Sorts them in descending order
3. Filters out any numbers divisible by 3
4. Returns the sum of the remaining numbers

Please write and execute this code.
"""
}
]

let model = GenerativeModel(
name: "gemini-1.5-pro-latest",
apiKey: apiKey,
tools: [Tool(codeExecution: CodeExecution())],
systemInstruction: ModelContent(
role: "system",
parts: [.text("You are a helpful coding assistant. When asked to solve problems, you should write and execute code to find the answer.")]
)
)

let userContent = try [ModelContent(
role: "user",
parts: [.text(messages.last!.content._stripToText())]
)]

let response = try await model.generateContent(userContent)

dump(response)

// Validate code structure
if let executableCode = response.candidates.first?.content.parts.compactMap({ part in
if case .executableCode(let code) = part { return code }
return nil
}).first {
#expect(executableCode.language.lowercased() == "python", "Should use Python")
#expect(executableCode.code.contains("random.randint"), "Should use random number generation")
#expect(executableCode.code.contains("sort"), "Should include sorting")
#expect(executableCode.code.contains("filter"), "Should include filtering")
}

// Validate execution outcome
if let executionResult = response.candidates.first?.content.parts.compactMap({ part in
if case .codeExecutionResult(let result) = part { return result }
return nil
}).first {
#expect(executionResult.outcome == .ok, "Code execution should complete successfully")
#expect(!executionResult.output.isEmpty, "Execution output should not be empty")

// Parse the output number to verify it's a valid sum
let numberPattern = #/\d+/#
if let match = executionResult.output.firstMatch(of: numberPattern) {
let sum = Int(match.output)
#expect(sum != nil && sum! > 0, "Output should contain a positive number")
}
}
}
}
Loading
Loading