Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: hold runtime ref and handle to prevent spawn after shutdown #736

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions foyer-common/src/asyncify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use tokio::runtime::Handle;
use crate::runtime::SingletonHandle;

/// Convert the block call to async call.
#[cfg(not(madsim))]
Expand All @@ -36,9 +36,9 @@ where
f()
}

/// Convert the block call to async call with given runtime.
/// Convert the block call to async call with given runtime handle.
#[cfg(not(madsim))]
pub async fn asyncify_with_runtime<F, T>(runtime: &Handle, f: F) -> T
pub async fn asyncify_with_runtime<F, T>(runtime: &SingletonHandle, f: F) -> T
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
Expand All @@ -50,7 +50,7 @@ where
/// Convert the block call to async call with given runtime.
///
/// madsim compatible mode.
pub async fn asyncify_with_runtime<F, T>(_: &Handle, f: F) -> T
pub async fn asyncify_with_runtime<F, T>(_: &SingletonHandle, f: F) -> T
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
Expand Down
84 changes: 83 additions & 1 deletion foyer-common/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

use std::{
fmt::Debug,
future::Future,
mem::ManuallyDrop,
ops::{Deref, DerefMut},
};

use tokio::runtime::Runtime;
use tokio::{
runtime::{Handle, Runtime},
task::JoinHandle,
};

/// A wrapper around [`Runtime`] that shuts down the runtime in the background when dropped.
///
Expand Down Expand Up @@ -62,3 +66,81 @@ impl From<Runtime> for BackgroundShutdownRuntime {
Self(ManuallyDrop::new(runtime))
}
}

/// A non-clonable runtime handle.
#[derive(Debug)]
pub struct SingletonHandle(Handle);

impl From<Handle> for SingletonHandle {
fn from(handle: Handle) -> Self {
Self(handle)
}
}

impl SingletonHandle {
/// Spawns a future onto the Tokio runtime.
///
/// This spawns the given future onto the runtime's executor, usually a
/// thread pool. The thread pool is then responsible for polling the future
/// until it completes.
///
/// The provided future will start running in the background immediately
/// when `spawn` is called, even if you don't await the returned
/// `JoinHandle`.
///
/// See [module level][mod] documentation for more details.
///
/// [mod]: index.html
///
/// # Examples
///
/// ```
/// use tokio::runtime::Runtime;
///
/// # fn dox() {
/// // Create the runtime
/// let rt = Runtime::new().unwrap();
/// // Get a handle from this runtime
/// let handle = rt.handle();
///
/// // Spawn a future onto the runtime using the handle
/// handle.spawn(async {
/// println!("now running on a worker thread");
/// });
/// # }
/// ```
pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.0.spawn(future)
}

/// Runs the provided function on an executor dedicated to blocking
/// operations.
///
/// # Examples
///
/// ```
/// use tokio::runtime::Runtime;
///
/// # fn dox() {
/// // Create the runtime
/// let rt = Runtime::new().unwrap();
/// // Get a handle from this runtime
/// let handle = rt.handle();
///
/// // Spawn a blocking function onto the runtime using the handle
/// handle.spawn_blocking(|| {
/// println!("now running on a worker thread");
/// });
/// # }
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
self.0.spawn_blocking(func)
}
}
3 changes: 2 additions & 1 deletion foyer-memory/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use foyer_common::{
code::{HashBuilder, Key, Value},
event::EventListener,
future::Diversion,
runtime::SingletonHandle,
};
use futures::Future;
use pin_project::pin_project;
Expand Down Expand Up @@ -834,7 +835,7 @@ where
key: K,
context: CacheContext,
fetch: F,
runtime: &tokio::runtime::Handle,
runtime: &SingletonHandle,
) -> Fetch<K, V, ER, S>
where
F: FnOnce() -> FU,
Expand Down
13 changes: 10 additions & 3 deletions foyer-memory/src/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
future::{Diversion, DiversionFuture},
metrics::Metrics,
object_pool::ObjectPool,
runtime::SingletonHandle,
strict_assert, strict_assert_eq,
};
use hashbrown::hash_map::{Entry as HashMapEntry, HashMap};
Expand Down Expand Up @@ -739,7 +740,12 @@
FU: Future<Output = std::result::Result<V, ER>> + Send + 'static,
ER: Send + 'static + Debug,
{
self.fetch_inner(key, CacheContext::default(), fetch, &tokio::runtime::Handle::current())
self.fetch_inner(
key,
CacheContext::default(),
fetch,
&tokio::runtime::Handle::current().into(),
)
}

