Skip to content

Commit

Permalink
Add shared child.
Browse files Browse the repository at this point in the history
  • Loading branch information
milesj committed Jan 13, 2025
1 parent d3b4f7b commit d52ce66
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 68 deletions.
4 changes: 0 additions & 4 deletions crates/process/src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ pub struct Command {

/// Console to write output to
pub console: Option<Arc<Console>>,

/// Current ID of a running child process.
pub current_id: Option<u32>,
}

impl Command {
Expand All @@ -59,7 +56,6 @@ impl Command {
print_command: false,
shell: Some(Shell::default()),
console: None,
current_id: None,
}
}

Expand Down
122 changes: 71 additions & 51 deletions crates/process/src/exec_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ use tracing::{debug, enabled};

impl Command {
pub async fn exec_capture_output(&mut self) -> miette::Result<Output> {
let registry = ProcessRegistry::instance();
let (mut command, line) = self.create_async_command();
let output: Output;

if self.should_pass_stdin() {
let child = if self.should_pass_stdin() {
let mut child = command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
Expand All @@ -34,59 +34,73 @@ impl Command {

self.write_input_to_child(&mut child, &line).await?;

self.current_id = child.id();

output = child
.wait_with_output()
.await
.map_err(|error| ProcessError::Capture {
bin: self.get_bin_name(),
error: Box::new(error),
})?;
child
} else {
output = command
.output()
.await
command
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|error| ProcessError::Capture {
bin: self.get_bin_name(),
error: Box::new(error),
})?;
}
})?
};

let shared_child = registry.add_child(child).await;

let result = shared_child
.wait_with_output()
.await
.map_err(|error| ProcessError::Capture {
bin: self.get_bin_name(),
error: Box::new(error),
});

registry.remove_child(shared_child).await;

let output = result?;

self.handle_nonzero_status(&output, true)?;

Ok(output)
}

pub async fn exec_stream_output(&mut self) -> miette::Result<Output> {
let registry = ProcessRegistry::instance();
let (mut command, line) = self.create_async_command();
let mut child: Child;

if self.should_pass_stdin() {
child =
command
.stdin(Stdio::piped())
.spawn()
.map_err(|error| ProcessError::Stream {
let child =
if self.should_pass_stdin() {
let mut child = command.stdin(Stdio::piped()).spawn().map_err(|error| {
ProcessError::Stream {
bin: self.get_bin_name(),
error: Box::new(error),
})?;
}
})?;

self.write_input_to_child(&mut child, &line).await?;
} else {
child = command.spawn().map_err(|error| ProcessError::Stream {
self.write_input_to_child(&mut child, &line).await?;

child
} else {
command.spawn().map_err(|error| ProcessError::Stream {
bin: self.get_bin_name(),
error: Box::new(error),
})?
};

let shared_child = registry.add_child(child).await;

let result = shared_child
.wait()
.await
.map_err(|error| ProcessError::Stream {
bin: self.get_bin_name(),
error: Box::new(error),
})?;
};
});

self.current_id = child.id();

let status = child.wait().await.map_err(|error| ProcessError::Stream {
bin: self.get_bin_name(),
error: Box::new(error),
})?;
registry.remove_child(shared_child).await;

let status = result?;
let output = Output {
status,
stderr: vec![],
Expand All @@ -98,7 +112,8 @@ impl Command {
Ok(output)
}

pub async fn exec_stream_and_capture_output_old(&mut self) -> miette::Result<Output> {
pub async fn exec_stream_and_capture_output(&mut self) -> miette::Result<Output> {
let registry = ProcessRegistry::instance();
let (mut command, line) = self.create_async_command();

let mut child = command
Expand All @@ -115,20 +130,20 @@ impl Command {
error: Box::new(error),
})?;

self.current_id = child.id();

if self.should_pass_stdin() {
self.write_input_to_child(&mut child, &line).await?;
}

let shared_child = registry.add_child(child).await;

// We need to log the child process output to the parent terminal
// AND capture stdout/stderr so that we can cache it for future runs.
// This doesn't seem to be supported natively by `Stdio`, so I have
// this *real ugly* implementation to solve it. There's gotta be a
// better way to do this?
// https://stackoverflow.com/a/49063262
let stderr = BufReader::new(child.stderr.take().unwrap());
let stdout = BufReader::new(child.stdout.take().unwrap());
let stderr = BufReader::new(shared_child.take_stderr().await.unwrap());
let stdout = BufReader::new(shared_child.take_stdout().await.unwrap());
let mut handles = vec![];

let captured_stderr = Arc::new(RwLock::new(vec![]));
Expand Down Expand Up @@ -192,14 +207,17 @@ impl Command {
}

// Attempt to create the child output
let status = child
let result = shared_child
.wait()
.await
.map_err(|error| ProcessError::StreamCapture {
bin: self.get_bin_name(),
error: Box::new(error),
})?;
});

registry.remove_child(shared_child).await;

let status = result?;
let output = Output {
status,
stdout: captured_stdout.read().unwrap().join("\n").into_bytes(),
Expand All @@ -211,7 +229,8 @@ impl Command {
Ok(output)
}

pub async fn exec_stream_and_capture_output(&mut self) -> miette::Result<Output> {
pub async fn exec_stream_and_capture_output_new(&mut self) -> miette::Result<Output> {
let registry = ProcessRegistry::instance();
let (mut command, line) = self.create_async_command();

let mut child = command
Expand All @@ -228,18 +247,18 @@ impl Command {
error: Box::new(error),
})?;

self.current_id = child.id();

if self.should_pass_stdin() {
self.write_input_to_child(&mut child, &line).await?;
}

let shared_child = registry.add_child(child).await;

// Stream and attempt to capture the output
let stderr = child.stderr.take().unwrap();
let stderr = shared_child.take_stderr().await.unwrap();
let mut stderr_buffer = Vec::new();
let mut stderr_pos = 0;

let stdout = child.stdout.take().unwrap();
let stdout = shared_child.take_stdout().await.unwrap();
let mut stdout_buffer = Vec::new();
let mut stdout_pos = 0;

Expand Down Expand Up @@ -291,14 +310,17 @@ impl Command {
})?;

// Attempt to create the child output
let status = child
let result = shared_child
.wait()
.await
.map_err(|error| ProcessError::StreamCapture {
bin: self.get_bin_name(),
error: Box::new(error),
})?;
});

