Skip to content

Commit

Permalink
Predicted outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesrochabrun committed Jan 3, 2025
1 parent 77418b0 commit 8cfa454
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
7BBE7EAB2B02E8FC0096A693 /* ChatMessageDisplayModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BBE7EAA2B02E8FC0096A693 /* ChatMessageDisplayModel.swift */; };
7BBE7EDE2B03718E0096A693 /* ChatFunctionCallProvider.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BBE7EDD2B03718E0096A693 /* ChatFunctionCallProvider.swift */; };
7BBE7EE02B0372550096A693 /* ChatFunctionCallDemoView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BBE7EDF2B0372550096A693 /* ChatFunctionCallDemoView.swift */; };
7BE802592D2878170080E06A /* ChatPredictedOutputDemoView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BE802582D2878170080E06A /* ChatPredictedOutputDemoView.swift */; };
7BE9A5AF2B0B33E600CE8103 /* SwiftOpenAIExampleTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BA788DE2AE23A49008825D5 /* SwiftOpenAIExampleTests.swift */; };
/* End PBXBuildFile section */

Expand Down Expand Up @@ -156,6 +157,7 @@
7BBE7EAA2B02E8FC0096A693 /* ChatMessageDisplayModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessageDisplayModel.swift; sourceTree = "<group>"; };
7BBE7EDD2B03718E0096A693 /* ChatFunctionCallProvider.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatFunctionCallProvider.swift; sourceTree = "<group>"; };
7BBE7EDF2B0372550096A693 /* ChatFunctionCallDemoView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatFunctionCallDemoView.swift; sourceTree = "<group>"; };
7BE802582D2878170080E06A /* ChatPredictedOutputDemoView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatPredictedOutputDemoView.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */

/* Begin PBXFrameworksBuildPhase section */
Expand Down Expand Up @@ -378,6 +380,7 @@
isa = PBXGroup;
children = (
7BA788CC2AE23A48008825D5 /* SwiftOpenAIExampleApp.swift */,
7BE802572D2877D30080E06A /* PredictedOutputsDemo */,
7B50DD292C2A9D1D0070A64D /* LocalChatDemo */,
7B99C2E52C0718CD00E701B3 /* Files */,
7B7239AF2AF9FF1D00646679 /* SharedModels */,
Expand Down Expand Up @@ -480,6 +483,14 @@
path = Completion;
sourceTree = "<group>";
};
7BE802572D2877D30080E06A /* PredictedOutputsDemo */ = {
isa = PBXGroup;
children = (
7BE802582D2878170080E06A /* ChatPredictedOutputDemoView.swift */,
);
path = PredictedOutputsDemo;
sourceTree = "<group>";
};
/* End PBXGroup section */