pub fn fetch_with_context<F, FU, ER>(
Expand All @@ -753,15 +759,16 @@
FU: Future<Output = std::result::Result<V, ER>> + Send + 'static,
ER: Send + 'static + Debug,
{
self.fetch_inner(key, context, fetch, &tokio::runtime::Handle::current())
self.fetch_inner(key, context, fetch, &tokio::runtime::Handle::current().into())

Check warning on line 762 in foyer-memory/src/generic.rs

View check run for this annotation

Codecov / codecov/patch

foyer-memory/src/generic.rs#L762

Added line #L762 was not covered by tests
}

#[doc(hidden)]
pub fn fetch_inner<F, FU, ER, ID>(
self: &Arc<Self>,
key: K,
context: CacheContext,
fetch: F,
runtime: &tokio::runtime::Handle,
runtime: &SingletonHandle,
) -> GenericFetch<K, V, E, I, S, ER>
where
F: FnOnce() -> FU,
Expand Down
20 changes: 9 additions & 11 deletions foyer-storage/src/device/direct_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ use std::{
use foyer_common::{asyncify::asyncify_with_runtime, bits};
use fs4::free_space;
use serde::{Deserialize, Serialize};
use tokio::runtime::Handle;

use super::{Dev, DevExt, DevOptions, RegionId};
use crate::{
device::ALIGN,
error::{Error, Result},
IoBytes, IoBytesMut,
IoBytes, IoBytesMut, Runtime,
};

/// Options for the direct file device.
Expand All @@ -49,7 +48,7 @@ pub struct DirectFileDevice {
capacity: usize,
region_size: usize,

runtime: Handle,
runtime: Runtime,
}

impl DevOptions for DirectFileDeviceOptions {
Expand Down Expand Up @@ -90,7 +89,7 @@ impl DirectFileDevice {

let file = self.file.clone();

asyncify_with_runtime(&self.runtime, move || {
asyncify_with_runtime(self.runtime.write(), move || {
#[cfg(target_family = "windows")]
let written = {
use std::os::windows::fs::FileExt;
Expand Down Expand Up @@ -133,7 +132,7 @@ impl DirectFileDevice {

let file = self.file.clone();

let mut buffer = asyncify_with_runtime(&self.runtime, move || {
let mut buffer = asyncify_with_runtime(self.runtime.read(), move || {
#[cfg(target_family = "windows")]
let read = {
use std::os::windows::fs::FileExt;
Expand Down Expand Up @@ -172,9 +171,7 @@ impl Dev for DirectFileDevice {
}

#[fastrace::trace(name = "foyer::storage::device::direct_file::open")]
async fn open(options: Self::Options) -> Result<Self> {
let runtime = Handle::current();

async fn open(options: Self::Options, runtime: Runtime) -> Result<Self> {
options.verify()?;

let dir = options
Expand Down Expand Up @@ -253,7 +250,7 @@ impl Dev for DirectFileDevice {
#[fastrace::trace(name = "foyer::storage::device::direct_file::flush")]
async fn flush(&self, _: Option<RegionId>) -> Result<()> {
let file = self.file.clone();
asyncify_with_runtime(&self.runtime, move || file.sync_all().map_err(Error::from)).await
asyncify_with_runtime(self.runtime.write(), move || file.sync_all().map_err(Error::from)).await
}
}

Expand Down Expand Up @@ -360,6 +357,7 @@ mod tests {
#[test_log::test(tokio::test)]
async fn test_direct_file_device_io() {
let dir = tempfile::tempdir().unwrap();
let runtime = Runtime::current();

let options = DirectFileDeviceOptionsBuilder::new(dir.path().join("test-direct-file"))
.with_capacity(4 * 1024 * 1024)
Expand All @@ -368,7 +366,7 @@ mod tests {

tracing::debug!("{options:?}");

let device = DirectFileDevice::open(options.clone()).await.unwrap();
let device = DirectFileDevice::open(options.clone(), runtime.clone()).await.unwrap();

let mut buf = IoBytesMut::with_capacity(64 * 1024);
buf.extend(repeat_n(b'x', 64 * 1024 - 100));
Expand All @@ -383,7 +381,7 @@ mod tests {

drop(device);

let device = DirectFileDevice::open(options).await.unwrap();
let device = DirectFileDevice::open(options, runtime).await.unwrap();

let b = device.read(0, 4096, 64 * 1024 - 100).await.unwrap().freeze();
assert_eq!(buf, b);
Expand Down
25 changes: 12 additions & 13 deletions foyer-storage/src/device/direct_fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
use futures::future::try_join_all;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use tokio::runtime::Handle;

use super::{Dev, DevExt, DevOptions, RegionId};
use crate::{
device::ALIGN,
error::{Error, Result},
IoBytes, IoBytesMut,
IoBytes, IoBytesMut, Runtime,
};

/// Options for the direct fs device.
Expand All @@ -56,7 +55,7 @@
capacity: usize,
file_size: usize,

runtime: Handle,
runtime: Runtime,
}

impl DevOptions for DirectFsDeviceOptions {
Expand Down Expand Up @@ -106,17 +105,16 @@
}

#[fastrace::trace(name = "foyer::storage::device::direct_fs::open")]
async fn open(options: Self::Options) -> Result<Self> {
let runtime = Handle::current();

async fn open(options: Self::Options, runtime: Runtime) -> Result<Self> {
options.verify()?;

// TODO(MrCroxx): write and read options to a manifest file for pinning

let regions = options.capacity / options.file_size;

let path = options.dir.clone();
asyncify_with_runtime(&runtime, move || create_dir_all(path)).await?;
if !options.dir.exists() {
create_dir_all(&options.dir)?;

Check warning on line 116 in foyer-storage/src/device/direct_fs.rs

View check run for this annotation

Codecov / codecov/patch

foyer-storage/src/device/direct_fs.rs#L116

Added line #L116 was not covered by tests
}

let futures = (0..regions)
.map(|i| {
Expand Down Expand Up @@ -165,7 +163,7 @@

let file = self.file(region).clone();

asyncify_with_runtime(&self.inner.runtime, move || {
asyncify_with_runtime(self.inner.runtime.write(), move || {
#[cfg(target_family = "windows")]
let written = {
use std::os::windows::fs::FileExt;
Expand Down Expand Up @@ -207,7 +205,7 @@

let file = self.file(region).clone();

let mut buffer = asyncify_with_runtime(&self.inner.runtime, move || {
let mut buffer = asyncify_with_runtime(self.inner.runtime.read(), move || {
#[cfg(target_family = "unix")]
let read = {
use std::os::unix::fs::FileExt;
Expand Down Expand Up @@ -237,7 +235,7 @@
async fn flush(&self, region: Option<super::RegionId>) -> Result<()> {
let flush = |region: RegionId| {
let file = self.file(region).clone();
asyncify_with_runtime(&self.inner.runtime, move || file.sync_all().map_err(Error::from))
asyncify_with_runtime(self.inner.runtime.write(), move || file.sync_all().map_err(Error::from))
};

if let Some(region) = region {
Expand Down Expand Up @@ -352,6 +350,7 @@
#[test_log::test(tokio::test)]
async fn test_direct_fd_device_io() {
let dir = tempfile::tempdir().unwrap();
let runtime = Runtime::current();

let options = DirectFsDeviceOptionsBuilder::new(dir.path())
.with_capacity(4 * 1024 * 1024)
Expand All @@ -360,7 +359,7 @@

tracing::debug!("{options:?}");

let device = DirectFsDevice::open(options.clone()).await.unwrap();
let device = DirectFsDevice::open(options.clone(), runtime.clone()).await.unwrap();

let mut buf = IoBytesMut::with_capacity(64 * 1024);
buf.extend(repeat_n(b'x', 64 * 1024 - 100));
Expand All @@ -375,7 +374,7 @@

drop(device);

let device = DirectFsDevice::open(options).await.unwrap();
let device = DirectFsDevice::open(options, runtime).await.unwrap();

let b = device.read(0, 4096, 64 * 1024 - 100).await.unwrap().freeze();
assert_eq!(buf, b);
Expand Down
Loading
Loading