Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync with dev #150

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion casper-server/src/context.rs
Original file line number Diff line number Diff line change
@@ -141,7 +141,7 @@ impl AppContextInner {

// Start task scheduler
let max_background_tasks = self.config.main.max_background_tasks;
lua::tasks::start_task_scheduler(&lua, max_background_tasks);
lua::tasks::start_task_scheduler(lua, max_background_tasks);

// Enable sandboxing before loading user code
lua.sandbox(true)?;
41 changes: 30 additions & 11 deletions casper-server/src/lua/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::rc::Rc;
use std::result::Result as StdResult;
use std::sync::atomic::{AtomicU64, Ordering};

use mlua::{
AnyUserData, ExternalError, ExternalResult, Function, Lua, Result, Table, UserData, Value,
};
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::sync::{oneshot, watch};
use tokio::task::JoinHandle;
use tokio::time::{Duration, Instant};
use tracing::warn;
@@ -30,8 +31,12 @@ struct TaskHandle {
join_handle_rx: Option<oneshot::Receiver<TaskJoinHandle>>,
}

#[derive(Clone, Copy)]
struct MaxBackgroundTasks(Option<u64>);

#[derive(Clone)]
struct ShutdownNotifier(watch::Sender<bool>);

// Global task identifier
static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);

@@ -71,7 +76,7 @@ impl UserData for TaskHandle {
}

fn spawn_task(lua: &Lua, arg: Value) -> Result<StdResult<TaskHandle, String>> {
let max_background_tasks = lua.app_data_ref::<MaxBackgroundTasks>().unwrap();
let max_background_tasks = *lua.app_data_ref::<MaxBackgroundTasks>().unwrap();
let current_tasks = tasks_counter_get!();

if let Some(max_tasks) = max_background_tasks.0 {
@@ -128,27 +133,38 @@ fn spawn_task(lua: &Lua, arg: Value) -> Result<StdResult<TaskHandle, String>> {
}

pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {
let lua = lua.clone();
let lua = Rc::new(lua.clone());
let mut task_rx = lua
.remove_app_data::<UnboundedReceiver<Task>>()
.expect("Failed to get task receiver");

lua.set_app_data(MaxBackgroundTasks(max_background_tasks));

let (shutdown_tx, shutdown_rx) = watch::channel(false);
lua.set_app_data(ShutdownNotifier(shutdown_tx));

tokio::task::spawn_local(async move {
while let Some(task) = task_rx.recv().await {
let lua = lua.clone();
let mut shutdown = shutdown_rx.clone();
let join_handle = tokio::task::spawn_local(async move {
let start = Instant::now();
let _task_count_guard = tasks_counter_inc!();
// Keep Lua instance alive while task is running
let _lua_guard = lua;
let task_future = task.handler.call_async::<Value>(());

let result = match task.timeout {
Some(timeout) => ntex::time::timeout(timeout, task_future).await,
None => Ok(task_future.await),
Some(timeout) => tokio::select! {
_ = shutdown.wait_for(|&x| x) => return Err("task scheduler shutdown".into_lua_err()),
result = ntex::time::timeout(timeout, task_future) =>
result.unwrap_or_else(|_| Err("task exceeded timeout".into_lua_err())),
},
None => tokio::select! {
_ = shutdown.wait_for(|&x| x) => return Err("task scheduler shutdown".into_lua_err()),
result = task_future => result,
},
};
// Outer Result errors will always be timeouts
let result = result
.unwrap_or_else(|_| Err("task exceeded timeout".to_string()).into_lua_err());

// Record task metrics
match task.name {
@@ -178,7 +194,9 @@ pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {

pub fn stop_task_scheduler(lua: &Lua) {
lua.remove_app_data::<UnboundedSender<Task>>();
lua.remove_app_data::<UnboundedReceiver<Task>>();

// Notify all tasks to stop
_ = lua.app_data_ref::<ShutdownNotifier>().unwrap().0.send(true);
}

pub fn create_module(lua: &Lua) -> Result<Table> {
@@ -192,14 +210,13 @@ pub fn create_module(lua: &Lua) -> Result<Table> {

#[cfg(test)]
mod tests {
use std::rc::Rc;
use std::time::Duration;

use mlua::{chunk, Lua, Result};

#[ntex::test]
async fn test_tasks() -> Result<()> {
let lua = Rc::new(Lua::new());
let lua = Lua::new();

lua.globals().set("tasks", super::create_module(&lua)?)?;
lua.globals().set(
@@ -331,6 +348,8 @@ mod tests {
.await
.unwrap();

super::stop_task_scheduler(&lua);

Ok(())
}
}