Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix audio transcription #144

Merged
merged 1 commit into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions examples/audio_transcriptions.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1};
use std::env;
use std::fs::File;
use std::io::Read;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = AudioTranscriptionRequest::new(
"examples/data/problem.mp3".to_string(),
WHISPER_1.to_string(),
);
let file_path = "examples/data/problem.mp3";

// Test with file
let req = AudioTranscriptionRequest::new(file_path.to_string(), WHISPER_1.to_string());

let req_json = req.clone().response_format("json".to_string());

Expand All @@ -22,7 +24,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let result = client.audio_transcription_raw(req_raw).await?;
println!("{:?}", result);

// Test with bytes
let mut file = File::open(file_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;

let req = AudioTranscriptionRequest::new_bytes(buffer, WHISPER_1.to_string());

let req_json = req.clone().response_format("json".to_string());

let result = client.audio_transcription(req_json).await?;
println!("{:?}", result);

Ok(())
}

// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_translations
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_transcriptions
58 changes: 54 additions & 4 deletions src/v1/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,31 +310,49 @@ impl OpenAIClient {
&self,
req: AudioTranscriptionRequest,
) -> Result<AudioTranscriptionResponse, APIError> {
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
if let Some(response_format) = &req.response_format {
if response_format != "json" && response_format != "verbose_json" {
return Err(APIError::CustomError {
message: "response_format must be either 'json' or 'verbose_json' please use audio_transcription_raw".to_string(),
});
}
}
let form = Self::create_form(&req, "file")?;
let form: Form;
if req.clone().file.is_some() {
form = Self::create_form(&req, "file")?;
} else if let Some(bytes) = req.clone().bytes {
form = Self::create_form_from_bytes(&req, bytes)?;
} else {
return Err(APIError::CustomError {
message: "Either file or bytes must be provided".to_string(),
});
}
self.post_form("audio/transcriptions", form).await
}

pub async fn audio_transcription_raw(
&self,
req: AudioTranscriptionRequest,
) -> Result<Bytes, APIError> {
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
if let Some(response_format) = &req.response_format {
if response_format != "text" && response_format != "srt" && response_format != "vtt" {
return Err(APIError::CustomError {
message: "response_format must be either 'text', 'srt' or 'vtt', please use audio_transcription".to_string(),
});
}
}
let form = Self::create_form(&req, "file")?;
let form: Form;
if req.clone().file.is_some() {
form = Self::create_form(&req, "file")?;
} else if let Some(bytes) = req.clone().bytes {
form = Self::create_form_from_bytes(&req, bytes)?;
} else {
return Err(APIError::CustomError {
message: "Either file or bytes must be provided".to_string(),
});
}
self.post_form_raw("audio/transcriptions", form).await
}

Expand Down Expand Up @@ -823,4 +841,36 @@ impl OpenAIClient {

Ok(form)
}

fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
where
T: Serialize,
{
let json = match serde_json::to_value(req) {
Ok(json) => json,
Err(e) => {
return Err(APIError::CustomError {
message: e.to_string(),
})
}
};

let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));

if let Value::Object(map) = json {
for (key, value) in map.into_iter() {
match value {
Value::String(s) => {
form = form.text(key, s);
}
Value::Number(n) => {
form = form.text(key, n.to_string());
}
_ => {}
}
}
}

Ok(form)
}
}
19 changes: 17 additions & 2 deletions src/v1/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ pub const WHISPER_1: &str = "whisper-1";

#[derive(Debug, Serialize, Clone)]
pub struct AudioTranscriptionRequest {
pub file: String,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bytes: Option<Vec<u8>>,
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
Expand All @@ -23,8 +25,21 @@ pub struct AudioTranscriptionRequest {
impl AudioTranscriptionRequest {
pub fn new(file: String, model: String) -> Self {
Self {
file,
model,
file: Some(file),
bytes: None,
prompt: None,
response_format: None,
temperature: None,
language: None,
}
}

pub fn new_bytes(bytes: Vec<u8>, model: String) -> Self {
Self {
model,
file: None,
bytes: Some(bytes),
prompt: None,
response_format: None,
temperature: None,
Expand Down
Loading