Skip to content

Commit

Permalink
Merge pull request #37 from night-cruise/feat/add-chain-call
Browse files Browse the repository at this point in the history
feat: add chain call to create Request instance
  • Loading branch information
Dongri Jin authored Oct 17, 2023
2 parents f1f1fa7 + a9be9ef commit 93e9a32
Show file tree
Hide file tree
Showing 16 changed files with 402 additions and 135 deletions.
40 changes: 8 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,13 @@ let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
### Create request
```rust
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
let req = ChatCompletionRequest {
model: chat_completion::GPT4.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT4.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("Hello OpenAI!"),
}],
functions: None,
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
);
```

### Send request
Expand All @@ -68,27 +56,15 @@ use std::env;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = ChatCompletionRequest {
model: chat_completion::GPT4.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT4.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is Bitcoin?"),
name: None,
function_call: None,
}],
functions: None,
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
);
let result = client.chat_completion(req)?;
println!("{:?}", result.choices[0].message.content);
Ok(())
Expand Down
23 changes: 7 additions & 16 deletions examples/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,20 @@ use std::env;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = ChatCompletionRequest {
model: chat_completion::GPT4.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {

let req = ChatCompletionRequest::new(
chat_completion::GPT4.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is Bitcoin?"),
name: None,
function_call: None,
}],
functions: None,
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
);

let result = client.chat_completion(req)?;
println!("{:?}", result.choices[0].message.content);

Ok(())
}

Expand Down
30 changes: 12 additions & 18 deletions examples/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@ use std::env;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = CompletionRequest {
model: completion::GPT3_TEXT_DAVINCI_003.to_string(),
prompt: String::from("What is Bitcoin?"),
suffix: None,
max_tokens: Some(3000),
temperature: Some(0.9),
top_p: Some(1.0),
n: None,
stream: None,
logprobs: None,
echo: None,
stop: Some(vec![String::from(" Human:"), String::from(" AI:")]),
presence_penalty: Some(0.6),
frequency_penalty: Some(0.0),
best_of: None,
logit_bias: None,
user: None,
};

let req = CompletionRequest::new(
completion::GPT3_TEXT_DAVINCI_003.to_string(),
String::from("What is Bitcoin?"),
)
.max_tokens(3000)
.temperature(0.9)
.top_p(1.0)
.stop(vec![String::from(" Human:"), String::from(" AI:")])
.presence_penalty(0.6)
.frequency_penalty(0.0);

let result = client.completion(req)?;
println!("{:}", result.choices[0].text);

Expand Down
11 changes: 6 additions & 5 deletions examples/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use std::env;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = EmbeddingRequest {
model: "text-embedding-ada-002".to_string(),
input: "story time".to_string(),
user: Option::None,
};

let req = EmbeddingRequest::new(
"text-embedding-ada-002".to_string(),
"story time".to_string(),
);

let result = client.embedding(req)?;
println!("{:?}", result.data);

Expand Down
38 changes: 14 additions & 24 deletions examples/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,25 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}),
);

let req = ChatCompletionRequest {
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT3_5_TURBO_0613.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
name: None,
function_call: None,
}],
functions: Some(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}]),
function_call: Some(FunctionCallType::Auto), // Some(FunctionCallType::Function { name: "test".to_string() }),
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
)
.functions(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}])
.function_call(FunctionCallType::Auto);

// debug reuqest json
// let serialized = serde_json::to_string(&req).unwrap();
Expand Down
58 changes: 18 additions & 40 deletions examples/function_call_role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,24 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}),
);

let req = ChatCompletionRequest {
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT3_5_TURBO_0613.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
name: None,
function_call: None,
}],
functions: Some(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}]),
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
)
.functions(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}]);

let result = client.chat_completion(req)?;

Expand All @@ -80,9 +69,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let c: Currency = serde_json::from_str(&arguments)?;
let coin = c.coin;

let req = ChatCompletionRequest {
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
messages: vec![
let req = ChatCompletionRequest::new(
chat_completion::GPT3_5_TURBO_0613.to_string(),
vec![
chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
Expand All @@ -99,19 +88,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
function_call: None,
},
],
functions: None,
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
);

let result = client.chat_completion(req)?;
println!("{:?}", result.choices[0].message.content);
}
Expand Down
42 changes: 42 additions & 0 deletions src/v1/audio.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};

use crate::impl_builder_methods;

pub const WHISPER_1: &str = "whisper-1";

#[derive(Debug, Serialize)]
Expand All @@ -16,6 +18,27 @@ pub struct AudioTranscriptionRequest {
pub language: Option<String>,
}

impl AudioTranscriptionRequest {
pub fn new(file: String, model: String) -> Self {
Self {
file,
model,
prompt: None,
response_format: None,
temperature: None,
language: None,
}
}
}

impl_builder_methods!(
AudioTranscriptionRequest,
prompt: String,
response_format: String,
temperature: f32,
language: String
);

#[derive(Debug, Deserialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
Expand All @@ -33,6 +56,25 @@ pub struct AudioTranslationRequest {
pub temperature: Option<f32>,
}

impl AudioTranslationRequest {
pub fn new(file: String, model: String) -> Self {
Self {
file,
model,
prompt: None,
response_format: None,
temperature: None,
}
}
}

impl_builder_methods!(
AudioTranslationRequest,
prompt: String,
response_format: String,
temperature: f32
);

#[derive(Debug, Deserialize)]
pub struct AudioTranslationResponse {
pub text: String,
Expand Down
Loading

0 comments on commit 93e9a32

Please sign in to comment.