diff --git a/crates/nightward-core/src/providers.rs b/crates/nightward-core/src/providers.rs index fb18d37..b543670 100644 --- a/crates/nightward-core/src/providers.rs +++ b/crates/nightward-core/src/providers.rs @@ -8,6 +8,7 @@ use std::env; use std::io::Read; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; +use std::sync::mpsc::{self, Receiver, RecvTimeoutError}; use std::thread; use std::time::Duration; use wait_timeout::ChildExt; @@ -15,6 +16,7 @@ use wait_timeout::ChildExt; const DEFAULT_STDOUT_CAP: usize = 2 * 1024 * 1024; const DEFAULT_STDERR_CAP: usize = 64 * 1024; const DEFAULT_PROVIDER_TIMEOUT: Duration = Duration::from_secs(20); +const STREAM_COLLECT_TIMEOUT: Duration = Duration::from_secs(1); #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Provider { @@ -162,26 +164,25 @@ pub fn run_provider(name: &str, root: &Path) -> Result> { .stderr(Stdio::piped()) .spawn() .with_context(|| format!("spawn provider {name}"))?; - let stdout_handle = child + let stdout_reader = child .stdout .take() - .map(|stream| thread::spawn(move || read_stream_capped(stream, stdout_cap))); - let stderr_handle = child + .map(|stream| spawn_stream_reader(stream, stdout_cap)); + let stderr_reader = child .stderr .take() - .map(|stream| thread::spawn(move || read_stream_capped(stream, stderr_cap))); + .map(|stream| spawn_stream_reader(stream, stderr_cap)); let status = match child.wait_timeout(timeout)? { Some(status) => status, None => { let _ = child.kill(); let _ = child.wait(); - let _ = join_stream(stdout_handle); - let _ = join_stream(stderr_handle); return Err(anyhow!("provider timed out after {:?}", timeout)); } }; - let (stdout, stdout_truncated) = join_stream(stdout_handle); - let (stderr, _) = join_stream(stderr_handle); + let (stdout, stdout_truncated) = + collect_stream(stdout_reader, "stdout", STREAM_COLLECT_TIMEOUT)?; + let (stderr, _) = collect_stream(stderr_reader, "stderr", STREAM_COLLECT_TIMEOUT)?; if stdout_truncated { return Err(anyhow!("provider stdout exceeded {stdout_cap} byte cap")); } @@ -191,6 +192,18 @@ pub fn run_provider(name: &str, root: &Path) -> Result> { parse_provider_output(name, root, &stdout) } +struct StreamReader { + receiver: Receiver<(String, bool)>, +} + +fn spawn_stream_reader(stream: impl Read + Send + 'static, cap: usize) -> StreamReader { + let (sender, receiver) = mpsc::channel(); + thread::spawn(move || { + let _ = sender.send(read_stream_capped(stream, cap)); + }); + StreamReader { receiver } +} + fn read_stream_capped(mut stream: impl Read, cap: usize) -> (String, bool) { let mut out = Vec::with_capacity(cap.min(64 * 1024)); let mut truncated = false; @@ -213,10 +226,23 @@ fn read_stream_capped(mut stream: impl Read, cap: usize) -> (String, bool) { (redact_text(&String::from_utf8_lossy(&out)), truncated) } -fn join_stream(handle: Option>) -> (String, bool) { - handle - .and_then(|handle| handle.join().ok()) - .unwrap_or_default() +fn collect_stream( + reader: Option, + label: &str, + timeout: Duration, +) -> Result<(String, bool)> { + let Some(reader) = reader else { + return Ok((String::new(), false)); + }; + match reader.receiver.recv_timeout(timeout) { + Ok(result) => Ok(result), + Err(RecvTimeoutError::Timeout) => { + Err(anyhow!("provider {label} did not close after process exit")) + } + Err(RecvTimeoutError::Disconnected) => { + Err(anyhow!("provider {label} reader ended without output")) + } + } } pub fn parse_provider_output( diff --git a/crates/nightward-core/tests/provider_contracts.rs b/crates/nightward-core/tests/provider_contracts.rs index 6639706..6d08768 100644 --- a/crates/nightward-core/tests/provider_contracts.rs +++ b/crates/nightward-core/tests/provider_contracts.rs @@ -1,6 +1,8 @@ use nightward_core::analysis::SignalCategory; use nightward_core::providers::{parse_provider_output, run_provider, statuses}; use std::path::{Path, PathBuf}; +#[cfg(unix)] +use std::time::{Duration, Instant}; fn fixture(name: &str) -> String { let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) @@ -106,6 +108,35 @@ fn provider_timeout_returns_stable_warning_error() { assert!(error.to_string().contains("provider timed out after")); } +#[cfg(unix)] +#[test] +fn provider_timeout_does_not_wait_for_inherited_output_pipes() { + let _guard = EnvRestore::set(&[ + ("PATH", None), + ("NIGHTWARD_PROVIDER_TIMEOUT_MS", Some("50")), + ("NIGHTWARD_PROVIDER_STDOUT_CAP", None), + ]); + let dir = tempfile::tempdir().expect("temp dir"); + write_executable( + dir.path().join("gitleaks"), + "#!/bin/sh\n(/bin/sleep 2) &\n/bin/sleep 1\n", + ); + std::env::set_var("PATH", dir.path()); + + let started = Instant::now(); + let error = run_provider("gitleaks", dir.path()).expect_err("timeout"); + let elapsed = started.elapsed(); + + assert!( + error.to_string().contains("provider timed out after"), + "actual error: {error}" + ); + assert!( + elapsed < Duration::from_secs(1), + "provider timeout waited for inherited pipe holder: {elapsed:?}" + ); +} + #[cfg(unix)] #[test] fn provider_stdout_cap_fails_closed_before_parsing() {