/* Begin PBXNativeTarget section */
Expand Down Expand Up @@ -620,6 +631,7 @@
buildActionMask = 2147483647;
files = (
7BBE7EA92B02E8E50096A693 /* ChatMessageView.swift in Sources */,
7BE802592D2878170080E06A /* ChatPredictedOutputDemoView.swift in Sources */,
7B7239AE2AF9FF0000646679 /* ChatFunctionsCallStreamProvider.swift in Sources */,
7B436BA12AE25958003CE281 /* ChatProvider.swift in Sources */,
7B436BC32AE7B027003CE281 /* ModerationDemoView.swift in Sources */,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import SwiftOpenAI
var messages: [String] = []
var errorMessage: String = ""
var message: String = ""
var usage: ChatUsage?

init(service: OpenAIService) {
self.service = service
Expand All @@ -25,10 +26,14 @@ import SwiftOpenAI
parameters: ChatCompletionParameters) async throws
{
do {
let choices = try await service.startChat(parameters: parameters).choices
let response = try await service.startChat(parameters: parameters)
let choices = response.choices
let chatUsage = response.usage
let logprobs = choices.compactMap(\.logprobs)
dump(logprobs)
self.messages = choices.compactMap(\.message.content)
dump(chatUsage)
self.usage = chatUsage
} catch APIError.responseUnsuccessful(let description, let statusCode) {
self.errorMessage = "Network error with status code: \(statusCode) and description: \(description)"
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct OptionsListView: View {
enum APIOption: String, CaseIterable, Identifiable {
case audio = "Audio"
case chat = "Chat"
case chatPredictedOutput = "Chat Predicted Output"
case localChat = "Local Chat" // Ollama
case vision = "Vision"
case embeddings = "Embeddings"
Expand Down Expand Up @@ -51,6 +52,8 @@ struct OptionsListView: View {
AudioDemoView(service: openAIService)
case .chat:
ChatDemoView(service: openAIService)
case .chatPredictedOutput:
ChatPredictedOutputDemoView(service: openAIService)
case .vision:
ChatVisionDemoView(service: openAIService)
case .embeddings:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
//
// ChatPredictedOutputDemoView.swift
// SwiftOpenAIExample
//
// Created by James Rochabrun on 1/3/25.
//

import Foundation
import SwiftUI
import SwiftOpenAI

/// https://platform.openai.com/docs/guides/predicted-outputs
struct ChatPredictedOutputDemoView: View {

@State private var chatProvider: ChatProvider
@State private var isLoading = false
@State private var prompt = ""

init(service: OpenAIService) {
chatProvider = ChatProvider(service: service)
}

var body: some View {
ScrollView {
VStack {
textArea
Text(chatProvider.errorMessage)
.foregroundColor(.red)
chatCompletionResultView
}
}
.overlay(
Group {
if isLoading {
ProgressView()
} else {
EmptyView()
}
}
)
}

var textArea: some View {
HStack(spacing: 4) {
TextField("Enter prompt", text: $prompt, axis: .vertical)
.textFieldStyle(.roundedBorder)
.padding()
Button {
Task {
isLoading = true
defer { isLoading = false } // ensure isLoading is set to false when the

let content: ChatCompletionParameters.Message.ContentType = .text(prompt)
prompt = ""
let parameters = ChatCompletionParameters(
messages: [
.init(role: .system, content: .text(systemMessage)),
.init(role: .user, content: content),
.init(role: .user, content: .text(predictedCode))], // Sending the predicted code as another user message.
model: .gpt4o,
prediction: .init(content: .text(predictedCode)))
try await chatProvider.startChat(parameters: parameters)

}
} label: {
Image(systemName: "paperplane")
}
.buttonStyle(.bordered)
}
.padding()
}

/// stream = `false`
var chatCompletionResultView: some View {
ForEach(Array(chatProvider.messages.enumerated()), id: \.offset) { idx, val in
VStack(spacing: 0) {
Text("\(val)")
}
}
}
}

let systemMessage = """
You are a code editor assistant. I only output code without any explanations, commentary, or additional text. I follow these rules:
1. Respond with code only, never any text or explanations
2. Use appropriate syntax highlighting/formatting
3. If the code needs to be modified/improved, output the complete updated code
4. Do not include caveats, introductions, or commentary
5. Do not ask questions or solicit feedback
6. Do not explain what changes were made
7. Assume the user knows what they want and will review the code themselves
"""

let predictedCode = """
struct ChatPredictedOutputDemoView: View {
@State private var chatProvider: ChatProvider
@State private var isLoading = false
@State private var prompt = ""
init(service: OpenAIService) {
chatProvider = ChatProvider(service: service)
}
var body: some View {
ScrollView {
VStack {
textArea
Text(chatProvider.errorMessage)
.foregroundColor(.red)
streamedChatResultView
}
}
.overlay(
Group {
if isLoading {
ProgressView()
} else {
EmptyView()
}
}
)
}
var textArea: some View {
HStack(spacing: 4) {
TextField("Enter prompt", text: $prompt, axis: .vertical)
.textFieldStyle(.roundedBorder)
.padding()
Button {
Task {
isLoading = true
defer { isLoading = false } // ensure isLoading is set to false when the
let content: ChatCompletionParameters.Message.ContentType = .text(prompt)
prompt = ""
let parameters = ChatCompletionParameters(
messages: [.init(
role: .user,
content: content)],
model: .gpt4o)
}
} label: {
Image(systemName: "paperplane")
}
.buttonStyle(.bordered)
}
.padding()
}
/// stream = `true`
var streamedChatResultView: some View {
VStack {
Button("Cancel stream") {
chatProvider.cancelStream()
}
Text(chatProvider.message)
}
}
}
"""
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public struct ChatCompletionParameters: Encodable {
///The gpt-4o-audio-preview model can also be used to [generate audio](https://platform.openai.com/docs/guides/audio). To request that this model generate both text and audio responses, you can use:
/// ["text", "audio"]
public var modalities: [String]?
/// Configuration for a [Predicted Output](https://platform.openai.com/docs/guides/predicted-outputs), which can greatly improve response times when large parts of the model response are known ahead of time. This is most common when you are regenerating a file with only minor changes to most of the content.
public var prediction: Prediction?
/// Parameters for audio output. Required when audio output is requested with modalities: ["audio"]. [Learn more.](https://platform.openai.com/docs/guides/audio)
public var audio: Audio?
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. Defaults to 0
Expand Down Expand Up @@ -380,6 +382,41 @@ public struct ChatCompletionParameters: Encodable {
}
}

public struct Prediction: Encodable {
public let type: String
public let content: PredictionContent

public enum PredictionContent: Encodable {
case text(String)
case contentArray([ContentPart])

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .text(let text):
try container.encode(text)
case .contentArray(let parts):
try container.encode(parts)
}
}
}

public struct ContentPart: Encodable {
public let type: String
public let text: String

public init(type: String, text: String) {
self.type = type
self.text = text
}
}

public init(content: PredictionContent, type: String = "content") {
self.type = type
self.content = content
}
}

public enum ReasoningEffort: String, Encodable {
case low
case medium
Expand All @@ -405,6 +442,7 @@ public struct ChatCompletionParameters: Encodable {
case maCompletionTokens = "max_completion_tokens"
case n
case modalities
case prediction
case audio
case responseFormat = "response_format"
case presencePenalty = "presence_penalty"
Expand Down Expand Up @@ -436,6 +474,7 @@ public struct ChatCompletionParameters: Encodable {
maxTokens: Int? = nil,
n: Int? = nil,
modalities: [String]? = nil,
prediction: Prediction? = nil,
audio: Audio? = nil,
responseFormat: ResponseFormat? = nil,
presencePenalty: Double? = nil,
Expand Down Expand Up @@ -464,6 +503,7 @@ public struct ChatCompletionParameters: Encodable {
self.maxTokens = maxTokens
self.n = n
self.modalities = modalities
self.prediction = prediction
self.audio = audio
self.responseFormat = responseFormat
self.presencePenalty = presencePenalty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public struct ChatCompletionChunkObject: Decodable {
public let systemFingerprint: String?
/// The object type, which is always chat.completion.chunk.
public let object: String
/// An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.
public let usage: ChatUsage?

public struct ChatChoice: Decodable {

Expand Down Expand Up @@ -114,5 +116,6 @@ public struct ChatCompletionChunkObject: Decodable {
case serviceTier = "service_tier"
case systemFingerprint = "system_fingerprint"
case object
case usage
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public struct ChatCompletionObject: Decodable {
/// The object type, which is always chat.completion.
public let object: String
/// Usage statistics for the completion request.
public let usage: ChatUsage
public let usage: ChatUsage?

public struct ChatChoice: Decodable {

Expand Down Expand Up @@ -139,20 +139,4 @@ public struct ChatCompletionObject: Decodable {
case object
case usage
}

public struct ChatUsage: Decodable {

/// Number of tokens in the generated completion.
public let completionTokens: Int
/// Number of tokens in the prompt.
public let promptTokens: Int
/// Total number of tokens used in the request (prompt + completion).
public let totalTokens: Int

enum CodingKeys: String, CodingKey {
case completionTokens = "completion_tokens"
case promptTokens = "prompt_tokens"
case totalTokens = "total_tokens"
}
}
}
Loading

0 comments on commit 8cfa454

Please sign in to comment.