-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #107 from dongri/batch-api
Add batch api
- Loading branch information
Showing
9 changed files
with
216 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,3 +23,6 @@ features = ["derive"] | |
|
||
[dependencies.serde_json] | ||
version = "1" | ||
|
||
[dependencies.bytes] | ||
version = "1.7.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
use openai_api_rs::v1::api::OpenAIClient; | ||
use openai_api_rs::v1::batch::CreateBatchRequest; | ||
use openai_api_rs::v1::file::FileUploadRequest; | ||
use serde_json::{from_str, to_string_pretty, Value}; | ||
use std::env; | ||
use std::fs::File; | ||
use std::io::Write; | ||
use std::str; | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); | ||
|
||
let req = FileUploadRequest::new( | ||
"examples/data/batch_request.json".to_string(), | ||
"batch".to_string(), | ||
); | ||
|
||
let result = client.upload_file(req).await?; | ||
println!("File id: {:?}", result.id); | ||
|
||
let input_file_id = result.id; | ||
let req = CreateBatchRequest::new( | ||
input_file_id.clone(), | ||
"/v1/chat/completions".to_string(), | ||
"24h".to_string(), | ||
); | ||
|
||
let result = client.create_batch(req).await?; | ||
println!("Batch id: {:?}", result.id); | ||
|
||
let batch_id = result.id; | ||
let result = client.retrieve_batch(batch_id.to_string()).await?; | ||
println!("Batch status: {:?}", result.status); | ||
|
||
// sleep 30 seconds | ||
println!("Sleeping for 30 seconds..."); | ||
tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; | ||
|
||
let result = client.retrieve_batch(batch_id.to_string()).await?; | ||
|
||
let file_id = result.output_file_id.unwrap(); | ||
let result = client.retrieve_file_content(file_id).await?; | ||
let s = match str::from_utf8(&result) { | ||
Ok(v) => v.to_string(), | ||
Err(e) => panic!("Invalid UTF-8 sequence: {}", e), | ||
}; | ||
let json_value: Value = from_str(&s)?; | ||
let result_json = to_string_pretty(&json_value)?; | ||
|
||
let output_file_path = "examples/data/batch_result.json"; | ||
let mut file = File::create(output_file_path)?; | ||
file.write_all(result_json.as_bytes())?; | ||
|
||
println!("File writed to {:?}", output_file_path); | ||
|
||
Ok(()) | ||
} | ||
|
||
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is 2+2?"}]}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
{ | ||
"custom_id": "request-1", | ||
"error": null, | ||
"id": "batch_req_403hYy7nMxrxXFWXiwvoLG1q", | ||
"response": { | ||
"body": { | ||
"choices": [ | ||
{ | ||
"finish_reason": "stop", | ||
"index": 0, | ||
"logprobs": null, | ||
"message": { | ||
"content": "2 + 2 equals 4.", | ||
"refusal": null, | ||
"role": "assistant" | ||
} | ||
} | ||
], | ||
"created": 1724858089, | ||
"id": "chatcmpl-A1Efhv97EZNQeHKSLPnTmZex20gf2", | ||
"model": "gpt-4o-mini-2024-07-18", | ||
"object": "chat.completion", | ||
"system_fingerprint": "fp_f33667828e", | ||
"usage": { | ||
"completion_tokens": 8, | ||
"prompt_tokens": 24, | ||
"total_tokens": 32 | ||
} | ||
}, | ||
"request_id": "af0bac0d82530234e09bd6b5d9fbf5cf", | ||
"status_code": 200 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct CreateBatchRequest { | ||
pub input_file_id: String, | ||
pub endpoint: String, | ||
pub completion_window: String, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub metadata: Option<Metadata>, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct Metadata { | ||
pub customer_id: String, | ||
pub batch_description: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct RequestCounts { | ||
pub total: u32, | ||
pub completed: u32, | ||
pub failed: u32, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct BatchResponse { | ||
pub id: String, | ||
pub object: String, | ||
pub endpoint: String, | ||
pub errors: Option<Vec<String>>, | ||
pub input_file_id: String, | ||
pub completion_window: String, | ||
pub status: String, | ||
pub output_file_id: Option<String>, | ||
pub error_file_id: Option<String>, | ||
pub created_at: u64, | ||
pub in_progress_at: Option<u64>, | ||
pub expires_at: Option<u64>, | ||
pub finalizing_at: Option<u64>, | ||
pub completed_at: Option<u64>, | ||
pub failed_at: Option<u64>, | ||
pub expired_at: Option<u64>, | ||
pub cancelling_at: Option<u64>, | ||
pub cancelled_at: Option<u64>, | ||
pub request_counts: RequestCounts, | ||
pub metadata: Option<Metadata>, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct ListBatchResponse { | ||
pub object: String, | ||
pub data: Vec<BatchResponse>, | ||
pub first_id: String, | ||
pub last_id: String, | ||
pub has_more: bool, | ||
} | ||
|
||
impl CreateBatchRequest { | ||
pub fn new(input_file_id: String, endpoint: String, completion_window: String) -> Self { | ||
Self { | ||
input_file_id, | ||
endpoint, | ||
completion_window, | ||
metadata: None, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters