Skip to content

Diff Tracker PR #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 48 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
da4a468
Adding folders for openai related crates
jpalvarezl Jul 30, 2024
573cb7b
example runs
jpalvarezl Jul 30, 2024
9a55d53
Worked out how to expose structs selectively
jpalvarezl Jul 30, 2024
c683c25
Added scaffolding for chat_completions models
jpalvarezl Jul 30, 2024
302a9cd
Wired types to client request method
jpalvarezl Jul 30, 2024
9e9c866
Renamed crate folder
jpalvarezl Jul 31, 2024
210dd48
serialization works as expected
jpalvarezl Jul 31, 2024
ee4c021
Able to hit the service
jpalvarezl Jul 31, 2024
3e7d606
Added headers
jpalvarezl Jul 31, 2024
3781bee
Example for non-azure openai works
jpalvarezl Aug 1, 2024
172e51f
wip
jpalvarezl Aug 1, 2024
5e6bb6b
Removed bit about the examples in the README
jpalvarezl Aug 1, 2024
e8f9369
Added some classes to encapsulate secrets
jpalvarezl Aug 1, 2024
51de8e9
Cleaned up imports and moved build_request method up
jpalvarezl Aug 1, 2024
d4b5d40
Added code for azure version of chat completions (WIP)
jpalvarezl Aug 1, 2024
240eae1
Added code for azure example
jpalvarezl Aug 1, 2024
786e20f
Azure chat completion example works
jpalvarezl Aug 1, 2024
fea4fab
Adding support for multipart form-data in a very hacky way
jpalvarezl Aug 2, 2024
a531a0f
Added non azure example
jpalvarezl Aug 5, 2024
d710f4f
Added ctx for stream
jpalvarezl Aug 5, 2024
5351f89
WIP: streaming works, correctly chunking the deltas not implemented yet
jpalvarezl Aug 5, 2024
c830318
Streaming byte chunks successfully mapped into String
jpalvarezl Aug 6, 2024
b04c419
WIP: chunking stream events seems to work. TODO deserialize delta
jpalvarezl Aug 6, 2024
d0d05ea
WIP: chunking stream events seems to work. TODO deserialize delta
jpalvarezl Aug 6, 2024
b00eb73
Deserializing deltas, but emissions are not correctly streamed
jpalvarezl Aug 6, 2024
e6de6bd
WIP: early finish caused by source stream DONE message
jpalvarezl Aug 8, 2024
5411910
Removed unused imports and adding unit tests for stream chunks
jpalvarezl Aug 9, 2024
87cae8d
run cargo fmt
jpalvarezl Aug 9, 2024
80b8420
Setting up infra for stream events
jpalvarezl Aug 9, 2024
aa8afa3
test runs
jpalvarezl Aug 9, 2024
85e1a06
Preparing unit tests
jpalvarezl Aug 9, 2024
c6103bc
unit test pass
jpalvarezl Aug 9, 2024
1232092
Added next unit test, which fails
jpalvarezl Aug 9, 2024
5bfb6ae
wip
jpalvarezl Aug 9, 2024
0d9b9e6
Added hacky impl for PartialEq on ErrorKind::Custom for easier test a…
jpalvarezl Aug 9, 2024
30e75d2
Both tests passing
jpalvarezl Aug 9, 2024
98fe11a
More unit tests
jpalvarezl Aug 9, 2024
649d5f1
Project compiles
jpalvarezl Aug 9, 2024
1f9372f
Project compiles again
jpalvarezl Aug 9, 2024
072432b
Corrected some tests
jpalvarezl Aug 10, 2024
56baf61
Almost there
jpalvarezl Aug 10, 2024
33925cc
Tests are green
jpalvarezl Aug 11, 2024
fecb333
Enabled more chunks for real_data test
jpalvarezl Aug 11, 2024
177c2d9
Added new edge case
jpalvarezl Aug 11, 2024
fb279ca
Test name correction
jpalvarezl Aug 11, 2024
aca5f55
It sort of works
jpalvarezl Aug 11, 2024
ba74473
Merge branch 'main' into jpalvarezl/aoai_rust
jpalvarezl Aug 14, 2024
83a887f
Added stdout flush to showcase streaming better
jpalvarezl Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eng/test/mock_transport/src/mock_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ impl<'a> Serialize for RequestSerializer<'a> {
Body::Bytes(bytes) => base64::encode(bytes as &[u8]),
#[cfg(not(target_arch = "wasm32"))]
Body::SeekableStream(_) => unimplemented!(),
Body::Multipart(_) => unimplemented!(),
},
)?;