registry.remove_child(shared_child).await;

let status = result?;
let output = Output {
status,
stdout: stdout_buffer,
Expand Down Expand Up @@ -339,8 +361,6 @@ impl Command {
}

fn handle_nonzero_status(&mut self, output: &Output, with_message: bool) -> miette::Result<()> {
self.current_id = None;

if self.should_error_nonzero() && !output.status.success() {
return Err(output_to_error(self.get_bin_name(), output, with_message).into());
}
Expand Down
2 changes: 2 additions & 0 deletions crates/process/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod output;
mod output_stream;
mod process_error;
mod process_registry;
mod shared_child;
mod shell;
mod signal;

Expand All @@ -14,5 +15,6 @@ pub use moon_args as args;
pub use output::*;
pub use process_error::*;
pub use process_registry::*;
pub use shared_child::*;
pub use shell::*;
pub use signal::*;
53 changes: 40 additions & 13 deletions crates/process/src/process_registry.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use crate::shared_child::*;
use crate::signal::*;
use core::time::Duration;
use rustc_hash::FxHashMap;
use std::sync::{Arc, OnceLock};
use tokio::process::Child;
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::broadcast::{self, Receiver, Sender};
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tokio::time::sleep;
use tracing::{debug, trace};

static INSTANCE: OnceLock<Arc<ProcessRegistry>> = OnceLock::new();

pub struct ProcessRegistry {
processes: Arc<RwLock<FxHashMap<u32, Child>>>,
processes: Arc<RwLock<FxHashMap<u32, SharedChild>>>,
signal_sender: Sender<SignalType>,
signal_receiver: Receiver<SignalType>,
signal_wait_handle: JoinHandle<()>,
signal_shutdown_handle: JoinHandle<()>,
}
Expand All @@ -29,43 +31,66 @@ impl ProcessRegistry {

let (sender, receiver) = broadcast::channel::<SignalType>(10);
let sender_bg = sender.clone();
let receiver_bg = sender.subscribe();

let signal_wait_handle = tokio::spawn(async move {
wait_for_signal(sender_bg).await;
});

let signal_shutdown_handle = tokio::spawn(async move {
shutdown_processes_with_signal(receiver_bg, processes_bg).await;
shutdown_processes_with_signal(receiver, processes_bg).await;
});

Self {
processes,
signal_sender: sender,
signal_receiver: receiver,
signal_wait_handle,
signal_shutdown_handle,
}
}

pub async fn add_child(&self, child: Child) {
pub async fn add_child(&self, child: Child) -> SharedChild {
let shared = SharedChild::new(child);

self.processes
.write()
.await
.insert(child.id().expect("Child process requires a PID!"), child);
.insert(shared.id(), shared.clone());

shared
}

pub async fn get_child_by_id(&self, id: u32) -> Option<SharedChild> {
self.processes.read().await.get(&id).cloned()
}

pub async fn remove_child(&self, id: u32) {
pub async fn remove_child(&self, child: SharedChild) {
self.remove_child_by_id(child.id()).await
}

pub async fn remove_child_by_id(&self, id: u32) {
self.processes.write().await.remove(&id);
}

pub fn receive_signal(&self) -> Receiver<SignalType> {
self.signal_sender.subscribe()
}

// pub async fn wait_to_shutdown(&self) {
// self.signal_shutdown_handle.await;
// }
pub fn terminate_children(&self) {
let _ = self.signal_sender.send(SignalType::Terminate);
}

pub async fn wait_for_children_to_shutdown(&self) {
let mut count = 0;

loop {
if self.processes.read().await.is_empty() || count >= 5000 {
break;
}

sleep(Duration::from_millis(50)).await;
count += 50;
}
}
}

impl Drop for ProcessRegistry {
Expand All @@ -77,7 +102,7 @@ impl Drop for ProcessRegistry {

async fn shutdown_processes_with_signal(
mut receiver: Receiver<SignalType>,
processes: Arc<RwLock<FxHashMap<u32, Child>>>,
processes: Arc<RwLock<FxHashMap<u32, SharedChild>>>,
) {
// TODO
loop {
Expand All @@ -102,10 +127,12 @@ async fn shutdown_processes_with_signal(
children.len()
);

for (pid, mut child) in children.drain() {
for (pid, child) in children.drain() {
trace!(pid, "Killing child process");

let _ = child.kill().await;

drop(child);
}
}

Expand Down
Loading

0 comments on commit d52ce66

Please sign in to comment.