Skip to content

Commit 4e06968

Browse files
LarryOstermanheathsCopilot
authored
Implement EventHubs token refresh (#2476)
* Added preliminary support to refresh tokens * Convert AMQP to use async-trait package; Working token refresh logic * Moved credentials to connection manager since they are inherently per-connection; removed credentials from authorize_path since they're not needed at that point * Working connection manager token refresh tests Co-authored-by: Heath Stewart <[email protected]> Co-authored-by: Copilot <[email protected]> * Final PR feedback --------- Co-authored-by: Heath Stewart <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 0b37c0d commit 4e06968

File tree

26 files changed

+1308
-112
lines changed

26 files changed

+1308
-112
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sdk/core/azure_core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ pub mod fs;
1515
pub mod hmac;
1616
pub mod http;
1717
pub mod process;
18+
pub mod task;
19+
1820
#[cfg(feature = "test")]
1921
pub mod test;
2022

sdk/core/azure_core/src/process/standard.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ impl Executor for StdExecutor {
2121
let output = cmd.output();
2222
tx.send(output)
2323
});
24-
let output = rx
25-
.await
26-
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))??;
24+
let output = rx.await.map_err(io::Error::other)??;
2725
Ok(output)
2826
}
2927
}

sdk/core/azure_core/src/task/mod.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
//! Asynchronous task execution utilities.
5+
//!
6+
//! This module provides a mechanism to spawn tasks asynchronously and wait for their completion.
7+
//!
8+
//! It abstracts away the underlying implementation details, allowing for different task execution strategies based on the target architecture and features enabled.
9+
//!
10+
//!
11+
//! Example usage:
12+
//!
13+
//! ```
14+
//! use azure_core::task::{new_task_spawner, TaskSpawner};
15+
//! use futures::FutureExt;
16+
//!
17+
//! #[tokio::main]
18+
//! async fn main() {
19+
//! let spawner = new_task_spawner();
20+
//! let handle = spawner.spawn(async {
21+
//! // Simulate some work
22+
//! std::thread::sleep(std::time::Duration::from_secs(1));
23+
//! }.boxed());
24+
//!
25+
//! handle.await.expect("Task should complete successfully");
26+
//!
27+
//! println!("Task completed");
28+
//! }
29+
//! ```
30+
//!
31+
//!
32+
use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc};
33+
34+
mod standard_spawn;
35+
36+
#[cfg(feature = "tokio")]
37+
mod tokio_spawn;
38+
39+
#[cfg(test)]
40+
mod tests;
41+
42+
#[cfg(not(target_arch = "wasm32"))]
43+
pub(crate) type TaskFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
44+
45+
// WASM32 does not support `Send` futures, so we use a non-Send future type.
46+
#[cfg(target_arch = "wasm32")]
47+
pub(crate) type TaskFuture = Pin<Box<dyn Future<Output = ()> + 'static>>;
48+
49+
/// A `SpawnedTask` is a future that represents a running task.
50+
/// It can be awaited to block until the task has completed.
51+
#[cfg(not(target_arch = "wasm32"))]
52+
pub type SpawnedTask = Pin<
53+
Box<
54+
dyn Future<Output = std::result::Result<(), Box<dyn std::error::Error + Send>>>
55+
+ Send
56+
+ 'static,
57+
>,
58+
>;
59+
60+
#[cfg(target_arch = "wasm32")]
61+
pub type SpawnedTask =
62+
Pin<Box<dyn Future<Output = std::result::Result<(), Box<dyn std::error::Error>>> + 'static>>;
63+
64+
/// An async command runner.
65+
///
66+
pub trait TaskSpawner: Send + Sync + Debug {
67+
/// Spawn a task that executes a given future and returns the output.
68+
///
69+
/// # Arguments
70+
///
71+
/// * `f` - A future representing the task to be spawned. This future cannot capture any variables
72+
/// from its environment by reference, as it will be executed in a different thread or context.
73+
///
74+
/// # Returns
75+
/// A future which can be awaited to block until the task has completed.
76+
///
77+
/// # Example
78+
/// ```
79+
/// use azure_core::task::{new_task_spawner, TaskSpawner};
80+
/// use futures::FutureExt;
81+
///
82+
/// #[tokio::main]
83+
/// async fn main() {
84+
/// let spawner = new_task_spawner();
85+
/// let future = spawner.spawn(async {
86+
/// // Simulate some work
87+
/// std::thread::sleep(std::time::Duration::from_secs(1));
88+
/// }.boxed());
89+
/// future.await.expect("Task should complete successfully");
90+
/// }
91+
/// ```
92+
///
93+
/// # Note
94+
///
95+
/// This trait intentionally does not use the *`async_trait`* macro because when the
96+
/// `async_trait` attribute is applied to a trait implementation, the rewritten
97+
/// method cannot directly return a future, instead they wrap the return value
98+
/// in a future, and we want the `spawn` method to directly return a future
99+
/// that can be awaited.
100+
///
101+
fn spawn(&self, f: TaskFuture) -> SpawnedTask;
102+
}
103+
104+
/// Creates a new [`TaskSpawner`] to enable running tasks asynchronously.
105+
///
106+
///
107+
/// The implementation depends on the target architecture and the features enabled:
108+
/// - If the `tokio` feature is enabled, it uses a tokio based spawner.
109+
/// - If the `tokio` feature is not enabled and the target architecture is not `wasm32`, it uses a std::thread based spawner.
110+
///
111+
/// # Returns
112+
/// A new instance of a [`TaskSpawner`] which can be used to spawn background tasks.
113+
///
114+
/// # Example
115+
///
116+
/// ```
117+
/// use azure_core::task::{new_task_spawner, TaskSpawner};
118+
/// use futures::FutureExt;
119+
///
120+
/// #[tokio::main]
121+
/// async fn main() {
122+
/// let spawner = new_task_spawner();
123+
/// let handle = spawner.spawn(async {
124+
/// // Simulate some work
125+
/// std::thread::sleep(std::time::Duration::from_secs(1));
126+
/// }.boxed());
127+
/// }
128+
/// ```
129+
///
130+
pub fn new_task_spawner() -> Arc<dyn TaskSpawner> {
131+
#[cfg(not(feature = "tokio"))]
132+
{
133+
Arc::new(standard_spawn::StdSpawner)
134+
}
135+
#[cfg(feature = "tokio")]
136+
{
137+
Arc::new(tokio_spawn::TokioSpawner) as Arc<dyn TaskSpawner>
138+
}
139+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
use super::{SpawnedTask, TaskFuture, TaskSpawner};
5+
#[cfg(not(target_arch = "wasm32"))]
6+
use futures::{executor::LocalPool, task::SpawnExt};
7+
#[cfg(not(target_arch = "wasm32"))]
8+
use std::{
9+
future,
10+
future::Future,
11+
pin::Pin,
12+
sync::{Arc, Mutex},
13+
task::Waker,
14+
task::{Context, Poll},
15+
thread,
16+
};
17+
#[cfg(not(target_arch = "wasm32"))]
18+
use tracing::debug;
19+
20+
/// A future that completes when a thread join handle completes.
21+
#[cfg(not(target_arch = "wasm32"))]
22+
struct ThreadJoinFuture {
23+
join_state: Arc<Mutex<ThreadJoinState>>,
24+
}
25+
26+
#[cfg(not(target_arch = "wasm32"))]
27+
#[derive(Default)]
28+
struct ThreadJoinState {
29+
join_handle:
30+
Option<thread::JoinHandle<std::result::Result<(), Box<dyn std::error::Error + Send>>>>,
31+
waker: Option<Waker>,
32+
thread_finished: bool,
33+
}
34+
35+
#[cfg(not(target_arch = "wasm32"))]
36+
impl Future for ThreadJoinFuture {
37+
type Output = std::result::Result<(), Box<dyn std::error::Error + Send>>;
38+
39+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
40+
let mut join_state = self.join_state.lock().map_err(|e| {
41+
debug!("Failed to lock join state: {}", e);
42+
Box::new(crate::Error::message(
43+
crate::error::ErrorKind::Other,
44+
format!("Thread panicked: {:?}", e),
45+
)) as Box<dyn std::error::Error + Send>
46+
})?;
47+
48+
// Join handle is present, so we can check if the thread has finished
49+
// and take the handle if it has.
50+
// This is safe because we are holding the lock on the join state.
51+
// We can safely take the handle and join it without blocking.
52+
// This allows us to retrieve the terminal state of the thread.
53+
if join_state.thread_finished {
54+
// Thread is finished, so we can safely take the handle
55+
let Some(join_handle) = join_state.join_handle.take() else {
56+
// The join handle was already removed from the state, we know we're done.
57+
return Poll::Ready(Ok(()));
58+
};
59+
60+
// Since we know the thread is finished, we can safely take the handle
61+
// and join it. This allows us to retrieve the terminal state of the thread.
62+
//
63+
// Technically this might block (because the `thread_finished` flag
64+
// is set before the thread *actually* finishes), but it should be negligible.
65+
match join_handle.join() {
66+
Ok(_) => Poll::Ready(Ok(())),
67+
Err(e) => Poll::Ready(Err(Box::new(crate::Error::message(
68+
crate::error::ErrorKind::Other,
69+
format!("Thread panicked: {:?}", e),
70+
)) as Box<dyn std::error::Error + Send>)),
71+
}
72+
} else {
73+
// Thread is still running, so we need to register the waker
74+
// for when it completes.
75+
join_state.waker = Some(cx.waker().clone());
76+
Poll::Pending
77+
}
78+
}
79+
}
80+
81+
/// A [`TaskSpawner`] using [`std::thread::spawn`].
82+
#[derive(Debug)]
83+
pub struct StdSpawner;
84+
85+
impl TaskSpawner for StdSpawner {
86+
#[cfg_attr(target_arch = "wasm32", allow(unused_variables))]
87+
fn spawn(&self, f: TaskFuture) -> SpawnedTask {
88+
#[cfg(target_arch = "wasm32")]
89+
{
90+
panic!("std::thread::spawn is not supported on wasm32")
91+
}
92+
#[cfg(not(target_arch = "wasm32"))]
93+
{
94+
let join_state = Arc::new(Mutex::new(ThreadJoinState::default()));
95+
{
96+
let Ok(mut js) = join_state.lock() else {
97+
return Box::pin(future::ready(Err(Box::new(crate::Error::message(
98+
crate::error::ErrorKind::Other,
99+
"Thread panicked.",
100+
))
101+
as Box<dyn std::error::Error + Send>)));
102+
};
103+
104+
// Clone the join state so it can be moved into the thread
105+
// and used to notify the waker when the thread finishes.
106+
let join_state_clone = join_state.clone();
107+
108+
js.join_handle = Some(thread::spawn(move || {
109+
// Create a local executor
110+
let mut local_pool = LocalPool::new();
111+
let spawner = local_pool.spawner();
112+
113+
// Spawn the future on the local executor
114+
let Ok(future_handle) = spawner.spawn_with_handle(f) else {
115+
return Err(Box::new(crate::Error::message(
116+
crate::error::ErrorKind::Other,
117+
"Failed to spawn future.",
118+
))
119+
as Box<dyn std::error::Error + Send>);
120+
};
121+
// Drive the executor until the future completes
122+
local_pool.run_until(future_handle);
123+
124+
let Ok(mut join_state) = join_state_clone.lock() else {
125+
return Err(Box::new(crate::Error::message(
126+
crate::error::ErrorKind::Other,
127+
"Failed to lock join state",
128+
))
129+
as Box<dyn std::error::Error + Send>);
130+
};
131+
132+
// The thread has finished, so we can take the waker
133+
// and notify it.
134+
join_state.thread_finished = true;
135+
if let Some(waker) = join_state.waker.take() {
136+
waker.wake();
137+
}
138+
Ok(())
139+
}));
140+
}
141+
// Create a future that will complete when the thread joins
142+
let join_future = ThreadJoinFuture { join_state };
143+
Box::pin(join_future)
144+
}
145+
}
146+
}

0 commit comments

Comments
 (0)