Expand Down
2 changes: 2 additions & 0 deletions eng/test/mock_transport/src/player_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,14 @@ impl Policy for MockTransportPlayerPolicy {
Body::Bytes(bytes) => bytes as &[u8],
#[cfg(not(target_arch = "wasm32"))]
Body::SeekableStream(_) => unimplemented!(),
Body::Multipart(_) => unimplemented!(),
};

let expected_body = match expected_request.body() {
Body::Bytes(bytes) => bytes as &[u8],
#[cfg(not(target_arch = "wasm32"))]
Body::SeekableStream(_) => unimplemented!(),
Body::Multipart(_) => unimplemented!(),
};

if actual_body != expected_body {
Expand Down
3 changes: 2 additions & 1 deletion sdk/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ http-types = { version = "2.12", default-features = false }
tracing = "0.1.40"
rand = "0.8"
reqwest = { version = "0.12.0", features = [
"stream",
"stream", "multipart"
], default-features = false, optional = true }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand Down Expand Up @@ -57,6 +57,7 @@ default = []
enable_reqwest = ["reqwest/default-tls"]
enable_reqwest_gzip = ["reqwest/gzip"]
enable_reqwest_rustls = ["reqwest/rustls-tls"]
enable_reqwest_multipart = ["reqwest/multipart"]
hmac_rust = ["dep:sha2", "dep:hmac"]
hmac_openssl = ["dep:openssl"]
test_e2e = []
Expand Down
14 changes: 12 additions & 2 deletions sdk/core/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ impl Display for ErrorKind {
}
}

// TODO: being able to derive PartialEq here would simplify some tests.
// Context::Custom inner boxed type prevents this. Check if there is a path to being able to derive it.
/// An error encountered from interfacing with Azure
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct Error {
context: Context,
}
Expand Down Expand Up @@ -383,7 +385,7 @@ where
}
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
enum Context {
Simple(ErrorKind),
Message {
Expand All @@ -400,6 +402,14 @@ struct Custom {
error: Box<dyn std::error::Error + Send + Sync>,
}

// Horrendous hack, so that I can easily assert_eq in tests. This makes the azure_core::Error unfortunately not implement
// PartialEq, which is necessary to make comparison and assertion outputs more readable.
impl PartialEq for Custom {
fn eq(&self, other: &Self) -> bool {
self.kind == other.kind && self.error.to_string() == other.error.to_string()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
12 changes: 7 additions & 5 deletions sdk/core/src/headers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,13 @@ impl Debug for Headers {
let redacted = HeaderValue::from_static("[redacted]");
f.debug_map()
.entries(self.0.iter().map(|(name, value)| {
if matching_ignore_ascii_case(name.as_str(), AUTHORIZATION.as_str()) {
(name, value)
} else {
(name, &redacted)
}
(name, value)
// TODO: restore
// if matching_ignore_ascii_case(name.as_str(), AUTHORIZATION.as_str()) {
// (name, value)
// } else {
// (name, &redacted)
// }
}))
.finish()?;
write!(f, ")")
Expand Down
1 change: 1 addition & 0 deletions sdk/core/src/http_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use async_trait::async_trait;
use bytes::Bytes;
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use crate::to_reqwest_form;

/// Construct a new `HttpClient`
pub fn new_http_client() -> Arc<dyn HttpClient> {
Expand Down
2 changes: 2 additions & 0 deletions sdk/core/src/http_client/reqwest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub fn new_reqwest_client() -> Arc<dyn HttpClient> {
// See <https://github.com/hyperium/hyper/issues/2312> for more details.
#[cfg(not(target_arch = "wasm32"))]
let client = ::reqwest::ClientBuilder::new()
.connection_verbose(true)
.pool_max_idle_per_host(0)
.build()
.expect("failed to build `reqwest` client");
Expand All @@ -43,6 +44,7 @@ impl HttpClient for ::reqwest::Client {
let body = request.body().clone();

let reqwest_request = match body {
Body::Multipart(form) => req.multipart(super::to_reqwest_form(form)).build(),
Body::Bytes(bytes) => req.body(bytes).build(),

// We cannot currently implement `Body::SeekableStream` for WASM
Expand Down
69 changes: 69 additions & 0 deletions sdk/core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::fmt::Debug;
/// An HTTP Body.
#[derive(Debug, Clone)]
pub enum Body {
Multipart(MyForm),
/// A body of a known size.
Bytes(bytes::Bytes),
/// A streaming body.
Expand All @@ -24,6 +25,7 @@ pub enum Body {
impl Body {
pub fn len(&self) -> usize {
match self {
Body::Multipart(_) => 0,
Body::Bytes(bytes) => bytes.len(),
#[cfg(not(target_arch = "wasm32"))]
Body::SeekableStream(stream) => stream.len(),
Expand All @@ -36,6 +38,7 @@ impl Body {

pub(crate) async fn reset(&mut self) -> crate::Result<()> {
match self {
Body::Multipart(_) => Ok(()),
Body::Bytes(_) => Ok(()),
#[cfg(not(target_arch = "wasm32"))]
Body::SeekableStream(stream) => stream.reset().await,
Expand All @@ -59,6 +62,12 @@ impl From<Box<dyn SeekableStream>> for Body {
}
}

impl From<MyForm> for Body {
fn from(my_form: MyForm) -> Self {
Self::Multipart(my_form)
}
}

/// A pipeline request.
///
/// A pipeline request is composed by a destination (uri), a method, a collection of headers and a
Expand Down Expand Up @@ -129,6 +138,10 @@ impl Request {
self.body = body.into();
}

pub fn multipart(&mut self, form: MyForm) {
self.body = Body::Multipart(form);
}

pub fn insert_header<K, V>(&mut self, key: K, value: V)
where
K: Into<crate::headers::HeaderName>,
Expand All @@ -147,3 +160,59 @@ impl Request {
self.insert_header(item.name(), item.value());
}
}

// Had to add this type because reqwest::multipart::Form does not implement Clone
// reqwest seems to handle the calculation of the content-size, so we don't need to keep
// track of that here. In a proper implementation, we might need to handle it.
#[derive(Debug, Clone)]
pub struct MyForm {
pub(crate) parts: Vec<MyPart>,
}

impl MyForm {
pub fn new() -> Self {
Self { parts: Vec::new() }
}

pub fn text(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.parts.push(MyPart::Text {
name: name.into(),
value: value.into(),
});
self
}

pub fn file(mut self, name: impl Into<String>, bytes: Vec<u8>) -> Self {
self.parts.push(MyPart::File {
name: name.into(),
bytes: bytes,
});
self
}
}

#[derive(Debug, Clone)]
pub enum MyPart {
Text{name: String, value: String},
File{name: String, bytes: Vec<u8>},
}

pub(crate) fn to_reqwest_form(form: MyForm) -> reqwest::multipart::Form {
let mut reqwest_form = reqwest::multipart::Form::new();
for part in form.parts {
match part {
MyPart::Text { name, value } => {
reqwest_form = reqwest_form.text(name, value);
}
// "part name" is no the same as `file_name`. Learned the hard way...
MyPart::File { name, bytes } => {
reqwest_form = reqwest_form.part("file",
reqwest::multipart::Part::bytes(bytes).
mime_str("application/octet-stream").unwrap()
.file_name(name)
);
}
}
}
reqwest_form
}
1 change: 1 addition & 0 deletions sdk/data_tables/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl TransactionOperations {
}
#[cfg(not(target_arch = "wasm32"))]
azure_core::Body::SeekableStream(_) => todo!(),
azure_core::Body::Multipart(_) => todo!(),
}
}

Expand Down
47 changes: 47 additions & 0 deletions sdk/openai_inference/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[package]
name = "azure_openai_inference"
version = "0.1.0"
description = "Rust wrappers around Microsoft Azure REST APIs - Azure OpenAI Inference"
readme = "README.md"
authors = ["Microsoft Corp."]
license = "MIT"
repository = "https://github.com/azure/azure-sdk-for-rust"
homepage = "https://github.com/azure/azure-sdk-for-rust"
# documentation = "https://docs.rs/azure_data_cosmos"
keywords = ["sdk", "azure", "rest"]
categories = ["api-bindings"]
edition = "2021"

[dependencies]
async-trait = "0.1"
azure_core = { path = "../core", version = "0.20" }
time = "0.3.10"
futures = "0.3"
tracing = "0.1.40"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
url = "2.2"
# uuid = { version = "1.0", features = ["v4"] }
thiserror = "1.0"
bytes = "1.0"
log = "0.4"
env_logger = "0.10"

[dev-dependencies]
# tracing-subscriber = "0.3"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
# clap = { version = "4.0.2", features = ["derive", "env"] }
reqwest = "0.12.0"
# stop-token = { version = "0.7.0", features = ["tokio"] }
mock_transport = { path = "../../eng/test/mock_transport" }

[features]
default = ["enable_reqwest", "hmac_rust"]
enable_reqwest = ["azure_core/enable_reqwest"]
enable_reqwest_rustls = ["azure_core/enable_reqwest_rustls"]
test_e2e = []
hmac_rust = ["azure_core/hmac_rust"]
hmac_openssl = ["azure_core/hmac_openssl"]

[package.metadata.docs.rs]
features = ["enable_reqwest", "enable_reqwest_rustls", "hmac_rust", "hmac_openssl"]
1 change: 1 addition & 0 deletions sdk/openai_inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# azure_openai_inference
Binary file not shown.
Binary file added sdk/openai_inference/assets/batman.wav
Binary file not shown.
35 changes: 35 additions & 0 deletions sdk/openai_inference/examples/chat_completions_azure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use azure_openai_inference::{AzureKeyCredential, AzureOpenAIClient};
use azure_openai_inference::{AzureServiceVersion, CreateChatCompletionsRequest};

#[tokio::main]
async fn main() {
let endpoint =
std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable");
let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable");

let openai_client = AzureOpenAIClient::new(endpoint, AzureKeyCredential::new(secret));

let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message(
"gpt-4-1106-preview",
"Tell me a joke about pineapples",
);

println!("{:#?}", &chat_completions_request);
println!("{:#?}", serde_json::to_string(&chat_completions_request));
let response = openai_client
.create_chat_completions(
&chat_completions_request.model,
AzureServiceVersion::V2023_12_01Preview,
&chat_completions_request,
)
.await;

match response {
Ok(chat_completions) => {
println!("{:#?}", &chat_completions);
}
Err(e) => {
println!("Error: {}", e);
}
}
}
30 changes: 30 additions & 0 deletions sdk/openai_inference/examples/chat_completions_non_azure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use azure_openai_inference::CreateChatCompletionsRequest;
use azure_openai_inference::OpenAIClient;
use azure_openai_inference::OpenAIKeyCredential;

#[tokio::main]
async fn main() {
let secret =
std::env::var("NON_AZURE_OPENAI_KEY").expect("Set NON_AZURE_OPENAI_KEY env variable");

let openai_client = OpenAIClient::new(OpenAIKeyCredential::new(secret));

let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message(
"gpt-3.5-turbo-1106",
"Tell me a joke about pineapples",
);

println!("{:#?}", &chat_completions_request);
let response = openai_client
.create_chat_completions(&chat_completions_request)
.await;

match response {
Ok(chat_completions) => {
println!("{:#?}", &chat_completions);
}
Err(e) => {
println!("Error: {}", e);
}
}
}
Loading