diff --git a/Cargo.lock b/Cargo.lock index fb5aa5cc1..857f15890 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4007,6 +4007,7 @@ dependencies = [ "async-trait", "axum", "base64 0.22.1", + "bytes", "cbc", "chrono", "dashmap", @@ -4028,6 +4029,7 @@ dependencies = [ "serde_json", "sha1", "sha2", + "tempfile", "tokio", "tokio-stream", "tokio-test", diff --git a/crates/openfang-api/src/openai_compat.rs b/crates/openfang-api/src/openai_compat.rs index d5d1cebd8..b66a05b22 100644 --- a/crates/openfang-api/src/openai_compat.rs +++ b/crates/openfang-api/src/openai_compat.rs @@ -216,7 +216,11 @@ fn convert_messages(oai_messages: &[OaiMessage]) -> Vec { .unwrap_or(parts[0]) .to_string(); let data = parts[1].to_string(); - Some(ContentBlock::Image { media_type, data }) + Some(ContentBlock::Image { + media_type, + data, + source_url: None, + }) } else { None } diff --git a/crates/openfang-api/src/routes.rs b/crates/openfang-api/src/routes.rs index 233677777..7b07686c8 100644 --- a/crates/openfang-api/src/routes.rs +++ b/crates/openfang-api/src/routes.rs @@ -280,6 +280,7 @@ pub fn resolve_attachments( blocks.push(openfang_types::message::ContentBlock::Image { media_type: content_type, data: b64, + source_url: None, }); } Err(e) => { @@ -512,6 +513,7 @@ pub async fn get_agent_session( openfang_types::message::ContentBlock::Image { media_type, data, + .. } => { texts.push("[Image]".to_string()); // Persist image to upload dir so it can be diff --git a/crates/openfang-channels/Cargo.toml b/crates/openfang-channels/Cargo.toml index e0ef71d55..71b74ce02 100644 --- a/crates/openfang-channels/Cargo.toml +++ b/crates/openfang-channels/Cargo.toml @@ -29,6 +29,7 @@ sha1 = { workspace = true } aes = "0.8" cbc = "0.1" base64 = { workspace = true } +bytes = { workspace = true } hex = { workspace = true } html-escape = { workspace = true } regex-lite = "0.1" @@ -42,3 +43,4 @@ rumqttc = { workspace = true } [dev-dependencies] tokio-test = { workspace = true } +tempfile = { workspace = true } diff --git a/crates/openfang-channels/src/bridge.rs b/crates/openfang-channels/src/bridge.rs index a1aecae62..93f5ab38d 100644 --- a/crates/openfang-channels/src/bridge.rs +++ b/crates/openfang-channels/src/bridge.rs @@ -601,12 +601,38 @@ async fn send_response( thread_id: Option<&str>, output_format: OutputFormat, ) { + // Parse `` markers BEFORE formatting — channel + // formatters (telegram HTML, slack mrkdwn) escape `<` and would break + // marker detection downstream. + let (text_to_format, attachment_blocks): (String, Vec) = + match crate::outbound_attach::parse(&text, None).await { + crate::outbound_attach::Parsed::NoMarkers => (text, Vec::new()), + crate::outbound_attach::Parsed::WithAttachments { + stripped_text, + files, + } => (stripped_text, files), + }; + let formatted = if adapter.name() == "wecom" { - formatter::format_for_wecom(&text, output_format) + formatter::format_for_wecom(&text_to_format, output_format) } else { - formatter::format_for_channel(&text, output_format) + formatter::format_for_channel(&text_to_format, output_format) + }; + + let content = if attachment_blocks.is_empty() { + ChannelContent::Text(formatted) + } else { + let mut blocks = Vec::with_capacity(attachment_blocks.len() + 1); + if !formatted.trim().is_empty() { + blocks.push(ChannelContent::Text(formatted)); + } + blocks.extend(attachment_blocks); + if blocks.len() == 1 { + blocks.remove(0) + } else { + ChannelContent::Multipart(blocks) + } }; - let content = ChannelContent::Text(formatted); let result = if let Some(tid) = thread_id { adapter.send_in_thread(user, content, tid).await @@ -858,6 +884,93 @@ async fn dispatch_message( return; } + // Multipart: flatten children into LLM content blocks. If any image + // succeeds, dispatch as multimodal; otherwise fall through to the text + // path (Multipart arm in the match below builds the combined descriptor). + if let ChannelContent::Multipart(parts) = &message.content { + let mut blocks: Vec = Vec::new(); + for part in parts { + debug_assert!( + !matches!(part, ChannelContent::Multipart(_)), + "nested Multipart in ChannelContent — adapters should produce flat lists" + ); + match part { + ChannelContent::Text(t) => blocks.push(ContentBlock::Text { + text: t.clone(), + provider_metadata: None, + }), + ChannelContent::Image { url, caption } => { + let mut img = download_image_to_blocks(url, caption.as_deref()).await; + blocks.append(&mut img); + } + ChannelContent::File { url, filename, .. } => { + blocks.push(ContentBlock::Text { + text: format!("[User sent a file ({filename}): {url}]"), + provider_metadata: None, + }); + } + ChannelContent::Voice { + url, + duration_seconds, + } => { + blocks.push(ContentBlock::Text { + text: format!("[User sent a voice message ({duration_seconds}s): {url}]"), + provider_metadata: None, + }); + } + ChannelContent::Location { lat, lon } => { + blocks.push(ContentBlock::Text { + text: format!("[User shared location: {lat}, {lon}]"), + provider_metadata: None, + }); + } + ChannelContent::FileData { filename, .. } => { + blocks.push(ContentBlock::Text { + text: format!("[User sent a local file: {filename}]"), + provider_metadata: None, + }); + } + // Commands aren't expected inside Multipart, but render as + // text rather than drop the message if one slips through. + ChannelContent::Command { name, args } => { + blocks.push(ContentBlock::Text { + text: format!("/{name} {}", args.join(" ")), + provider_metadata: None, + }); + } + // Defensive: debug_assert above catches this in dev; ignore + // gracefully in release. + ChannelContent::Multipart(_) => {} + } + } + + if blocks + .iter() + .any(|b| matches!(b, ContentBlock::Image { .. })) + { + let prefix_style = overrides + .as_ref() + .map(|o| o.prefix_agent_name) + .unwrap_or(PrefixStyle::Off); + dispatch_with_blocks( + blocks, + message, + handle, + router, + adapter, + adapter_arc, + ct_str, + thread_id, + output_format, + lifecycle_reactions, + prefix_style, + ) + .await; + return; + } + // No image blocks — fall through to text path below. + } + // For images: download, base64 encode, and send as multimodal content blocks if let ChannelContent::Image { ref url, @@ -909,6 +1022,7 @@ async fn dispatch_message( ChannelContent::File { ref url, ref filename, + .. } => { format!("[User sent a file ({filename}): {url}]") } @@ -924,6 +1038,37 @@ async fn dispatch_message( ChannelContent::FileData { ref filename, .. } => { format!("[User sent a local file: {filename}]") } + ChannelContent::Multipart(parts) => parts + .iter() + .map(|p| match p { + ChannelContent::Text(t) => t.clone(), + ChannelContent::Image { url, caption } => match caption { + Some(c) => format!("[User sent a photo: {url}]\nCaption: {c}"), + None => format!("[User sent a photo: {url}]"), + }, + ChannelContent::File { url, filename, .. } => { + format!("[User sent a file ({filename}): {url}]") + } + ChannelContent::Voice { + url, + duration_seconds, + } => format!("[User sent a voice message ({duration_seconds}s): {url}]"), + ChannelContent::Location { lat, lon } => { + format!("[User shared location: {lat}, {lon}]") + } + ChannelContent::FileData { filename, .. } => { + format!("[User sent a local file: {filename}]") + } + ChannelContent::Command { name, args } => { + format!("/{name} {}", args.join(" ")) + } + // Nesting is rejected by adapters; emit empty so the join + // doesn't insert spurious separators. + ChannelContent::Multipart(_) => String::new(), + }) + .filter(|s| !s.is_empty()) + .collect::>() + .join("\n"), }; // Check if it's a slash command embedded in text (e.g. "/agents") @@ -1372,6 +1517,10 @@ fn media_type_from_url(url: &str) -> String { /// Download an image from a URL and build content blocks for multimodal LLM input. /// +/// Accepts both `http(s)://` URLs (fetched via reqwest) and `file://` URLs +/// (read from local disk — used by the channel inbox materialization path so +/// agents see a stable local path even after a Discord CDN URL has expired). +/// /// Returns a `Vec` containing an image block (base64-encoded) and /// optionally a text block for the caption. If the download fails, returns a /// text-only block describing the failure. @@ -1381,38 +1530,79 @@ async fn download_image_to_blocks(url: &str, caption: Option<&str>) -> Vec r, - Err(e) => { - warn!("Failed to download image from channel: {e}"); - return vec![ContentBlock::Text { - text: format!("[Image download failed: {e}]"), - provider_metadata: None, - }]; - } - }; + // Branch on URL scheme: file:// reads from local disk, everything else + // goes through HTTP. We unify both paths into (bytes, header_type) before + // the size/magic-byte logic below. + let (bytes, header_type): (Vec, Option) = + if let Some(path) = url.strip_prefix("file://") { + // file:// — local read. No content-type header to honor; magic-byte + // sniffing and URL extension fallback do all the work. We don't + // percent-decode: the inbox writer controls filenames and avoids + // characters that would need encoding. + match tokio::fs::read(path).await { + Ok(b) => (b, None), + Err(e) => { + warn!("Failed to read image from local path {path}: {e}"); + return vec![ContentBlock::Text { + text: format!("[Image read failed: {e}]"), + provider_metadata: None, + }]; + } + } + } else { + // Build the client with transparent decompression DISABLED. Discord's + // CDN edges occasionally advertise `content-encoding: gzip` (or br) + // on PNG/JPEG passthroughs while the body is the raw, uncompressed + // image bytes. With the default reqwest client (gzip/deflate/brotli + // features enabled at the workspace level), this causes the + // decompression layer to choke on the image header and reqwest + // returns "error decoding response body" only on `bytes().await`, + // not on `send()`. Forcing identity encoding sidesteps the whole + // class of CDN content-encoding-flapping bugs. We also set a UA + // (some CDNs 403 clients without one) and a 30s timeout aligned + // with the upstream 5 MB cap. + let client = reqwest::Client::builder() + .no_gzip() + .no_deflate() + .no_brotli() + .user_agent("openfang/0.1 (+https://openfang.ai)") + .timeout(std::time::Duration::from_secs(30)) + .build() + .unwrap_or_else(|_| reqwest::Client::new()); + let resp = match client.get(url).send().await { + Ok(r) => r, + Err(e) => { + warn!("Failed to download image from channel: {e}"); + return vec![ContentBlock::Text { + text: format!("[Image download failed: {e}]"), + provider_metadata: None, + }]; + } + }; - // Detect media type from Content-Type header — but only trust it if it's - // actually an image/* type. Many APIs (Telegram, S3 pre-signed URLs) return - // `application/octet-stream` for all files, which breaks vision. - let header_type = resp - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .map(|ct| ct.split(';').next().unwrap_or(ct).trim().to_string()) - .filter(|ct| ct.starts_with("image/")); - - let bytes = match resp.bytes().await { - Ok(b) => b, - Err(e) => { - warn!("Failed to read image bytes: {e}"); - return vec![ContentBlock::Text { - text: format!("[Image read failed: {e}]"), - provider_metadata: None, - }]; - } - }; + // Detect media type from Content-Type header — but only trust it if + // it's actually an image/* type. Many APIs (Telegram, S3 pre-signed + // URLs) return `application/octet-stream` for all files, which + // breaks vision. + let header_type = resp + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .map(|ct| ct.split(';').next().unwrap_or(ct).trim().to_string()) + .filter(|ct| ct.starts_with("image/")); + + let bytes = match resp.bytes().await { + Ok(b) => b, + Err(e) => { + warn!("Failed to read image bytes: {e}"); + return vec![ContentBlock::Text { + text: format!("[Image read failed: {e}]"), + provider_metadata: None, + }]; + } + }; + (bytes.to_vec(), header_type) + }; // Three-tier media type detection: // 1. Trusted Content-Type header (only if image/*) @@ -1453,7 +1643,14 @@ async fn download_image_to_blocks(url: &str, caption: Option<&str>) -> Vec assert_eq!(text, "hello"), + other => panic!("expected text caption block, got {other:?}"), + } + match &blocks[1] { + ContentBlock::Image { + source_url, + media_type, + .. + } => { + assert_eq!( + source_url.as_deref(), + Some(url.as_str()), + "source_url must round-trip the fetched URL" + ); + assert_eq!(media_type, "image/png"); + } + other => panic!("expected image block, got {other:?}"), + } + } } diff --git a/crates/openfang-channels/src/discord.rs b/crates/openfang-channels/src/discord.rs index 7d43e53f4..8501743b4 100644 --- a/crates/openfang-channels/src/discord.rs +++ b/crates/openfang-channels/src/discord.rs @@ -9,6 +9,7 @@ use crate::types::{ use async_trait::async_trait; use futures::{SinkExt, Stream, StreamExt}; use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -16,12 +17,54 @@ use std::time::Duration; use tokio::sync::{mpsc, watch, Mutex, RwLock}; use tokio::task::JoinHandle; use tracing::{debug, error, info, warn}; +use url::Url; use zeroize::Zeroizing; const DISCORD_API_BASE: &str = "https://discord.com/api/v10"; const MAX_BACKOFF: Duration = Duration::from_secs(60); const INITIAL_BACKOFF: Duration = Duration::from_secs(1); const DISCORD_MSG_LIMIT: usize = 2000; +/// Multipart field name Discord requires for the first attachment payload. +/// Test-only fixture: the production call site uses `format!("files[{i}]")` +/// directly; this constant pins the wire format in +/// `test_attachment_field_name_pinned` so a `file[0]` typo or future +/// refactor can't slip past review. +#[cfg(test)] +const ATTACHMENT_FIELD_NAME: &str = "files[0]"; +/// Floor on the rate-limit retry delay. Discord occasionally returns +/// `retry_after: 0` (or a missing header), which would busy-loop the retry. +const RETRY_AFTER_FLOOR_SECS: f64 = 0.05; +/// Cap on the rate-limit retry delay so a misbehaving response can't park us +/// for a long time on a one-shot retry. +const RETRY_AFTER_CEIL_SECS: f64 = 30.0; +/// Per-request timeout for outbound URL fetches (File/Image arms). Matches the +/// adapter's existing 15s budget for other REST calls so a slow remote can't +/// stall the send pipeline. +const URL_FETCH_TIMEOUT: Duration = Duration::from_secs(15); +/// Hard cap on the size of a fetched URL body, both via Content-Length pre-flight +/// and via streamed accumulation. Discord itself caps non-Nitro uploads at +/// 25 MiB; matching that here means we reject before paying for bytes Discord +/// would refuse anyway. +const URL_FETCH_MAX_BYTES: usize = 25 * 1024 * 1024; +/// Maximum number of attachments per multipart POST. Discord's REST API caps +/// `files[N]` at 10 per request; the multipart helper relies on the caller +/// having pre-chunked. +const ATTACHMENTS_PER_CHUNK: usize = 10; +/// Aggregate byte cap on a single multipart POST's attachment payload. +/// Discord caps non-Nitro requests at 25 MiB total (multipart envelope + +/// payload_json + every `files[i]`); 24 MiB leaves ~1 MiB of headroom for the +/// envelope so an over-budget attempt can't silently 413. Files larger than +/// this cap end up in their own single-attachment chunk; Discord still +/// rejects them, but the caller sees the same error they did before this +/// cap landed. +const CHUNK_TOTAL_CAP_BYTES: usize = 24 * 1024 * 1024; +/// Cap on the number of HTTP redirects we'll follow on URL fetches. Each hop +/// is independently SSRF-rechecked at the literal-IP level. +const URL_FETCH_MAX_REDIRECTS: usize = 3; +/// User-Agent we identify as on outbound URL fetches. Pinned so a future test +/// can assert on it; remote operators looking at access logs see a single +/// stable identifier instead of reqwest's default. +const URL_FETCH_USER_AGENT: &str = concat!("openfang-channels-discord/", env!("CARGO_PKG_VERSION")); /// Discord Gateway opcodes. mod opcode { @@ -47,6 +90,236 @@ fn build_heartbeat_payload(last_sequence: Option) -> serde_json::Value { }) } +/// Format a URL for log/error messages with the query string and fragment +/// stripped. Discord CDN URLs carry HMAC-style query params (`ex`, `is`, `hm`, +/// `__cf_bm`) that grant time-limited access; logging them at warn/error level +/// would leak credential-equivalent material into operator log aggregators. +fn redact_url(u: &Url) -> String { + format!( + "{}://{}{}", + u.scheme(), + u.host_str().unwrap_or(""), + u.path() + ) +} + +/// Returns true if the IP address is one we refuse to fetch from to prevent +/// SSRF: loopback, RFC1918 / link-local / unique-local, multicast, unspecified, +/// or the literal cloud-metadata IP. IPv4-mapped IPv6 addresses are unwrapped +/// to their underlying v4 and re-checked. +fn is_blocked_ip(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(v4) => is_blocked_v4(v4), + IpAddr::V6(v6) => { + // Strip IPv4-mapped wrappers (::ffff:a.b.c.d) before re-checking. + // `Ipv6Addr::to_ipv4_mapped` is stable as of 1.63 but we use the + // older `to_ipv4` which also covers IPv4-compatible ::a.b.c.d. + if let Some(v4) = v6.to_ipv4() { + if is_blocked_v4(v4) { + return true; + } + } + if v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() { + return true; + } + // Link-local fe80::/10. `Ipv6Addr::is_unicast_link_local` is + // unstable; check the prefix manually. + let seg0 = v6.segments()[0]; + if (seg0 & 0xffc0) == 0xfe80 { + return true; + } + // Unique local fc00::/7. + if (seg0 & 0xfe00) == 0xfc00 { + return true; + } + false + } + } +} + +fn is_blocked_v4(v4: Ipv4Addr) -> bool { + if v4.is_loopback() + || v4.is_private() + || v4.is_link_local() + || v4.is_unspecified() + || v4.is_multicast() + || v4.is_broadcast() + { + return true; + } + // 169.254.169.254 is technically link-local (covered above) but make the + // intent explicit — a future stdlib change to `is_link_local` shouldn't + // silently re-open cloud-metadata exfiltration. + if v4.octets() == [169, 254, 169, 254] { + return true; + } + // Carrier-grade NAT 100.64.0.0/10. Not in `Ipv4Addr::is_private` but is + // commonly internal. + if v4.octets()[0] == 100 && (v4.octets()[1] & 0xc0) == 0x40 { + return true; + } + false +} + +/// Synchronous SSRF check on a parsed URL: scheme allowlist + literal-IP host +/// range check. Hostname (DNS) resolution is the caller's responsibility (see +/// [`resolve_and_check_host`]); this function intentionally avoids DNS so it +/// can run inside the sync `redirect::Policy::custom` callback on every hop. +/// +/// Threat model note: the redirect callback can only do this literal-IP +/// recheck, not DNS, because reqwest's redirect policy is sync. A malicious +/// DNS server that returns a public IP at first lookup and a private IP on a +/// second lookup is *not* in the threat model — the threat is a malicious URL +/// the agent was tricked into emitting. Literal-IP redirects are still +/// blocked at every hop, which closes the most obvious bypass. +fn check_url_scheme_and_literal_ip(u: &Url) -> Result<(), String> { + match u.scheme() { + "http" | "https" => {} + other => { + return Err(format!( + "URL fetch refused: scheme {other:?} not allowed (need http/https) for {}", + redact_url(u) + )); + } + } + if let Some(host) = u.host() { + match host { + url::Host::Ipv4(v4) => { + if is_blocked_v4(v4) { + return Err(format!( + "URL fetch refused: blocked IPv4 host for {}", + redact_url(u) + )); + } + } + url::Host::Ipv6(v6) => { + if is_blocked_ip(IpAddr::V6(v6)) { + return Err(format!( + "URL fetch refused: blocked IPv6 host for {}", + redact_url(u) + )); + } + } + url::Host::Domain(_) => {} + } + } else { + return Err(format!( + "URL fetch refused: missing host for {}", + redact_url(u) + )); + } + Ok(()) +} + +/// Typed intermediate produced by the single classification pass over a +/// `ChannelContent::Multipart`'s blocks. Carrying enough information per +/// variant lets the subsequent resolve step operate on this enum alone +/// without a second walk over the original `Vec`. +enum AttachmentSource { + /// Already fully-resolved attachment (came from a `FileData` block). + Resolved { + bytes: bytes::Bytes, + filename: String, + mime: String, + }, + /// URL-backed image; `Fetcher` resolves the bytes, then + /// `resolve_image_mime` / `resolve_image_filename` derive the metadata. + UrlImage { url: String }, + /// URL-backed file with caller-supplied filename/mime hints; `Fetcher` + /// resolves the bytes, then `resolve_file_mime` / `resolve_file_filename` + /// reconcile against the response Content-Type. + UrlFile { + url: String, + filename: String, + mime: Option, + }, +} + +/// Abstraction over "fetch a URL into memory" so production and tests share +/// the same wire-level HTTP code while differing only in whether SSRF +/// validation runs first. +/// +/// - [`ProductionFetcher`] performs scheme + DNS-resolved IP checks via +/// [`resolve_and_check_host`] before issuing the request. +/// - [`PermissiveFetcher`] (test-only) skips the SSRF preflight so tests +/// can hit `127.0.0.1` fixture servers via the same wire path. +/// +/// Returns the body as `Bytes` plus the response's `Content-Type` with any +/// MIME parameters (e.g. `; charset=utf-8`) stripped. +#[async_trait] +trait Fetcher: Send + Sync { + async fn fetch( + &self, + url: &str, + ) -> Result<(bytes::Bytes, Option), Box>; +} + +/// Production fetcher: parses the URL, runs the SSRF preflight, then performs +/// the HTTP fetch via [`do_http_fetch`]. +struct ProductionFetcher; + +#[async_trait] +impl Fetcher for ProductionFetcher { + async fn fetch( + &self, + url: &str, + ) -> Result<(bytes::Bytes, Option), Box> { + let parsed = Url::parse(url).map_err(|e| format!("URL fetch refused: parse error: {e}"))?; + resolve_and_check_host(&parsed).await?; + do_http_fetch(&parsed).await + } +} + +/// Test-only fetcher that performs the same wire fetch but skips the SSRF +/// preflight, so tests can point `Image{url}` / `File{url}` blocks at local +/// stub servers without bypassing the production code path. +#[cfg(test)] +struct PermissiveFetcher; + +#[cfg(test)] +#[async_trait] +impl Fetcher for PermissiveFetcher { + async fn fetch( + &self, + url: &str, + ) -> Result<(bytes::Bytes, Option), Box> { + let parsed = Url::parse(url).map_err(|e| format!("URL fetch refused: parse error: {e}"))?; + do_http_fetch(&parsed).await + } +} + +/// Resolve the URL's host (DNS if hostname; identity if IP literal) and reject +/// if any resolved address fails the SSRF check. Performs both the scheme +/// check and the per-IP check. +async fn resolve_and_check_host(u: &Url) -> Result<(), String> { + check_url_scheme_and_literal_ip(u)?; + let host = match u.host() { + Some(url::Host::Domain(d)) => d.to_string(), + // IP literals already passed the literal-IP check above; no DNS needed. + Some(_) => return Ok(()), + None => { + return Err(format!( + "URL fetch refused: missing host for {}", + redact_url(u) + )) + } + }; + let port = u.port_or_known_default().unwrap_or(0); + let hostport = format!("{host}:{port}"); + let addrs = tokio::net::lookup_host(hostport.as_str()) + .await + .map_err(|_| format!("URL fetch refused: DNS lookup failed for {}", redact_url(u)))?; + for sa in addrs { + if is_blocked_ip(sa.ip()) { + return Err(format!( + "URL fetch refused: host resolved to blocked address for {}", + redact_url(u) + )); + } + } + Ok(()) +} + /// Discord Gateway adapter using WebSocket. pub struct DiscordAdapter { /// SECURITY: Bot token is zeroized on drop to prevent memory disclosure. @@ -64,6 +337,15 @@ pub struct DiscordAdapter { session_id: Arc>>, /// Resume gateway URL. resume_gateway_url: Arc>>, + /// Override for the Discord REST API base URL. `None` in production (uses + /// `DISCORD_API_BASE`). Set by tests that spin up a local stub server. + #[cfg(test)] + api_base_override: Option, + /// Resolver for outbound URL fetches (`Image{url}` / `File{url}`). In + /// production this is [`ProductionFetcher`] which runs the SSRF preflight; + /// tests can swap in [`PermissiveFetcher`] to point at local stubs without + /// bypassing the wire path. + fetcher: Arc, } impl DiscordAdapter { @@ -87,9 +369,27 @@ impl DiscordAdapter { bot_user_id: Arc::new(RwLock::new(None)), session_id: Arc::new(RwLock::new(None)), resume_gateway_url: Arc::new(RwLock::new(None)), + #[cfg(test)] + api_base_override: None, + fetcher: Arc::new(ProductionFetcher), } } + /// Returns the Discord REST API base URL, honouring the test override when + /// present. In production this is always `DISCORD_API_BASE`. + #[cfg(test)] + fn api_base(&self) -> &str { + self.api_base_override + .as_deref() + .unwrap_or(DISCORD_API_BASE) + } + + #[cfg(not(test))] + #[inline(always)] + fn api_base(&self) -> &str { + DISCORD_API_BASE + } + /// Get the WebSocket gateway URL from the Discord API. async fn get_gateway_url(&self) -> Result> { let url = format!("{DISCORD_API_BASE}/gateway/bot"); @@ -115,7 +415,7 @@ impl DiscordAdapter { channel_id: &str, text: &str, ) -> Result<(), Box> { - let url = format!("{DISCORD_API_BASE}/channels/{channel_id}/messages"); + let url = format!("{}/channels/{channel_id}/messages", self.api_base()); let chunks = split_message(text, DISCORD_MSG_LIMIT); for chunk in chunks { @@ -136,9 +436,178 @@ impl DiscordAdapter { Ok(()) } + /// Send a file attachment to a Discord channel via REST multipart upload. + /// + /// Thin wrapper around `api_send_attachments` for the common single-file + /// case. `Bytes::clone` is a refcount bump so passing through is free. + async fn api_send_attachment( + &self, + channel_id: &str, + data: impl Into, + filename: &str, + mime_type: &str, + caption: Option<&str>, + ) -> Result<(), Box> { + self.api_send_attachments( + channel_id, + vec![(data.into(), filename.to_string(), mime_type.to_string())], + caption, + ) + .await + } + + /// Send one or more file attachments in a single multipart POST. + /// + /// Builds a `multipart/form-data` request with `payload_json` plus + /// `files[0]`…`files[N-1]` parts (N ≤ 10, per Discord's limit). The + /// caller is responsible for chunking larger batches. + /// + /// On HTTP 429 we honor `Retry-After` once before giving up. Higher-tier + /// rate-limit handling can land later if needed. + async fn api_send_attachments( + &self, + channel_id: &str, + attachments: Vec<(bytes::Bytes, String, String)>, + caption: Option<&str>, + ) -> Result<(), Box> { + let url = format!("{}/channels/{channel_id}/messages", self.api_base()); + + // Discord caps message content at DISCORD_MSG_LIMIT chars; truncate + // explicitly so a long caption doesn't silently 400. + let payload_json = build_attachment_payload_json(caption); + + // Pre-compute lengths. `Bytes::clone` is a refcount bump so the + // retry-path form rebuild is allocation-free for the file data. + let parts_meta: Vec<(bytes::Bytes, u64, String, String)> = attachments + .into_iter() + .map(|(b, name, mime)| { + let len = b.len() as u64; + (b, len, name, mime) + }) + .collect(); + + let build_form = || -> Result> { + let mut form = + reqwest::multipart::Form::new().text("payload_json", payload_json.clone()); + for (i, (bytes, body_len, filename, mime_type)) in parts_meta.iter().enumerate() { + let body = reqwest::Body::from(bytes.clone()); + let file_part = reqwest::multipart::Part::stream_with_length(body, *body_len) + .file_name(filename.clone()) + .mime_str(mime_type)?; + // Discord requires field names `files[0]`, `files[1]`, etc. + // The wire format is pinned by `test_attachment_field_name_pinned` + // (asserts `format!("files[{}]", 0) == ATTACHMENT_FIELD_NAME`). + let field_name = format!("files[{i}]"); + form = form.part(field_name, file_part); + } + Ok(form) + }; + + let mut attempts = 0u8; + loop { + attempts += 1; + let form = build_form()?; + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bot {}", self.token.as_str())) + .multipart(form) + .send() + .await?; + + let status = resp.status(); + if status.is_success() { + return Ok(()); + } + + // Honor Retry-After once on 429. Discord puts the canonical + // `retry_after` in the JSON body; the HTTP header is a fallback + // (and is sometimes absent on per-route limits). + if status == reqwest::StatusCode::TOO_MANY_REQUESTS && attempts == 1 { + let header_secs = resp + .headers() + .get(reqwest::header::RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let body_text = resp.text().await.unwrap_or_default(); + let body_secs = serde_json::from_str::(&body_text) + .ok() + .and_then(|v| v.get("retry_after").and_then(|r| r.as_f64())); + let retry_after_secs = body_secs + .or(header_secs) + .unwrap_or(1.0) + .clamp(RETRY_AFTER_FLOOR_SECS, RETRY_AFTER_CEIL_SECS); + warn!( + "Discord sendAttachments rate-limited; retrying after {retry_after_secs:.2}s" + ); + tokio::time::sleep(Duration::from_millis((retry_after_secs * 1000.0) as u64)).await; + continue; + } + + let body_text = resp.text().await.unwrap_or_default(); + warn!("Discord sendAttachments failed ({status}): {body_text}"); + return Err(format!("Discord sendAttachments failed ({status}): {body_text}").into()); + } + } + + /// Resolve a single [`AttachmentSource`] into the + /// `(bytes, filename, mime)` tuple consumed by `api_send_attachments`. + /// + /// `Resolved` returns immediately; URL variants delegate to `Fetcher` and + /// then run their respective resolver chains. Errors are wrapped with the + /// `"Multipart fetch failed for {url}: …"` prefix the existing tests pin. + /// + /// The error type is `Box` (not the looser + /// `Box` used elsewhere) so the resulting future is `Send` — + /// required by `try_join_all` in the multipart resolve step. The + /// conversion to `Box` happens implicitly at the call site + /// via `?`. + async fn resolve_attachment_source( + &self, + source: AttachmentSource, + ) -> Result<(bytes::Bytes, String, String), Box> { + match source { + AttachmentSource::Resolved { + bytes, + filename, + mime, + } => Ok((bytes, filename, mime)), + AttachmentSource::UrlImage { url } => { + // `Fetcher::fetch` returns `Box` (no Send); + // stringify so the error becomes `Send + Sync` for `?`. + let (bytes, response_ct) = self + .fetcher + .fetch(&url) + .await + .map_err(|e| format!("Multipart fetch failed for {url}: {e}"))?; + let resolved_mime = resolve_image_mime(response_ct.as_deref(), &url); + let resolved_filename = resolve_image_filename(&url, &resolved_mime); + Ok((bytes, resolved_filename, resolved_mime)) + } + AttachmentSource::UrlFile { + url, + filename, + mime, + } => { + let (bytes, response_ct) = self + .fetcher + .fetch(&url) + .await + .map_err(|e| format!("Multipart fetch failed for {url}: {e}"))?; + let resolved_filename = resolve_file_filename(Some(filename.as_str()), &url); + let resolved_mime = resolve_file_mime( + mime.as_deref(), + response_ct.as_deref(), + &resolved_filename, + ); + Ok((bytes, resolved_filename, resolved_mime)) + } + } + } + /// Send typing indicator to a Discord channel. async fn api_send_typing(&self, channel_id: &str) -> Result<(), Box> { - let url = format!("{DISCORD_API_BASE}/channels/{channel_id}/typing"); + let url = format!("{}/channels/{channel_id}/typing", self.api_base()); let _ = self .client .post(&url) @@ -149,6 +618,122 @@ impl DiscordAdapter { } } +/// Wire-level HTTP fetch shared by [`ProductionFetcher`] and +/// [`PermissiveFetcher`]. Assumes the caller has already done any SSRF +/// preflight on `parsed`. Performs: +/// +/// 1. Per-request reqwest client with a redirect policy that caps at +/// [`URL_FETCH_MAX_REDIRECTS`] hops and re-applies the literal-IP SSRF +/// check on every hop's URL. +/// 2. Two-stage size enforcement: Content-Length pre-flight, then streaming +/// chunk accumulation that aborts mid-stream on overrun. +/// +/// Errors are scrubbed via [`redact_url`] so Discord CDN HMAC params don't +/// land in operator logs. +async fn do_http_fetch( + parsed: &Url, +) -> Result<(bytes::Bytes, Option), Box> { + // Per-request client with a custom redirect policy. We cannot reuse + // a shared client because its redirect policy is fixed at build time. + let redirect_policy = reqwest::redirect::Policy::custom(|attempt| { + if attempt.previous().len() >= URL_FETCH_MAX_REDIRECTS { + return attempt.error(format!("redirect cap ({URL_FETCH_MAX_REDIRECTS}) exceeded")); + } + // Sync context: we can only do the literal-IP recheck here; DNS + // requires async. The original hostname was DNS-checked before + // the request started, so the only new bypass to close at this + // layer is a redirect to a literal private IP. + if let Err(e) = check_url_scheme_and_literal_ip(attempt.url()) { + return attempt.error(e); + } + attempt.follow() + }); + let client = reqwest::Client::builder() + .redirect(redirect_policy) + .user_agent(URL_FETCH_USER_AGENT) + .timeout(URL_FETCH_TIMEOUT) + .build()?; + + let resp = client.get(parsed.as_str()).send().await.map_err(|e| { + // reqwest's Display impl for Error includes the URL it was + // fetching (with query string). Replace it with the redacted + // form to keep CDN HMAC params out of error logs. + // + // For redirect-policy errors, reqwest's outer Display is the + // generic "error following redirect"; the actual cause (e.g. + // "URL fetch refused: blocked IPv4 host for ...") lives in the + // source chain. Walk it so the operator sees *why* we refused. + let stripped = e.without_url(); + let mut msg = stripped.to_string(); + let mut src: Option<&dyn std::error::Error> = std::error::Error::source(&stripped); + while let Some(s) = src { + use std::fmt::Write as _; + let _ = write!(msg, ": {s}"); + src = s.source(); + } + format!("URL fetch failed for {}: {msg}", redact_url(parsed)) + })?; + + let status = resp.status(); + if !status.is_success() { + // Read up to 512B of the body for diagnostics; ignore errors. + let snippet: String = resp + .text() + .await + .unwrap_or_default() + .chars() + .take(512) + .collect(); + return Err(format!( + "URL fetch failed ({status}) for {}: {snippet}", + redact_url(parsed) + ) + .into()); + } + + // Pre-flight: trust Content-Length when present so we can fail fast + // without buffering 26 MiB before erroring. + let content_length = resp.content_length(); + if let Some(len) = content_length { + if len as usize > URL_FETCH_MAX_BYTES { + return Err(format!( + "URL fetch refused: Content-Length {len} exceeds cap {URL_FETCH_MAX_BYTES} for {}", + redact_url(parsed) + ) + .into()); + } + } + + let content_type = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .map(strip_mime_params) + .filter(|s| !s.is_empty()); + + // Pre-size the buffer: if we have a trustworthy Content-Length, use it + // (clamped to the cap); otherwise start at 64 KiB so the happy path + // doesn't pay ~24 doublings on a 25 MiB body. + let initial_cap = std::cmp::min( + content_length.unwrap_or(64 * 1024) as usize, + URL_FETCH_MAX_BYTES, + ); + let mut buf = bytes::BytesMut::with_capacity(initial_cap); + let mut resp = resp; + while let Some(chunk) = resp.chunk().await? { + if buf.len() + chunk.len() > URL_FETCH_MAX_BYTES { + return Err(format!( + "URL fetch refused: streamed body exceeds cap {URL_FETCH_MAX_BYTES} for {}", + redact_url(parsed) + ) + .into()); + } + buf.extend_from_slice(&chunk); + } + + Ok((buf.freeze(), content_type)) +} + #[async_trait] impl ChannelAdapter for DiscordAdapter { fn name(&self) -> &str { @@ -520,6 +1105,177 @@ impl ChannelAdapter for DiscordAdapter { ChannelContent::Text(text) => { self.api_send_message(channel_id, &text).await?; } + ChannelContent::FileData { + data, + filename, + mime_type, + } => { + self.api_send_attachment(channel_id, data, &filename, &mime_type, None) + .await?; + } + ChannelContent::File { + url, + filename, + mime, + size: _, + } => { + // Fetch then route through the existing multipart helper. + // `Fetcher::fetch` enforces SSRF + 15s timeout + 25 MiB cap. + let (bytes, response_ct) = self.fetcher.fetch(&url).await?; + let resolved_filename = resolve_file_filename(Some(filename.as_str()), &url); + let resolved_mime = + resolve_file_mime(mime.as_deref(), response_ct.as_deref(), &resolved_filename); + // No caption on a bare File; captions travel via Multipart([Text, File]). + self.api_send_attachment( + channel_id, + bytes, + &resolved_filename, + &resolved_mime, + None, + ) + .await?; + } + ChannelContent::Image { url, caption } => { + let (bytes, response_ct) = self.fetcher.fetch(&url).await?; + let resolved_mime = resolve_image_mime(response_ct.as_deref(), &url); + let resolved_filename = resolve_image_filename(&url, &resolved_mime); + let caption_ref = caption.as_deref().filter(|s| !s.is_empty()); + self.api_send_attachment( + channel_id, + bytes, + &resolved_filename, + &resolved_mime, + caption_ref, + ) + .await?; + } + ChannelContent::Multipart(parts) => { + // Single pass over `parts`: bucket each block into caption + // pieces, an `AttachmentSource` for later resolution, or a + // logged-unknown name. The two-pass classify/resolve split is + // collapsed by carrying enough info on `AttachmentSource` for + // the resolve step to operate on the typed intermediate alone. + let mut caption_pieces: Vec = Vec::new(); + let mut sources: Vec = Vec::with_capacity(parts.len()); + let mut unknown_names: Vec<&str> = Vec::new(); + + for part in parts { + match part { + ChannelContent::Text(t) => caption_pieces.push(t), + ChannelContent::FileData { + data, + filename, + mime_type, + } => sources.push(AttachmentSource::Resolved { + bytes: bytes::Bytes::from(data), + filename, + mime: mime_type, + }), + // Per-Image inner captions are ignored inside Multipart; + // the outer caption_pieces form the single caption. + ChannelContent::Image { url, caption: _ } => { + sources.push(AttachmentSource::UrlImage { url }) + } + ChannelContent::File { + url, + filename, + mime, + size: _, + } => sources.push(AttachmentSource::UrlFile { + url, + filename, + mime, + }), + ChannelContent::Voice { .. } => unknown_names.push("Voice"), + ChannelContent::Location { .. } => unknown_names.push("Location"), + ChannelContent::Command { .. } => unknown_names.push("Command"), + ChannelContent::Multipart(_) => unknown_names.push("Multipart"), + } + } + + if !unknown_names.is_empty() { + warn!( + "Discord Multipart: skipping unknown/unsupported nested variant(s): {:?}", + unknown_names + ); + } + + // Build the single caption string from all Text blocks. + let caption_str = caption_pieces.join("\n\n"); + let caption_str = caption_str.trim(); + let caption_opt: Option<&str> = if caption_str.is_empty() { + None + } else { + Some(caption_str) + }; + + // Resolve sources concurrently. `try_join_all` preserves + // input order in the output Vec (so `files[i]` still lines + // up with the source's original position) and fails fast on + // the first error, cancelling the rest — matching the spec + // and the previous serial behavior. For an N-URL Multipart + // this drops latency from sum-of-RTTs to max-of-RTT. + let attachments_resolved: Vec<(bytes::Bytes, String, String)> = + futures::future::try_join_all( + sources + .into_iter() + .map(|s| self.resolve_attachment_source(s)), + ) + .await + // Widen `Box` (needed so the + // resolve future is `Send` for `try_join_all`) to the + // looser `Box` returned by `send`. Unsizing + // a trait object by removing auto traits is allowed but + // not exposed via `From`, so we coerce explicitly. + .map_err(|e| -> Box { e })?; + + if attachments_resolved.is_empty() { + // Caption-only Multipart (all blocks were Text/unknown). + if let Some(cap) = caption_opt { + self.api_send_message(channel_id, cap).await?; + } else { + warn!("Discord Multipart: all blocks empty or unknown, nothing to send"); + } + return Ok(()); + } + + // Chunk by both count (≤ ATTACHMENTS_PER_CHUNK) and aggregate + // bytes (≤ CHUNK_TOTAL_CAP_BYTES) so a 10×3 MiB Multipart + // doesn't 413 on Discord's per-request size limit. Order is + // preserved; oversized single attachments get their own + // chunk (they'll still be rejected by Discord, but with the + // same error path as before this cap existed). + let chunks: Vec> = + chunk_attachments(attachments_resolved); + let total_chunks = chunks.len(); + for (i, chunk) in chunks.into_iter().enumerate() { + let chunk_caption = if i == 0 { caption_opt } else { None }; + if let Err(e) = self + .api_send_attachments(channel_id, chunk, chunk_caption) + .await + { + if i > 0 { + // Standalone WARN with structured fields so an + // operator grepping for "why are some files + // showing and some not?" can find this in one + // search instead of parsing prose. The failed + // chunk index is recoverable as `chunks_sent` + // (the count of chunks that succeeded before + // this one). + warn!( + event = "discord_multipart_partial_send", + chunks_sent = i, + chunks_total = total_chunks, + "discord multipart partial send: chunk {}/{} failed after {} chunk(s) already on the wire", + i + 1, + total_chunks, + i + ); + } + return Err(e); + } + } + } _ => { self.api_send_message(channel_id, "(Unsupported content type)") .await?; @@ -538,6 +1294,244 @@ impl ChannelAdapter for DiscordAdapter { } } +/// Maximum byte size for an attachment to be classified as a vision-eligible +/// image. Anthropic's image content blocks are capped at 5 MB; oversize images +/// fall through to `File` so the bridge passes the URL as text instead of +/// attempting an inline image block. +const VISION_IMAGE_MAX_BYTES: u64 = 5 * 1024 * 1024; + +/// Build the `payload_json` body for an outbound attachment request. +/// +/// Discord's `POST /channels/{id}/messages` multipart endpoint expects a +/// `payload_json` part containing the same JSON the JSON-only variant would +/// take. Captions longer than `DISCORD_MSG_LIMIT` chars must be truncated +/// explicitly; otherwise the API responds 400 and silently drops the upload. +/// Greedy-pack attachments into chunks subject to two caps: +/// +/// 1. At most [`ATTACHMENTS_PER_CHUNK`] entries per chunk (Discord's +/// `files[N]` limit). +/// 2. At most [`CHUNK_TOTAL_CAP_BYTES`] aggregate bytes per chunk (Discord's +/// ~25 MiB request size limit, with headroom for multipart overhead). +/// +/// Order of inputs is preserved across the output. If a single attachment +/// alone exceeds the byte cap, it lands in its own chunk and is forwarded +/// untouched — Discord will reject the request, but that mirrors the +/// pre-existing behavior where the per-file cap was the only gate. +fn chunk_attachments( + attachments: Vec<(bytes::Bytes, String, String)>, +) -> Vec> { + let mut chunks: Vec> = Vec::new(); + let mut current: Vec<(bytes::Bytes, String, String)> = Vec::new(); + let mut current_bytes: usize = 0; + + for item in attachments { + let item_len = item.0.len(); + // Start a new chunk when adding this item would push us over the + // count or byte cap — but only if the current chunk isn't empty. + // (Empty + oversized item: keep going so we always make progress.) + let would_exceed_count = current.len() >= ATTACHMENTS_PER_CHUNK; + let would_exceed_bytes = + current_bytes.saturating_add(item_len) > CHUNK_TOTAL_CAP_BYTES; + if !current.is_empty() && (would_exceed_count || would_exceed_bytes) { + chunks.push(std::mem::take(&mut current)); + current_bytes = 0; + } + current_bytes = current_bytes.saturating_add(item_len); + current.push(item); + } + if !current.is_empty() { + chunks.push(current); + } + chunks +} + +fn build_attachment_payload_json(caption: Option<&str>) -> String { + match caption { + Some(c) if !c.is_empty() => { + let truncated: String = c.chars().take(DISCORD_MSG_LIMIT).collect(); + serde_json::json!({ "content": truncated }).to_string() + } + _ => serde_json::json!({}).to_string(), + } +} + +/// Strip MIME parameters (e.g. `; charset=utf-8`) so downstream comparisons +/// against canonical types like `image/png` work. Lower-cases and trims so +/// `IMAGE/PNG ; charset=utf-8` and `image/png` both normalize to `image/png`. +fn strip_mime_params(raw: &str) -> String { + raw.split(';') + .next() + .unwrap_or("") + .trim() + .to_ascii_lowercase() +} + +/// Derive a filename from a URL path: take the segment after the last `/`, +/// drop any query/fragment, percent-decode best-effort. Returns None if the +/// URL has no useful path segment (e.g. `https://host/`). +fn derive_filename_from_url(url: &str) -> Option { + // Strip scheme://host. We only care about the path-ish suffix; doing + // this without a real URL parser keeps the helper dep-free and total + // (a malformed URL still gets a best-effort answer). + let after_scheme = url.split_once("://").map(|(_, r)| r).unwrap_or(url); + let path = after_scheme.split_once('/').map(|(_, r)| r).unwrap_or(""); + // Drop query and fragment. + let path = path.split(['?', '#']).next().unwrap_or(""); + let last = path.rsplit('/').next().unwrap_or(""); + if last.is_empty() { + return None; + } + // Percent-decode best-effort; fall back to the raw segment on failure. + let decoded = percent_decode_lossy(last); + if decoded.is_empty() { + None + } else { + Some(decoded) + } +} + +/// Tiny percent-decoder. We don't pull in `percent-encoding` for this — the +/// adapter already avoids new deps and we only need it to prettify Discord +/// CDN paths like `photo%20final.png` → `photo final.png`. +fn percent_decode_lossy(s: &str) -> String { + let bytes = s.as_bytes(); + let mut out = Vec::with_capacity(bytes.len()); + let mut i = 0; + while i < bytes.len() { + if bytes[i] == b'%' && i + 2 < bytes.len() { + let hi = (bytes[i + 1] as char).to_digit(16); + let lo = (bytes[i + 2] as char).to_digit(16); + if let (Some(h), Some(l)) = (hi, lo) { + out.push((h * 16 + l) as u8); + i += 3; + continue; + } + } + out.push(bytes[i]); + i += 1; + } + String::from_utf8_lossy(&out).into_owned() +} + +/// Pick a filename for an outbound `File` arm. Preference order: explicit +/// `filename` field → URL path tail → `"file"`. +fn resolve_file_filename(field: Option<&str>, url: &str) -> String { + field + .filter(|s| !s.is_empty()) + .map(str::to_string) + .or_else(|| derive_filename_from_url(url)) + .unwrap_or_else(|| "file".to_string()) +} + +/// Pick a MIME for an outbound `File` arm. Preference order: explicit `mime` +/// field → response Content-Type → extension lookup from filename → +/// `application/octet-stream`. +fn resolve_file_mime(field: Option<&str>, response_ct: Option<&str>, filename: &str) -> String { + field + .filter(|s| !s.is_empty()) + .map(str::to_string) + .or_else(|| response_ct.map(str::to_string)) + .or_else(|| mime_from_extension(filename).map(str::to_string)) + .unwrap_or_else(|| "application/octet-stream".to_string()) +} + +/// Pick a filename for an outbound `Image` arm. Preference order: URL path +/// tail → `"image" + extension` inferred from the resolved MIME (default +/// `.png`). +fn resolve_image_filename(url: &str, resolved_mime: &str) -> String { + if let Some(name) = derive_filename_from_url(url) { + return name; + } + let ext = match resolved_mime { + "image/jpeg" => ".jpg", + "image/gif" => ".gif", + "image/webp" => ".webp", + "image/heic" => ".heic", + "image/heif" => ".heif", + _ => ".png", + }; + format!("image{ext}") +} + +/// Pick a MIME for an outbound `Image` arm. Preference order: response +/// Content-Type → extension lookup from URL tail → `image/png`. +fn resolve_image_mime(response_ct: Option<&str>, url: &str) -> String { + if let Some(ct) = response_ct.filter(|s| !s.is_empty()) { + return ct.to_string(); + } + if let Some(tail) = derive_filename_from_url(url) { + if let Some(ext) = mime_from_extension(&tail) { + return ext.to_string(); + } + } + "image/png".to_string() +} + +/// Best-effort MIME inference from a filename extension. Used as a fallback +/// when Discord's `content_type` field is missing or empty (we've observed +/// this on some bot-relayed attachments). +fn mime_from_extension(filename: &str) -> Option<&'static str> { + let ext = filename.rsplit('.').next()?.to_ascii_lowercase(); + match ext.as_str() { + "jpg" | "jpeg" => Some("image/jpeg"), + "png" => Some("image/png"), + "gif" => Some("image/gif"), + "webp" => Some("image/webp"), + "heic" => Some("image/heic"), + "heif" => Some("image/heif"), + "pdf" => Some("application/pdf"), + "txt" => Some("text/plain"), + "md" => Some("text/markdown"), + "json" => Some("application/json"), + "mp4" => Some("video/mp4"), + "mov" => Some("video/quicktime"), + "mp3" => Some("audio/mpeg"), + "wav" => Some("audio/wav"), + "ogg" => Some("audio/ogg"), + _ => None, + } +} + +/// Classify a single Discord attachment JSON object into a `ChannelContent` +/// block. Vision-eligible image MIME types (jpeg/png/gif/webp) under +/// `VISION_IMAGE_MAX_BYTES` become `Image`; everything else becomes `File` +/// (URL-pass-through; the bridge will surface it as a text descriptor in v1). +/// +/// MIME resolution chain: `attachments[].content_type` (if non-empty) → +/// extension lookup → `application/octet-stream`. +fn classify_discord_attachment(att: &serde_json::Value) -> ChannelContent { + let url = att["url"].as_str().unwrap_or("").to_string(); + let filename = att["filename"].as_str().unwrap_or("file").to_string(); + let size = att["size"].as_u64(); + + let resolved_mime: String = att["content_type"] + .as_str() + .filter(|s| !s.is_empty()) + .map(str::to_string) + .or_else(|| mime_from_extension(&filename).map(str::to_string)) + .unwrap_or_else(|| "application/octet-stream".to_string()); + + let is_vision_mime = matches!( + resolved_mime.as_str(), + "image/jpeg" | "image/png" | "image/gif" | "image/webp" + ); + // If size is unknown, optimistically allow the image — the bridge will + // surface a 4xx if Anthropic rejects it, which is better than silently + // demoting to a text URL. + let within_vision_limit = size.map(|s| s <= VISION_IMAGE_MAX_BYTES).unwrap_or(true); + + if is_vision_mime && within_vision_limit { + ChannelContent::Image { url, caption: None } + } else { + ChannelContent::File { + url, + filename, + mime: Some(resolved_mime), + size, + } + } +} + /// Parse a Discord MESSAGE_CREATE or MESSAGE_UPDATE payload into a `ChannelMessage`. async fn parse_discord_message( d: &serde_json::Value, @@ -546,6 +1540,11 @@ async fn parse_discord_message( allowed_users: &[String], ignore_bots: bool, ) -> Option { + // Diagnostic: dump the raw Discord payload so we can ground attachment + // parsing in real JSON. Gated by RUST_LOG; silent at default `info` level. + // Enable with: RUST_LOG=openfang_channels::discord=debug + debug!(target: "openfang_channels::discord", payload = %d, "discord raw message payload"); + let author = d.get("author")?; let author_id = author["id"].as_str()?; @@ -577,10 +1576,6 @@ async fn parse_discord_message( } let content_text = d["content"].as_str().unwrap_or(""); - if content_text.is_empty() { - return None; - } - let channel_id = d["channel_id"].as_str()?; let message_id = d["id"].as_str().unwrap_or("0"); let username = author["username"].as_str().unwrap_or("Unknown"); @@ -597,7 +1592,8 @@ async fn parse_discord_message( .map(|dt| dt.with_timezone(&chrono::Utc)) .unwrap_or_else(chrono::Utc::now); - // Parse commands (messages starting with /) + // Parse commands (messages starting with /). Commands do not carry + // attachments in v1; attachment processing only runs in the non-command path. let content = if content_text.starts_with('/') { let parts: Vec<&str> = content_text.splitn(2, ' ').collect(); let cmd_name = &parts[0][1..]; @@ -611,7 +1607,50 @@ async fn parse_discord_message( args, } } else { - ChannelContent::Text(content_text.to_string()) + let attachment_blocks: Vec = d["attachments"] + .as_array() + .map(|arr| arr.iter().map(classify_discord_attachment).collect()) + .unwrap_or_default(); + + match (content_text.is_empty(), attachment_blocks.len()) { + // No text, no attachments → nothing to ingest. + (true, 0) => return None, + // Text only. + (false, 0) => ChannelContent::Text(content_text.to_string()), + // Single attachment, no caption. + (true, 1) => attachment_blocks.into_iter().next().unwrap(), + // Single attachment + caption: emit Multipart with the caption as + // a sibling Text block. This keeps the caption visible to providers + // that flatten content to text only (e.g. claude-code/*, which + // currently drops Image blocks) — the user gets a coherent + // text-only response instead of a hallucination. Vision-capable + // providers see the same blocks and dispatch multimodally. + (false, 1) => { + let block = attachment_blocks.into_iter().next().unwrap(); + let normalized = match block { + // Drop any caption that classify_discord_attachment may have + // attached; the sibling Text block is now the caption. + ChannelContent::Image { url, caption: _ } => { + ChannelContent::Image { url, caption: None } + } + other => other, + }; + ChannelContent::Multipart(vec![ + ChannelContent::Text(content_text.to_string()), + normalized, + ]) + } + // Multiple attachments, no caption. + (true, _) => ChannelContent::Multipart(attachment_blocks), + // Multiple attachments + caption: text first, then attachments + // (matches Discord's visual ordering: text above attachments). + (false, _) => { + let mut blocks = Vec::with_capacity(attachment_blocks.len() + 1); + blocks.push(ChannelContent::Text(content_text.to_string())); + blocks.extend(attachment_blocks); + ChannelContent::Multipart(blocks) + } + } }; // Determine if this is a group message (guild_id present = server channel) @@ -661,6 +1700,81 @@ async fn parse_discord_message( mod tests { use super::*; + #[test] + fn test_attachment_payload_no_caption() { + // No caption → empty JSON object so Discord doesn't reject it. + assert_eq!(build_attachment_payload_json(None), "{}"); + assert_eq!(build_attachment_payload_json(Some("")), "{}"); + } + + #[test] + fn test_attachment_payload_short_caption() { + let json = build_attachment_payload_json(Some("hello")); + let v: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(v["content"], "hello"); + } + + #[test] + fn test_attachment_payload_truncates_long_caption() { + // 3000 chars → must truncate to DISCORD_MSG_LIMIT (2000) so Discord + // accepts the request instead of 400-ing on a too-long content field. + let long = "a".repeat(3000); + let json = build_attachment_payload_json(Some(&long)); + let v: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!( + v["content"].as_str().unwrap().chars().count(), + DISCORD_MSG_LIMIT + ); + } + + #[test] + fn test_attachment_payload_truncation_is_char_safe() { + // Multibyte chars must not be split mid-codepoint. + let s: String = "héllo ".repeat(500); // 6 chars per chunk → 3000 chars total + let json = build_attachment_payload_json(Some(&s)); + let v: serde_json::Value = serde_json::from_str(&json).unwrap(); + // Round-trip through serde guarantees we didn't produce invalid UTF-8. + assert_eq!( + v["content"].as_str().unwrap().chars().count(), + DISCORD_MSG_LIMIT + ); + } + + #[test] + fn test_attachment_field_name_pinned() { + // Discord rejects the upload silently if the multipart field isn't + // exactly `files[0]` (a `file[0]` typo would fail at runtime, per + // attachment, with no useful error). Pin the wire format here so a + // typo at the call site is impossible without also changing this test. + // Both invariants matter: the constant's literal value AND the + // `format!("files[{i}]", i=0)` we now use at the call site must agree. + assert_eq!(ATTACHMENT_FIELD_NAME, "files[0]"); + assert_eq!(format!("files[{}]", 0), ATTACHMENT_FIELD_NAME); + } + + #[test] + fn test_multipart_part_accepts_common_mimes() { + // Validate that mime_str() doesn't reject the MIME types we map from + // tool_runner.rs::channel_send. If any of these started failing we'd + // surface as a runtime upload error per file. + for mime in [ + "image/png", + "image/jpeg", + "image/gif", + "image/webp", + "application/pdf", + "text/plain", + "application/json", + "application/octet-stream", + "video/mp4", + ] { + let part = reqwest::multipart::Part::bytes(b"x".to_vec()) + .file_name("f.bin") + .mime_str(mime); + assert!(part.is_ok(), "mime_str rejected {mime}"); + } + } + #[tokio::test] async fn test_parse_discord_message_basic() { let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); @@ -1032,4 +2146,1423 @@ mod tests { assert_eq!(adapter.name(), "discord"); assert_eq!(adapter.channel_type(), ChannelType::Discord); } + + // -- Multipart / attachment parsing tests (commit 4) ---------------------- + + fn att(filename: &str, content_type: Option<&str>, size: u64) -> serde_json::Value { + let mut obj = serde_json::json!({ + "url": format!("https://cdn.discordapp.com/attachments/1/2/{filename}"), + "filename": filename, + "size": size, + }); + if let Some(ct) = content_type { + obj["content_type"] = serde_json::Value::String(ct.to_string()); + } + obj + } + + fn payload_with(content: &str, attachments: Vec) -> serde_json::Value { + serde_json::json!({ + "id": "msg1", + "channel_id": "ch1", + "content": content, + "author": { + "id": "user456", + "username": "alice", + "discriminator": "0", + "bot": false + }, + "timestamp": "2024-01-01T00:00:00+00:00", + "attachments": attachments, + }) + } + + #[tokio::test] + async fn test_parse_image_only_no_caption() { + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + let d = payload_with("", vec![att("photo.png", Some("image/png"), 100_000)]); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + .await + .unwrap(); + match msg.content { + ChannelContent::Image { caption, url } => { + assert!(caption.is_none()); + assert!(url.contains("photo.png")); + } + other => panic!("expected Image, got {other:?}"), + } + } + + #[tokio::test] + async fn test_parse_image_with_caption() { + // Single image + caption is emitted as Multipart([Text, Image]) so the + // caption survives providers that flatten content blocks to text only + // (e.g. claude-code/*). The Image carries no caption of its own; the + // sibling Text block IS the caption. + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + let d = payload_with( + "look at this", + vec![att("photo.jpg", Some("image/jpeg"), 50_000)], + ); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + .await + .unwrap(); + match msg.content { + ChannelContent::Multipart(parts) => { + assert_eq!(parts.len(), 2); + assert!(matches!(&parts[0], ChannelContent::Text(t) if t == "look at this")); + match &parts[1] { + ChannelContent::Image { caption, url } => { + assert!( + caption.is_none(), + "image caption should be None; the sibling Text block is the caption" + ); + assert!(url.contains("photo.jpg")); + } + other => panic!("expected Image as second part, got {other:?}"), + } + } + other => panic!("expected Multipart, got {other:?}"), + } + } + + #[tokio::test] + async fn test_parse_multi_image_no_caption() { + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + let d = payload_with( + "", + vec![ + att("a.png", Some("image/png"), 10_000), + att("b.png", Some("image/png"), 20_000), + ], + ); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + .await + .unwrap(); + match msg.content { + ChannelContent::Multipart(parts) => { + assert_eq!(parts.len(), 2); + assert!(parts + .iter() + .all(|p| matches!(p, ChannelContent::Image { .. }))); + } + other => panic!("expected Multipart, got {other:?}"), + } + } + + #[tokio::test] + async fn test_parse_multi_image_with_caption() { + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + let d = payload_with( + "two pics", + vec![ + att("a.png", Some("image/png"), 10_000), + att("b.png", Some("image/png"), 20_000), + ], + ); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + .await + .unwrap(); + match msg.content { + ChannelContent::Multipart(parts) => { + assert_eq!(parts.len(), 3); + // Text first, then images. + assert!(matches!(&parts[0], ChannelContent::Text(t) if t == "two pics")); + assert!(matches!(&parts[1], ChannelContent::Image { .. })); + assert!(matches!(&parts[2], ChannelContent::Image { .. })); + } + other => panic!("expected Multipart, got {other:?}"), + } + } + + #[tokio::test] + async fn test_parse_heic_falls_to_file() { + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + let d = payload_with("", vec![att("photo.heic", Some("image/heic"), 100_000)]); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + .await + .unwrap(); + match msg.content { + ChannelContent::File { mime, filename, .. } => { + assert_eq!(filename, "photo.heic"); + assert_eq!(mime.as_deref(), Some("image/heic")); + } + other => panic!("expected File, got {other:?}"), + } + } + + #[tokio::test] + async fn test_parse_oversize_image_falls_to_file() { + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + // 6 MB exceeds VISION_IMAGE_MAX_BYTES (5 MB). + let d = payload_with( + "", + vec![att("huge.png", Some("image/png"), 6 * 1024 * 1024)], + ); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + .await + .unwrap(); + match msg.content { + ChannelContent::File { + filename, + mime, + size, + .. + } => { + assert_eq!(filename, "huge.png"); + assert_eq!(mime.as_deref(), Some("image/png")); + assert_eq!(size, Some(6 * 1024 * 1024)); + } + other => panic!("expected File, got {other:?}"), + } + } + + #[tokio::test] + async fn test_parse_file_with_caption_yields_multipart() { + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + let d = payload_with( + "see attached", + vec![att("doc.pdf", Some("application/pdf"), 200_000)], + ); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + .await + .unwrap(); + match msg.content { + ChannelContent::Multipart(parts) => { + assert_eq!(parts.len(), 2); + assert!(matches!(&parts[0], ChannelContent::Text(t) if t == "see attached")); + assert!(matches!(&parts[1], ChannelContent::File { .. })); + } + other => panic!("expected Multipart, got {other:?}"), + } + } + + #[tokio::test] + async fn test_parse_extension_fallback_when_content_type_missing() { + // Discord occasionally omits content_type on bot-relayed attachments; + // we should fall back to the filename extension. + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + let d = payload_with("", vec![att("pic.png", None, 50_000)]); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true) + .await + .unwrap(); + assert!(matches!(msg.content, ChannelContent::Image { .. })); + } + + // -- Pure helper tests for File/Image arm fallback chains -------------- + + #[test] + fn test_strip_mime_params_basic() { + assert_eq!(strip_mime_params("image/png"), "image/png"); + assert_eq!(strip_mime_params("image/png; charset=utf-8"), "image/png"); + assert_eq!( + strip_mime_params(" IMAGE/PNG ; charset=utf-8 "), + "image/png" + ); + assert_eq!(strip_mime_params(""), ""); + } + + #[test] + fn test_derive_filename_from_url() { + assert_eq!( + derive_filename_from_url("https://cdn.example.com/a/b/photo.png"), + Some("photo.png".to_string()) + ); + assert_eq!( + derive_filename_from_url("https://cdn.example.com/a/b/photo.png?ex=1&hm=2"), + Some("photo.png".to_string()) + ); + assert_eq!( + derive_filename_from_url("https://cdn.example.com/a/photo%20final.png"), + Some("photo final.png".to_string()) + ); + // Trailing slash → no filename derivable. + assert_eq!(derive_filename_from_url("https://cdn.example.com/"), None); + assert_eq!(derive_filename_from_url("https://cdn.example.com"), None); + } + + #[test] + fn test_resolve_file_filename_chain() { + // Field wins over URL. + assert_eq!( + resolve_file_filename(Some("explicit.bin"), "https://x/y/url.dat"), + "explicit.bin" + ); + // Empty field → URL fallback. + assert_eq!( + resolve_file_filename(Some(""), "https://x/y/url.dat"), + "url.dat" + ); + // None → URL fallback. + assert_eq!( + resolve_file_filename(None, "https://x/y/url.dat"), + "url.dat" + ); + // No URL tail → "file". + assert_eq!(resolve_file_filename(None, "https://x/"), "file"); + } + + #[test] + fn test_resolve_file_mime_chain() { + // Field wins. + assert_eq!( + resolve_file_mime(Some("application/pdf"), Some("text/plain"), "f.txt"), + "application/pdf" + ); + // No field → response Content-Type. + assert_eq!( + resolve_file_mime(None, Some("text/plain"), "f.txt"), + "text/plain" + ); + // No field, no CT → extension lookup. + assert_eq!(resolve_file_mime(None, None, "f.pdf"), "application/pdf"); + // Nothing → default. + assert_eq!( + resolve_file_mime(None, None, "no-ext"), + "application/octet-stream" + ); + } + + #[test] + fn test_resolve_image_filename_chain() { + // URL tail wins. + assert_eq!( + resolve_image_filename("https://x/y/picture.jpg", "image/jpeg"), + "picture.jpg" + ); + // No URL tail → image + ext from MIME. + assert_eq!( + resolve_image_filename("https://x/", "image/jpeg"), + "image.jpg" + ); + assert_eq!( + resolve_image_filename("https://x/", "image/png"), + "image.png" + ); + assert_eq!( + resolve_image_filename("https://x/", "image/webp"), + "image.webp" + ); + // Unknown MIME → .png default. + assert_eq!( + resolve_image_filename("https://x/", "application/octet-stream"), + "image.png" + ); + } + + #[test] + fn test_resolve_image_mime_chain() { + // Response CT wins. + assert_eq!( + resolve_image_mime(Some("image/jpeg"), "https://x/y/foo.png"), + "image/jpeg" + ); + // No CT → URL extension. + assert_eq!(resolve_image_mime(None, "https://x/y/foo.png"), "image/png"); + // No CT, no extension → default. + assert_eq!(resolve_image_mime(None, "https://x/y/blob"), "image/png"); + } + + // -- Fetcher::fetch / do_http_fetch size-cap tests ---------------------- + + /// Spawn a hand-rolled HTTP server that replies with a fixed status, + /// optional Content-Length header (lying or omitted), and a body produced + /// by the supplied closure. This intentionally does NOT use axum's + /// `Body::from(Vec)` (which always sends the real Content-Length); we + /// need control over the header to exercise both pre-flight and streaming + /// rejection paths. + async fn spawn_raw_http_server( + status_line: &'static str, + content_length: Option<&'static str>, + body: bytes::Bytes, + ) -> String { + use tokio::io::AsyncWriteExt; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + // Accept exactly one connection — sufficient for a single + // do_http_fetch test invocation. + let (mut sock, _) = match listener.accept().await { + Ok(p) => p, + Err(_) => return, + }; + // Drain the request line + headers (best-effort; we don't parse). + let mut buf = [0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut sock, &mut buf).await; + let mut header = format!("HTTP/1.1 {status_line}\r\n"); + header.push_str("Content-Type: application/octet-stream\r\n"); + if let Some(cl) = content_length { + header.push_str(&format!("Content-Length: {cl}\r\n")); + } + header.push_str("Connection: close\r\n\r\n"); + let _ = sock.write_all(header.as_bytes()).await; + let _ = sock.write_all(&body).await; + let _ = sock.shutdown().await; + }); + format!("http://{addr}/file") + } + + /// Like `spawn_raw_http_server` but with a configurable Content-Type + /// header and a configurable path suffix in the returned URL. Used by + /// mixed-type Multipart tests where different URLs must carry different + /// content-types (e.g. `image/png` for the Image block, `application/pdf` + /// for the File block). + async fn spawn_fixture_server( + content_type: &'static str, + path: &'static str, + body: bytes::Bytes, + ) -> String { + use tokio::io::AsyncWriteExt; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + let (mut sock, _) = match listener.accept().await { + Ok(p) => p, + Err(_) => return, + }; + let mut buf = [0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut sock, &mut buf).await; + let body_len = body.len(); + let header = format!( + "HTTP/1.1 200 OK\r\nContent-Type: {content_type}\r\nContent-Length: {body_len}\r\nConnection: close\r\n\r\n" + ); + let _ = sock.write_all(header.as_bytes()).await; + let _ = sock.write_all(&body).await; + let _ = sock.shutdown().await; + }); + format!("http://{addr}/{path}") + } + + /// Spawn a server that omits Content-Length (so the pre-flight CL check + /// can't short-circuit) and streams `actual_len` bytes in 1 MiB chunks + /// using HTTP/1.1 Connection: close framing. After each successful chunk + /// write, increments `bytes_sent`. The test asserts on the counter to + /// prove the client aborted *mid-stream* rather than buffering the + /// entire body and complaining at the end. + async fn spawn_chunked_streaming_server( + actual_len: usize, + bytes_sent: Arc, + ) -> String { + use std::sync::atomic::Ordering; + use tokio::io::AsyncWriteExt; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + let (mut sock, _) = match listener.accept().await { + Ok(p) => p, + Err(_) => return, + }; + let mut hbuf = [0u8; 4096]; + let _ = tokio::io::AsyncReadExt::read(&mut sock, &mut hbuf).await; + let header = "HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nConnection: close\r\n\r\n"; + if sock.write_all(header.as_bytes()).await.is_err() { + return; + } + let chunk = vec![0u8; 1024 * 1024]; + let mut written = 0usize; + while written < actual_len { + let n = std::cmp::min(chunk.len(), actual_len - written); + if sock.write_all(&chunk[..n]).await.is_err() { + // Client aborted (which is exactly what we expect once it + // hits the cap). Stop writing further chunks. + break; + } + // Flush so the kernel doesn't coalesce chunks beyond what the + // client has actually pulled — that would let the server + // "appear" to write 100 MiB instantly while the client has + // only consumed 25 MiB. With small SO_SNDBUF + flush the + // counter approximates client-consumed bytes. + let _ = sock.flush().await; + written += n; + bytes_sent.fetch_add(n, Ordering::SeqCst); + } + let _ = sock.shutdown().await; + }); + format!("http://{addr}/file") + } + + fn test_adapter() -> DiscordAdapter { + DiscordAdapter::new("test-token".into(), vec![], vec![], true, 0) + } + + /// Test-only helper: drive the wire-level fetch directly so tests against + /// 127.0.0.1 fixture servers don't trip the SSRF preflight. Production + /// callers always go through `Fetcher::fetch` and inherit the guard. + async fn test_fetch( + _adapter: &DiscordAdapter, + url: &str, + ) -> Result<(bytes::Bytes, Option), Box> { + let parsed = Url::parse(url).unwrap(); + do_http_fetch(&parsed).await + } + + #[tokio::test] + async fn test_download_size_cap_via_content_length() { + // Server advertises an oversized Content-Length and sends a tiny body. + // The adapter must reject before reading anything significant. + let oversized = (URL_FETCH_MAX_BYTES + 1).to_string(); + // Leak the string so we can hand &'static str to the spawn helper. + let cl: &'static str = Box::leak(oversized.into_boxed_str()); + let url = spawn_raw_http_server("200 OK", Some(cl), bytes::Bytes::from_static(b"x")).await; + + let adapter = test_adapter(); + let res = test_fetch(&adapter, &url).await; + assert!(res.is_err(), "expected Err on oversized Content-Length"); + let err = res.unwrap_err().to_string(); + assert!( + err.contains("Content-Length"), + "err should mention CL: {err}" + ); + } + + #[tokio::test] + async fn test_download_size_cap_via_streaming_aborts_midstream() { + // The strengthened version: server claims a believable Content-Length + // (just under the cap) so the pre-flight check passes, then streams + // chunks past the cap. We assert via a side-channel counter that the + // server stopped writing well before the full body went out — proving + // the client aborted mid-stream rather than buffering everything and + // erroring at the end. + use std::sync::atomic::Ordering; + let actual = URL_FETCH_MAX_BYTES * 4; // ~100 MiB worth of chunks queued + let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let url = spawn_chunked_streaming_server(actual, counter.clone()).await; + + let adapter = test_adapter(); + let res = test_fetch(&adapter, &url).await; + assert!(res.is_err(), "expected Err on oversized streamed body"); + let err = res.unwrap_err().to_string(); + assert!( + err.contains("streamed body exceeds cap"), + "err should mention streaming cap: {err}" + ); + // Give the server task a beat to observe the closed socket. + tokio::time::sleep(Duration::from_millis(250)).await; + let sent = counter.load(Ordering::SeqCst); + // Allow generous slack for kernel/userland buffering on top of the + // 25 MiB cap. The regression we're guarding against is "client buffers + // the entire 100 MiB then errors" — that would show ~100 MiB sent. + // Allowing up to 2*cap covers reasonable in-flight buffering without + // letting the regression slip through. + let allowed = URL_FETCH_MAX_BYTES * 2; + assert!( + sent <= allowed, + "server pushed {sent} bytes (allowed {allowed}); client did not abort mid-stream" + ); + assert!( + sent < actual, + "server pushed full payload ({sent} of {actual}); client did not abort mid-stream" + ); + } + + #[tokio::test] + async fn test_download_under_cap_succeeds_and_returns_bytes_and_ct() { + // Sanity check: a normal small payload returns the bytes and the + // stripped Content-Type. Uses axum so Content-Length is set correctly. + use axum::{routing::get, Router}; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let app = Router::new().route( + "/f", + get(|| async { + ( + [(axum::http::header::CONTENT_TYPE, "image/png; charset=utf-8")], + bytes::Bytes::from_static(b"PNGDATA"), + ) + }), + ); + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + let url = format!("http://{addr}/f"); + + let adapter = test_adapter(); + let (bytes, ct) = test_fetch(&adapter, &url).await.unwrap(); + assert_eq!(&bytes[..], b"PNGDATA"); + // Content-Type parameters must be stripped. + assert_eq!(ct.as_deref(), Some("image/png")); + } + + // -- SSRF guard tests --------------------------------------------------- + + async fn assert_ssrf_blocked(url: &str) { + let adapter = test_adapter(); + let res = adapter.fetcher.fetch(url).await; + let err = res + .err() + .unwrap_or_else(|| panic!("expected SSRF block for {url}")) + .to_string(); + assert!( + err.contains("refused") || err.contains("not allowed") || err.contains("blocked"), + "expected SSRF refusal for {url}, got: {err}" + ); + // The query string must not appear in the error (log-scrubbing). + assert!( + !err.contains("?"), + "SSRF error must not leak query string for {url}: {err}" + ); + } + + #[tokio::test] + async fn test_ssrf_blocks_loopback() { + assert_ssrf_blocked("http://127.0.0.1/secret?token=abc").await; + } + + #[tokio::test] + async fn test_ssrf_blocks_private_10() { + assert_ssrf_blocked("http://10.0.0.1/admin?key=v").await; + } + + #[tokio::test] + async fn test_ssrf_blocks_private_192() { + assert_ssrf_blocked("http://192.168.1.1/router?op=reboot").await; + } + + #[tokio::test] + async fn test_ssrf_blocks_link_local() { + // Cloud metadata: explicit canary URL from the spec. + assert_ssrf_blocked( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/role", + ) + .await; + } + + #[tokio::test] + async fn test_ssrf_blocks_non_http_scheme() { + for u in [ + "file:///etc/passwd", + "gopher://127.0.0.1:25/_HELO", + "ftp://example.com/x", + "data:text/plain,hello", + ] { + let adapter = test_adapter(); + let res = adapter.fetcher.fetch(u).await; + let err = res + .err() + .unwrap_or_else(|| panic!("expected scheme refusal for {u}")) + .to_string(); + assert!( + err.contains("scheme") || err.contains("refused"), + "expected scheme refusal for {u}: {err}" + ); + } + } + + #[tokio::test] + async fn test_ssrf_allows_public_ip_literal_check() { + // Validate the *check*, not the network round-trip: a public IP literal + // must pass `resolve_and_check_host` so we know the guard isn't + // accidentally over-blocking. + let u = Url::parse("http://1.1.1.1/").unwrap(); + resolve_and_check_host(&u) + .await + .expect("public IP must pass SSRF check"); + } + + #[tokio::test] + async fn test_ssrf_blocks_ipv6_loopback_and_metadata_mapped() { + // Bracketed IPv6 loopback. + assert_ssrf_blocked("http://[::1]/secret?x=1").await; + // IPv4-mapped IPv6 of the cloud metadata IP must also be blocked. + assert_ssrf_blocked("http://[::ffff:169.254.169.254]/latest?creds=1").await; + } + + #[test] + fn test_redact_url_strips_query() { + let u = Url::parse( + "https://cdn.discordapp.com/attachments/1/2/file.png?ex=abc&is=def&hm=secret#frag", + ) + .unwrap(); + let r = redact_url(&u); + assert_eq!(r, "https://cdn.discordapp.com/attachments/1/2/file.png"); + assert!(!r.contains("ex=")); + assert!(!r.contains("hm=")); + assert!(!r.contains("frag")); + } + + #[test] + fn test_is_blocked_v4_canary() { + // Explicit assertion that the cloud-metadata IP is rejected even if + // some future stdlib change widens or narrows `is_link_local`. + assert!(is_blocked_v4(Ipv4Addr::new(169, 254, 169, 254))); + assert!(is_blocked_v4(Ipv4Addr::new(127, 0, 0, 1))); + assert!(is_blocked_v4(Ipv4Addr::new(10, 0, 0, 1))); + assert!(is_blocked_v4(Ipv4Addr::new(192, 168, 1, 1))); + assert!(is_blocked_v4(Ipv4Addr::new(172, 16, 0, 1))); + assert!(is_blocked_v4(Ipv4Addr::new(0, 0, 0, 0))); + assert!(is_blocked_v4(Ipv4Addr::new(100, 64, 0, 1))); // CGNAT + // Public addresses must pass. + assert!(!is_blocked_v4(Ipv4Addr::new(1, 1, 1, 1))); + assert!(!is_blocked_v4(Ipv4Addr::new(8, 8, 8, 8))); + } + + #[tokio::test] + async fn test_parse_empty_message_with_no_attachments_returns_none() { + let bot_id = Arc::new(RwLock::new(Some("bot123".to_string()))); + let d = payload_with("", vec![]); + let msg = parse_discord_message(&d, &bot_id, &[], &[], true).await; + assert!(msg.is_none()); + } + + // ========================================================================== + // Outbound Multipart send() tests + // ========================================================================== + // + // Test helpers: spin up an axum stub that accepts multipart POSTs to + // `/channels/:id/messages`, captures the `payload_json` field text and + // the number of file parts, and stores them in a shared Arc>. + // We point the adapter at the stub via `api_base_override`. + + use tokio::sync::Mutex as TokioMutex; + + #[derive(Debug, Default, Clone)] + struct CapturedFile { + field_name: String, + filename: Option, + content_type: Option, + } + + #[derive(Debug, Default, Clone)] + struct CapturedPost { + payload_json: String, + /// Bare field names (legacy, kept so existing assertions continue to compile). + file_field_names: Vec, + /// Richer per-file metadata captured from each `files[*]` part. + files: Vec, + } + + /// Build an axum stub that captures one or more multipart POSTs to + /// `/channels/test/messages` and records them into `captured`. + async fn spawn_discord_stub(captured: Arc>>) -> String { + use axum::{ + extract::{DefaultBodyLimit, Multipart}, + http::StatusCode, + routing::post, + Extension, Router, + }; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = Router::new() + .route( + "/channels/test/messages", + post( + |Extension(store): Extension>>>, + mut multipart: Multipart| async move { + let mut post = CapturedPost::default(); + while let Ok(Some(field)) = multipart.next_field().await { + let name = field.name().unwrap_or("").to_string(); + if name == "payload_json" { + post.payload_json = field.text().await.unwrap_or_default(); + } else { + let filename = field.file_name().map(str::to_string); + let content_type = field.content_type().map(str::to_string); + // Drain the file bytes so axum doesn't error. + let _ = field.bytes().await; + post.files.push(CapturedFile { + field_name: name.clone(), + filename, + content_type, + }); + post.file_field_names.push(name); + } + } + store.lock().await.push(post); + StatusCode::OK + }, + ), + ) + // Default 2 MiB body limit would reject the byte-cap chunking + // test's ~20 MiB chunks; disable it on the stub. + .layer(DefaultBodyLimit::disable()) + .layer(Extension(captured)); + + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + format!("http://{addr}") + } + + fn make_channel_user(channel_id: &str) -> ChannelUser { + ChannelUser { + platform_id: channel_id.to_string(), + display_name: "test-user".to_string(), + openfang_user: None, + } + } + + fn test_adapter_with_base(base: String) -> DiscordAdapter { + let mut a = test_adapter(); + a.api_base_override = Some(base); + a + } + + /// Like `test_adapter_with_base` but installs [`PermissiveFetcher`] so + /// `Image{url}` / `File{url}` blocks pointing at localhost stub servers + /// can flow through the normal `Fetcher::fetch` path without tripping + /// the SSRF preflight. + fn test_adapter_with_base_and_ssrf_bypass(base: String) -> DiscordAdapter { + let mut a = test_adapter_with_base(base); + a.fetcher = Arc::new(PermissiveFetcher); + a + } + + // ---- required test a: caption concatenation -------------------------------- + + #[tokio::test] + async fn test_multipart_outbound_caption_concatenation() { + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let base = spawn_discord_stub(captured.clone()).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + let content = ChannelContent::Multipart(vec![ + ChannelContent::Text("hello".to_string()), + ChannelContent::Text("world".to_string()), + ChannelContent::FileData { + data: b"payload".to_vec(), + filename: "file.txt".to_string(), + mime_type: "text/plain".to_string(), + }, + ]); + + adapter.send(&user, content).await.unwrap(); + + let posts = captured.lock().await; + assert_eq!(posts.len(), 1, "expected exactly one POST"); + let v: serde_json::Value = serde_json::from_str(&posts[0].payload_json).unwrap(); + assert_eq!( + v["content"].as_str().unwrap_or(""), + "hello\n\nworld", + "caption should be the two Text blocks joined by \\n\\n" + ); + } + + // ---- required test b: empty/whitespace caption suppressed ------------------ + // + // Image URL fetches go through the SSRF guard which blocks 127.0.0.1, so + // this test uses FileData to avoid the network requirement while still + // exercising the caption-suppression logic path. + + #[tokio::test] + async fn test_multipart_outbound_empty_caption_suppressed() { + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let base = spawn_discord_stub(captured.clone()).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + let content = ChannelContent::Multipart(vec![ + ChannelContent::Text("".to_string()), + ChannelContent::Text(" ".to_string()), + ChannelContent::FileData { + data: b"bytes".to_vec(), + filename: "f.bin".to_string(), + mime_type: "application/octet-stream".to_string(), + }, + ]); + adapter.send(&user, content).await.unwrap(); + + let posts = captured.lock().await; + assert_eq!(posts.len(), 1, "expected one POST"); + let v: serde_json::Value = serde_json::from_str(&posts[0].payload_json).unwrap(); + assert!( + v.get("content").is_none(), + "empty/whitespace caption must produce payload_json without 'content' field; got: {}", + posts[0].payload_json + ); + } + + // ---- required test c: chunking >10 ---------------------------------------- + + #[tokio::test] + async fn test_multipart_outbound_chunking_gt10() { + // 23 FileData blocks should produce ceil(23/10) = 3 POSTs. + // First chunk: caption + files[0..10) (10 files) + // Second chunk: no caption + files[0..10) (10 files) + // Third chunk: no caption + files[0..3) (3 files) + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let base = spawn_discord_stub(captured.clone()).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + let mut parts = vec![ChannelContent::Text("cap".to_string())]; + for i in 0..23u32 { + parts.push(ChannelContent::FileData { + data: format!("data{i}").into_bytes(), + filename: format!("f{i}.txt"), + mime_type: "text/plain".to_string(), + }); + } + + adapter + .send(&user, ChannelContent::Multipart(parts)) + .await + .unwrap(); + + let posts = captured.lock().await; + assert_eq!( + posts.len(), + 3, + "23 files should produce 3 POSTs (chunks of 10)" + ); + + // First chunk carries the caption. + let v0: serde_json::Value = serde_json::from_str(&posts[0].payload_json).unwrap(); + assert_eq!( + v0["content"].as_str().unwrap_or(""), + "cap", + "first chunk must carry the caption" + ); + assert_eq!( + posts[0].file_field_names.len(), + 10, + "first chunk must have 10 files" + ); + + // Second chunk has no caption. + let v1: serde_json::Value = serde_json::from_str(&posts[1].payload_json).unwrap(); + assert!( + v1.get("content").is_none(), + "second chunk must not carry the caption" + ); + assert_eq!( + posts[1].file_field_names.len(), + 10, + "second chunk must have 10 files" + ); + + // Third chunk has no caption and only 3 files. + let v2: serde_json::Value = serde_json::from_str(&posts[2].payload_json).unwrap(); + assert!( + v2.get("content").is_none(), + "third chunk must not carry the caption" + ); + assert_eq!( + posts[2].file_field_names.len(), + 3, + "third chunk must have 3 files" + ); + } + + // ---- required test d: caption-only fallback -------------------------------- + + /// Checks that a Multipart with only Text blocks sends exactly one plain + /// text message (no multipart POST) via the `api_send_message` path. + #[tokio::test] + async fn test_multipart_outbound_caption_only_fallback() { + // The Discord stub only handles `/channels/test/messages` POSTs. + // api_send_message sends JSON (not multipart), so we use a simple + // axum stub that accepts any POST and records the Content-Type. + use axum::{extract::Request, http::StatusCode, routing::post, Extension, Router}; + + let calls: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let calls_clone = calls.clone(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let app = + Router::new() + .route( + "/channels/test/messages", + post( + |Extension(store): Extension>>>, + req: Request| async move { + let ct = req + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + store.lock().await.push(ct); + StatusCode::OK + }, + ), + ) + .layer(Extension(calls_clone)); + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + + let adapter = test_adapter_with_base(format!("http://{addr}")); + let user = make_channel_user("test"); + + let content = + ChannelContent::Multipart(vec![ChannelContent::Text("only text".to_string())]); + adapter.send(&user, content).await.unwrap(); + + let cts = calls.lock().await; + assert_eq!(cts.len(), 1, "expected exactly one POST for caption-only"); + // Plain message (JSON), not multipart. + assert!( + cts[0].contains("application/json"), + "caption-only should send JSON, not multipart; content-type was: {}", + cts[0] + ); + } + + // ---- required test e: mixed Image{url}+File{url} resolver dispatch ---------- + + /// Verifies that a `Multipart([Text, Image{url}, File{url}])` block routes + /// each attachment through the correct resolver branch: + /// + /// - `Image{url}` → `resolve_image_mime` / `resolve_image_filename`: + /// the response Content-Type is used as-is and the filename is derived + /// from the URL path or inferred from the MIME (e.g. `image.png`). + /// + /// - `File{url, filename, mime}` → `resolve_file_mime` / + /// `resolve_file_filename`: the explicitly supplied filename and MIME + /// from the `File{}` block take precedence over the server's + /// Content-Type. + /// + /// The test spins up two local HTTP fixture servers (bypassing the SSRF + /// guard via `ssrf_bypass`), one per URL, then asserts on the + /// per-part filename and Content-Type captured by the Discord stub. + #[tokio::test] + async fn test_multipart_outbound_mixed_types_single_post() { + // Spawn a fixture server for the Image block — serves image/png bytes. + let image_url = spawn_fixture_server( + "image/png", + "photo.png", + bytes::Bytes::from_static(b"\x89PNG\r\n\x1a\n"), // minimal PNG magic + ) + .await; + + // Spawn a fixture server for the File block — serves application/pdf + // bytes. The File block supplies an explicit filename and MIME so the + // resolver must prefer those over the server's Content-Type. + let file_url = spawn_fixture_server( + "application/octet-stream", // server sends generic; resolver should prefer field mime + "ignored-server-name.bin", + bytes::Bytes::from_static(b"%PDF-1.4"), + ) + .await; + + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let base = spawn_discord_stub(captured.clone()).await; + let adapter = test_adapter_with_base_and_ssrf_bypass(base); + let user = make_channel_user("test"); + + let content = ChannelContent::Multipart(vec![ + ChannelContent::Text("mixed".to_string()), + ChannelContent::Image { + url: image_url.clone(), + caption: None, + }, + ChannelContent::File { + url: file_url.clone(), + filename: "report.pdf".to_string(), + mime: Some("application/pdf".to_string()), + size: None, + }, + ]); + + adapter.send(&user, content).await.unwrap(); + + let posts = captured.lock().await; + assert_eq!( + posts.len(), + 1, + "expected exactly one POST for mixed Multipart" + ); + + // Caption preserved. + let v: serde_json::Value = serde_json::from_str(&posts[0].payload_json).unwrap(); + assert_eq!(v["content"].as_str().unwrap_or(""), "mixed"); + + // Both files appeared in a single POST. + assert_eq!( + posts[0].files.len(), + 2, + "expected two file parts in the POST" + ); + + // ---- Image block assertions ---- + // resolve_image_mime: server sent image/png → resolved mime = "image/png" + // resolve_image_filename: URL path tail is "photo.png" → filename = "photo.png" + let img_part = posts[0] + .files + .iter() + .find(|f| f.field_name == "files[0]") + .expect("files[0] must be present"); + assert_eq!( + img_part.content_type.as_deref(), + Some("image/png"), + "Image block must use resolve_image_mime (server Content-Type preserved)" + ); + assert_eq!( + img_part.filename.as_deref(), + Some("photo.png"), + "Image block must use resolve_image_filename (URL path tail)" + ); + + // ---- File block assertions ---- + // resolve_file_filename: explicit filename "report.pdf" takes precedence over URL + // resolve_file_mime: explicit mime "application/pdf" takes precedence over server CT + let file_part = posts[0] + .files + .iter() + .find(|f| f.field_name == "files[1]") + .expect("files[1] must be present"); + assert_eq!( + file_part.content_type.as_deref(), + Some("application/pdf"), + "File block must use resolve_file_mime (explicit mime from File{{}} block)" + ); + assert_eq!( + file_part.filename.as_deref(), + Some("report.pdf"), + "File block must use resolve_file_filename (explicit filename from File{{}} block)" + ); + } + + // ---- should-have: mid-batch fetch failure ---------------------------------- + + #[tokio::test] + async fn test_multipart_outbound_fetch_failure_returns_err() { + // A File block with an SSRF-blocked URL should cause send() to return Err. + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let base = spawn_discord_stub(captured.clone()).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + let content = ChannelContent::Multipart(vec![ + ChannelContent::Text("cap".to_string()), + ChannelContent::File { + url: "http://127.0.0.1/secret".to_string(), + filename: "s.txt".to_string(), + mime: None, + size: None, + }, + ]); + + let result = adapter.send(&user, content).await; + assert!(result.is_err(), "expected Err on fetch failure"); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("Multipart fetch failed") || err.contains("refused"), + "error should mention failing fetch; got: {err}" + ); + // No POST should have been made (fetch failed before send). + let posts = captured.lock().await; + assert!(posts.is_empty(), "no POST should occur if fetch fails"); + } + + // ---- should-have: empty Multipart ------------------------------------------ + + #[tokio::test] + async fn test_multipart_outbound_empty_is_ok_no_posts() { + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let base = spawn_discord_stub(captured.clone()).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + let result = adapter.send(&user, ChannelContent::Multipart(vec![])).await; + assert!(result.is_ok(), "empty Multipart should return Ok"); + let posts = captured.lock().await; + assert!( + posts.is_empty(), + "empty Multipart must not produce any POSTs" + ); + } + + // ---- should-have: unknown nested variant is logged, not fatal -------------- + + #[tokio::test] + async fn test_multipart_outbound_unknown_nested_variant_skipped() { + // A Multipart containing a nested Multipart (and a FileData) should + // warn but still send the FileData. + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let base = spawn_discord_stub(captured.clone()).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + let content = ChannelContent::Multipart(vec![ + ChannelContent::Text("x".to_string()), + ChannelContent::Multipart(vec![]), // unknown nesting + ChannelContent::FileData { + data: b"f".to_vec(), + filename: "f.txt".to_string(), + mime_type: "text/plain".to_string(), + }, + ]); + + let result = adapter.send(&user, content).await; + assert!(result.is_ok(), "unknown nested variant must not be fatal"); + let posts = captured.lock().await; + assert_eq!(posts.len(), 1, "FileData should still be sent"); + } + + // ---- multi-file 429 retry -------------------------------------------------- + + /// Spawn a stub at `/channels/test/messages` that returns + /// `first_response` on attempt 0 and 200 OK on every subsequent attempt. + /// Captures every POST's parsed multipart fields into the returned + /// `Arc<...Vec>` for assertions. Used by the 429 retry + /// tests to vary only the 429 response shape (body+header vs header-only) + /// while sharing the rest of the scaffolding. + async fn spawn_429_then_ok_stub( + first_response: Arc< + dyn Fn() -> axum::response::Response + Send + Sync + 'static, + >, + ) -> (String, Arc>>) { + use axum::{ + extract::{DefaultBodyLimit, Multipart}, + http::StatusCode, + response::IntoResponse, + routing::post, + Extension, Router, + }; + use std::sync::atomic::{AtomicUsize, Ordering}; + + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let attempt = Arc::new(AtomicUsize::new(0)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let captured_clone = captured.clone(); + let attempt_clone = attempt.clone(); + let app = Router::new() + .route( + "/channels/test/messages", + post( + move |Extension(_): Extension<()>, mut multipart: Multipart| { + let captured = captured_clone.clone(); + let attempt = attempt_clone.clone(); + let first_response = first_response.clone(); + async move { + let n = attempt.fetch_add(1, Ordering::SeqCst); + let mut post_rec = CapturedPost::default(); + while let Ok(Some(field)) = multipart.next_field().await { + let name = field.name().unwrap_or("").to_string(); + if name == "payload_json" { + post_rec.payload_json = + field.text().await.unwrap_or_default(); + } else { + let filename = field.file_name().map(str::to_string); + let content_type = field.content_type().map(str::to_string); + let _ = field.bytes().await; + post_rec.files.push(CapturedFile { + field_name: name.clone(), + filename, + content_type, + }); + post_rec.file_field_names.push(name); + } + } + captured.lock().await.push(post_rec); + if n == 0 { + first_response() + } else { + StatusCode::OK.into_response() + } + } + }, + ), + ) + .layer(DefaultBodyLimit::disable()) + .layer(Extension(())); + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + (format!("http://{addr}"), captured) + } + + /// Build a `ChannelContent::Multipart` with `n` `FileData` blocks named + /// `a.txt`, `b.txt`, … (for tests that only care about field-name + /// ordering, not content). `n` must be ≤ 26. + fn caption_plus_n_files(caption: &str, n: usize) -> ChannelContent { + assert!(n <= 26, "caption_plus_n_files: n must fit in a-z"); + let mut parts = vec![ChannelContent::Text(caption.to_string())]; + for i in 0..n { + let ch = (b'a' + i as u8) as char; + parts.push(ChannelContent::FileData { + data: vec![ch as u8], + filename: format!("{ch}.txt"), + mime_type: "text/plain".to_string(), + }); + } + ChannelContent::Multipart(parts) + } + + /// Assert every captured POST's multipart fields are exactly + /// `["files[0]", "files[1]", …, "files[n-1]"]`. + async fn assert_all_attempts_carry_files( + captured: &Arc>>, + n: usize, + ) { + let expected: Vec = (0..n).map(|i| format!("files[{i}]")).collect(); + let posts = captured.lock().await; + for (i, p) in posts.iter().enumerate() { + assert_eq!( + p.file_field_names, expected, + "attempt {i} must include files[0..{n})" + ); + } + } + + /// 429 response with both `Retry-After: 0` header and a JSON body + /// containing `retry_after: 0.0`. Sending a 3-attachment Multipart + /// must produce exactly 2 POSTs and both must carry the full file set. + /// Locks in the body-aware retry path (body wins over header per the + /// adapter's `body_secs.or(header_secs)`). + #[tokio::test] + async fn test_multipart_outbound_multifile_429_retries_once() { + use axum::{http::StatusCode, response::IntoResponse}; + let first: Arc axum::response::Response + Send + Sync> = Arc::new(|| { + ( + StatusCode::TOO_MANY_REQUESTS, + [(axum::http::header::RETRY_AFTER, "0")], + r#"{"retry_after":0.0,"global":false}"#, + ) + .into_response() + }); + let (base, captured) = spawn_429_then_ok_stub(first).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + adapter + .send(&user, caption_plus_n_files("cap", 3)) + .await + .unwrap(); + + assert_eq!( + captured.lock().await.len(), + 2, + "expected 2 POSTs (one 429-rejected, one 200) for the same chunk" + ); + assert_all_attempts_carry_files(&captured, 3).await; + } + + /// 429 response with **only** the `Retry-After` header (empty body). + /// The header-fallback path (`body_secs.or(header_secs)`) must still + /// trigger the retry, so a regression that drops header parsing fails + /// here independently of the body-present test. + #[tokio::test] + async fn test_multipart_outbound_multifile_429_header_only_retries_once() { + use axum::{http::StatusCode, response::IntoResponse}; + let first: Arc axum::response::Response + Send + Sync> = Arc::new(|| { + ( + StatusCode::TOO_MANY_REQUESTS, + [(axum::http::header::RETRY_AFTER, "0")], + "", + ) + .into_response() + }); + let (base, captured) = spawn_429_then_ok_stub(first).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + adapter + .send(&user, caption_plus_n_files("cap", 2)) + .await + .unwrap(); + + assert_eq!( + captured.lock().await.len(), + 2, + "header-only 429 must still trigger one retry" + ); + assert_all_attempts_carry_files(&captured, 2).await; + } + + // ---- aggregate per-chunk byte cap ------------------------------------------ + + /// Three 10 MiB FileData blocks must split into two chunks under the + /// 24 MiB per-chunk byte cap (20 MiB + 10 MiB). The caption rides only + /// on the first chunk; chunk-2 has no caption. + #[tokio::test] + async fn test_multipart_outbound_chunking_by_byte_cap() { + let captured: Arc>> = Arc::new(TokioMutex::new(Vec::new())); + let base = spawn_discord_stub(captured.clone()).await; + let adapter = test_adapter_with_base(base); + let user = make_channel_user("test"); + + let big = vec![0u8; 10 * 1024 * 1024]; + let parts = vec![ + ChannelContent::Text("cap".to_string()), + ChannelContent::FileData { + data: big.clone(), + filename: "a.bin".to_string(), + mime_type: "application/octet-stream".to_string(), + }, + ChannelContent::FileData { + data: big.clone(), + filename: "b.bin".to_string(), + mime_type: "application/octet-stream".to_string(), + }, + ChannelContent::FileData { + data: big, + filename: "c.bin".to_string(), + mime_type: "application/octet-stream".to_string(), + }, + ]; + + adapter + .send(&user, ChannelContent::Multipart(parts)) + .await + .unwrap(); + + let posts = captured.lock().await; + assert_eq!( + posts.len(), + 2, + "3×10 MiB attachments should split into 2 chunks under the 24 MiB cap" + ); + // Chunk 1: caption + 2 files (a, b). + let v0: serde_json::Value = serde_json::from_str(&posts[0].payload_json).unwrap(); + assert_eq!(v0["content"].as_str().unwrap_or(""), "cap"); + assert_eq!(posts[0].files.len(), 2, "first chunk holds first 2 files"); + // Chunk 2: no caption + 1 file (c). + let v1: serde_json::Value = serde_json::from_str(&posts[1].payload_json).unwrap(); + assert!( + v1.get("content").is_none(), + "second chunk must not carry the caption" + ); + assert_eq!(posts[1].files.len(), 1, "second chunk holds the 3rd file"); + } + + /// Direct unit test of the chunking helper: verifies count cap, byte cap, + /// and the oversize-single-attachment edge case (lands in its own chunk + /// instead of stalling progress). + #[test] + fn test_chunk_attachments_count_and_byte_caps() { + // 12 small files → 2 chunks of 10 + 2 (count cap dominates). + let small: Vec<_> = (0..12) + .map(|i| { + ( + bytes::Bytes::from(vec![0u8; 1024]), + format!("f{i}.bin"), + "application/octet-stream".to_string(), + ) + }) + .collect(); + let chunks = chunk_attachments(small); + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0].len(), 10); + assert_eq!(chunks[1].len(), 2); + + // Three 10 MiB items → 2 chunks (byte cap dominates: 20 + 10 ≤ 24). + let big_payload = bytes::Bytes::from(vec![0u8; 10 * 1024 * 1024]); + let big = vec![ + (big_payload.clone(), "a".to_string(), "x".to_string()), + (big_payload.clone(), "b".to_string(), "x".to_string()), + (big_payload.clone(), "c".to_string(), "x".to_string()), + ]; + let chunks = chunk_attachments(big); + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0].len(), 2); + assert_eq!(chunks[1].len(), 1); + + // One oversized attachment by itself → single chunk holding it. + // (Discord rejects, but the helper mustn't loop forever or drop it.) + let oversized = bytes::Bytes::from(vec![0u8; CHUNK_TOTAL_CAP_BYTES + 1]); + let solo = vec![(oversized, "huge".to_string(), "x".to_string())]; + let chunks = chunk_attachments(solo); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].len(), 1); + + // Empty input → no chunks. + let chunks = chunk_attachments(Vec::new()); + assert!(chunks.is_empty()); + } } diff --git a/crates/openfang-channels/src/lib.rs b/crates/openfang-channels/src/lib.rs index 7b122d2a2..8deb77f6c 100644 --- a/crates/openfang-channels/src/lib.rs +++ b/crates/openfang-channels/src/lib.rs @@ -8,6 +8,7 @@ pub mod discord; pub mod email; pub mod formatter; pub mod google_chat; +pub mod outbound_attach; pub mod irc; pub mod matrix; pub mod mattermost; diff --git a/crates/openfang-channels/src/outbound_attach.rs b/crates/openfang-channels/src/outbound_attach.rs new file mode 100644 index 000000000..002527a60 --- /dev/null +++ b/crates/openfang-channels/src/outbound_attach.rs @@ -0,0 +1,536 @@ +//! Outbound attachment parser. +//! +//! Recognises `` markers in agent response text, validates each path +//! against an allow-root, reads the bytes, and produces +//! `ChannelContent::FileData` blocks that the wire layer (`discord::send`, +//! `telegram::send`, …) already knows how to chunk and upload. +//! +//! ## Marker syntax +//! +//! ```text +//! +//! +//! +//! ``` +//! +//! All attribute values use double quotes. The marker is self-closing. +//! Multiple markers per response are supported up to Discord's 10-attachment +//! per-message cap; the wire-layer chunker handles aggregate-size splitting. +//! +//! ## Security +//! +//! Paths are canonicalised (so symlinks are resolved) and must lie under +//! one of the allow-roots — by default `$HOME/.openfang/`. This covers the +//! ephemeral `~/.openfang/tmp/` scratch area and per-agent +//! `~/.openfang/workspaces//` directories without leaking access to +//! the rest of the filesystem in the face of a prompt-injected agent. +//! +//! ## Failure mode +//! +//! Per-directive errors (path missing, outside allow-root, oversized) are +//! logged at WARN and the marker is silently dropped from the outgoing +//! message — partial success rather than failing the whole reply. If every +//! directive fails the caller still gets the stripped text back, so the +//! user sees the prose without the broken markers. + +use crate::types::ChannelContent; +use regex_lite::Regex; +use std::path::{Path, PathBuf}; +use std::sync::OnceLock; +use tracing::warn; + +/// Per-attachment hard cap. Discord allows 25 MiB per request on the free +/// tier; we cap each file at 25 MiB and rely on the wire-layer chunker +/// (24 MiB aggregate, 10 attachments per chunk in `discord::send`) to split +/// large multi-file responses across several messages. +const MAX_FILE_BYTES: u64 = 25 * 1024 * 1024; + +/// Hard cap on directives parsed from a single response. Discord refuses +/// more than 10 attachments per message; the chunker bucket-splits but +/// there's no point parsing further. +const MAX_ATTACHMENTS_PER_MESSAGE: usize = 10; + +/// Outcome of parsing an outbound response. +pub enum Parsed { + /// No `` marker present. Caller should take the + /// normal text-only path. + NoMarkers, + /// At least one marker was found. `stripped_text` is the original text + /// with all markers removed and any `caption=` values appended. `files` + /// is the resolved `FileData` blocks (possibly empty if every directive + /// failed validation). + WithAttachments { + stripped_text: String, + files: Vec, + }, +} + +fn marker_regex() -> &'static Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| { + Regex::new(r#"]*?)/>"#).expect("marker regex compiles") + }) +} + +fn attr_regex() -> &'static Regex { + static RE: OnceLock = OnceLock::new(); + RE.get_or_init(|| Regex::new(r#"(\w+)\s*=\s*"([^"]*)""#).expect("attr regex compiles")) +} + +#[derive(Debug)] +struct AttachDirective { + path: String, + name: Option, + spoiler: bool, + caption: Option, +} + +fn parse_directive(attrs: &str) -> Option { + let mut path = None; + let mut name = None; + let mut spoiler = false; + let mut caption = None; + for cap in attr_regex().captures_iter(attrs) { + let key = cap.get(1)?.as_str(); + let val = cap.get(2)?.as_str().to_string(); + match key { + "path" => path = Some(val), + "name" => name = Some(val), + "spoiler" => spoiler = matches!(val.as_str(), "true" | "1" | "yes"), + "caption" => caption = Some(val), + _ => {} + } + } + Some(AttachDirective { + path: path?, + name, + spoiler, + caption, + }) +} + +/// Extension → MIME type. Mirrors the table used by `tool_runner` for +/// `channel_send`'s `file_path` parameter so inbound and outbound paths +/// agree on the wire-format. Unknown extensions fall back to +/// `application/octet-stream`. +fn mime_from_extension(path: &Path) -> &'static str { + let ext = path + .extension() + .and_then(|e| e.to_str()) + .unwrap_or("") + .to_lowercase(); + match ext.as_str() { + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + "gif" => "image/gif", + "webp" => "image/webp", + "svg" => "image/svg+xml", + "pdf" => "application/pdf", + "txt" | "md" | "log" => "text/plain", + "csv" => "text/csv", + "json" => "application/json", + "xml" => "application/xml", + "zip" => "application/zip", + "gz" | "gzip" => "application/gzip", + "tar" => "application/x-tar", + "mp3" => "audio/mpeg", + "wav" => "audio/wav", + "mp4" => "video/mp4", + "doc" => "application/msword", + "docx" => "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "xls" => "application/vnd.ms-excel", + "xlsx" => "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + _ => "application/octet-stream", + } +} + +/// Default allow-root: canonicalised `$HOME/.openfang/`. Returns an empty +/// vec if `HOME` is unset or the directory does not exist (in which case +/// every directive will be rejected — fail-closed). +fn default_allow_roots() -> Vec { + let mut roots = Vec::new(); + if let Some(home) = std::env::var_os("HOME") { + let mut p = PathBuf::from(home); + p.push(".openfang"); + if let Ok(canon) = std::fs::canonicalize(&p) { + roots.push(canon); + } + } + roots +} + +async fn resolve_directive( + d: &AttachDirective, + allow_roots: &[PathBuf], +) -> Result { + let raw = PathBuf::from(&d.path); + if !raw.is_absolute() { + return Err(format!("path must be absolute: {}", d.path)); + } + let canon = tokio::fs::canonicalize(&raw) + .await + .map_err(|e| format!("canonicalize {}: {e}", raw.display()))?; + if !allow_roots.iter().any(|r| canon.starts_with(r)) { + return Err(format!("path {} outside allow-roots", canon.display())); + } + let metadata = tokio::fs::metadata(&canon) + .await + .map_err(|e| format!("stat {}: {e}", canon.display()))?; + if !metadata.is_file() { + return Err(format!("not a regular file: {}", canon.display())); + } + if metadata.len() > MAX_FILE_BYTES { + return Err(format!( + "{} exceeds {} byte cap (size {})", + canon.display(), + MAX_FILE_BYTES, + metadata.len() + )); + } + let data = tokio::fs::read(&canon) + .await + .map_err(|e| format!("read {}: {e}", canon.display()))?; + let basename = canon + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("file") + .to_string(); + let mut filename = d.name.clone().unwrap_or(basename); + if d.spoiler && !filename.starts_with("SPOILER_") { + // Discord's `SPOILER_` filename prefix flags the attachment as a + // spoiler. Other adapters ignore the prefix harmlessly. + filename = format!("SPOILER_{}", filename); + } + let mime_type = mime_from_extension(&canon).to_string(); + Ok(ChannelContent::FileData { + data, + filename, + mime_type, + }) +} + +/// Parse `text`, resolve every `` marker against +/// `allow_roots_override` (or the default `$HOME/.openfang/` root if +/// `None`), and return either `NoMarkers` or `WithAttachments`. +/// +/// The returned `stripped_text` is the original with markers removed and +/// `caption` attribute values appended (each on its own line, in +/// document order). The caller is responsible for running the channel +/// formatter over `stripped_text` — formatting *before* parsing would +/// HTML-escape `<` in markers and break detection. +pub async fn parse(text: &str, allow_roots_override: Option<&[PathBuf]>) -> Parsed { + let re = marker_regex(); + if !re.is_match(text) { + return Parsed::NoMarkers; + } + let owned_default; + let allow_roots: &[PathBuf] = match allow_roots_override { + Some(r) => r, + None => { + owned_default = default_allow_roots(); + &owned_default + } + }; + + let mut stripped = String::with_capacity(text.len()); + let mut last = 0; + let mut directives: Vec = Vec::new(); + let mut captions: Vec = Vec::new(); + + for cap in re.captures_iter(text) { + let m = cap.get(0).unwrap(); + let attrs = cap.get(1).map(|m| m.as_str()).unwrap_or(""); + stripped.push_str(&text[last..m.start()]); + match parse_directive(attrs) { + Some(d) => { + if directives.len() >= MAX_ATTACHMENTS_PER_MESSAGE { + warn!( + "outbound_attach: dropping marker beyond {} attachments cap", + MAX_ATTACHMENTS_PER_MESSAGE + ); + // Keep the marker visible — the agent should see it + // wasn't honoured. + stripped.push_str(m.as_str()); + } else { + if let Some(c) = &d.caption { + captions.push(c.clone()); + } + directives.push(d); + } + } + None => { + // Malformed marker — leave it in place for debuggability. + stripped.push_str(m.as_str()); + } + } + last = m.end(); + } + stripped.push_str(&text[last..]); + + // Append captions on their own lines. + let mut stripped_text = stripped.trim_end().to_string(); + for c in &captions { + if !stripped_text.is_empty() { + stripped_text.push('\n'); + } + stripped_text.push_str(c); + } + + let mut files: Vec = Vec::with_capacity(directives.len()); + for d in &directives { + match resolve_directive(d, allow_roots).await { + Ok(block) => files.push(block), + Err(e) => { + warn!("outbound_attach: skipping {}: {}", d.path, e); + } + } + } + + Parsed::WithAttachments { + stripped_text, + files, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + fn fixture_root() -> (tempfile::TempDir, Vec) { + let tmp = tempfile::tempdir().expect("tempdir"); + let root = std::fs::canonicalize(tmp.path()).expect("canonicalize tmp"); + (tmp, vec![root]) + } + + #[tokio::test] + async fn no_markers_returns_no_markers() { + let result = parse("just some prose, no markers here", None).await; + assert!(matches!(result, Parsed::NoMarkers)); + } + + #[tokio::test] + async fn single_marker_resolves_to_filedata() { + let (tmp, roots) = fixture_root(); + let path = tmp.path().join("hello.txt"); + std::fs::write(&path, b"hi").unwrap(); + let canon = std::fs::canonicalize(&path).unwrap(); + let text = format!( + "Here you go: done.", + canon.display() + ); + + let result = parse(&text, Some(&roots)).await; + match result { + Parsed::WithAttachments { + stripped_text, + files, + } => { + assert_eq!(stripped_text, "Here you go: done."); + assert_eq!(files.len(), 1); + match &files[0] { + ChannelContent::FileData { + data, + filename, + mime_type, + } => { + assert_eq!(data, b"hi"); + assert_eq!(filename, "hello.txt"); + assert_eq!(mime_type, "text/plain"); + } + _ => panic!("expected FileData"), + } + } + _ => panic!("expected WithAttachments"), + } + } + + #[tokio::test] + async fn caption_attribute_is_appended_to_text() { + let (tmp, roots) = fixture_root(); + let path = tmp.path().join("note.pdf"); + std::fs::write(&path, b"%PDF-1.4 stub").unwrap(); + let canon = std::fs::canonicalize(&path).unwrap(); + let text = format!( + "", + canon.display() + ); + + let result = parse(&text, Some(&roots)).await; + match result { + Parsed::WithAttachments { + stripped_text, + files, + } => { + assert_eq!(stripped_text, "for the meeting"); + assert_eq!(files.len(), 1); + match &files[0] { + ChannelContent::FileData { + filename, + mime_type, + .. + } => { + assert_eq!(filename, "note.pdf"); + assert_eq!(mime_type, "application/pdf"); + } + _ => panic!("expected FileData"), + } + } + _ => panic!("expected WithAttachments"), + } + } + + #[tokio::test] + async fn spoiler_prefixes_filename() { + let (tmp, roots) = fixture_root(); + let path = tmp.path().join("secret.png"); + std::fs::write(&path, b"\x89PNG").unwrap(); + let canon = std::fs::canonicalize(&path).unwrap(); + let text = format!( + "", + canon.display() + ); + + let result = parse(&text, Some(&roots)).await; + match result { + Parsed::WithAttachments { files, .. } => { + match &files[0] { + ChannelContent::FileData { filename, .. } => { + assert_eq!(filename, "SPOILER_secret.png"); + } + _ => panic!("expected FileData"), + } + } + _ => panic!("expected WithAttachments"), + } + } + + #[tokio::test] + async fn name_attribute_overrides_basename() { + let (tmp, roots) = fixture_root(); + let path = tmp.path().join("ugly-uuid-name.pdf"); + std::fs::write(&path, b"%PDF").unwrap(); + let canon = std::fs::canonicalize(&path).unwrap(); + let text = format!( + "", + canon.display() + ); + + let result = parse(&text, Some(&roots)).await; + match result { + Parsed::WithAttachments { files, .. } => match &files[0] { + ChannelContent::FileData { filename, .. } => { + assert_eq!(filename, "report.pdf"); + } + _ => panic!("expected FileData"), + }, + _ => panic!("expected WithAttachments"), + } + } + + #[tokio::test] + async fn path_outside_allow_root_is_rejected() { + // Use a path in /tmp that we know exists but isn't under our + // synthetic allow-root. + let (_keep, roots) = fixture_root(); + let outside = std::env::temp_dir().join("openfang-outbound-attach-outside.txt"); + std::fs::write(&outside, b"x").unwrap(); + let canon = std::fs::canonicalize(&outside).unwrap(); + + // Sanity: outside isn't under our fixture root. + assert!(!canon.starts_with(&roots[0])); + + let text = format!("", canon.display()); + let result = parse(&text, Some(&roots)).await; + match result { + Parsed::WithAttachments { + stripped_text, + files, + } => { + assert_eq!(stripped_text, ""); + assert!( + files.is_empty(), + "directive outside allow-root must be dropped" + ); + } + _ => panic!("expected WithAttachments (with empty files)"), + } + let _ = std::fs::remove_file(&outside); + } + + #[tokio::test] + async fn relative_path_is_rejected() { + let (_keep, roots) = fixture_root(); + let result = parse( + "", + Some(&roots), + ) + .await; + match result { + Parsed::WithAttachments { files, .. } => { + assert!(files.is_empty(), "relative path must be rejected"); + } + _ => panic!("expected WithAttachments"), + } + } + + #[tokio::test] + async fn multiple_markers_are_all_resolved() { + let (tmp, roots) = fixture_root(); + let p1 = tmp.path().join("a.txt"); + let p2 = tmp.path().join("b.txt"); + std::fs::write(&p1, b"a").unwrap(); + std::fs::write(&p2, b"b").unwrap(); + let c1 = std::fs::canonicalize(&p1).unwrap(); + let c2 = std::fs::canonicalize(&p2).unwrap(); + let text = format!( + "first then end", + c1.display(), + c2.display() + ); + + let result = parse(&text, Some(&roots)).await; + match result { + Parsed::WithAttachments { + stripped_text, + files, + } => { + assert_eq!(stripped_text, "first then end"); + assert_eq!(files.len(), 2); + } + _ => panic!("expected WithAttachments"), + } + } + + #[tokio::test] + async fn malformed_marker_left_in_place() { + // No `path=` attribute → directive is invalid. + let result = parse( + "before after", + None, + ) + .await; + match result { + Parsed::WithAttachments { + stripped_text, + files, + } => { + assert!(files.is_empty()); + assert!( + stripped_text.contains(""), + "malformed marker should be preserved verbatim" + ); + } + _ => panic!("expected WithAttachments (with malformed marker preserved)"), + } + } + + #[test] + fn mime_table_covers_common_extensions() { + assert_eq!(mime_from_extension(Path::new("x.pdf")), "application/pdf"); + assert_eq!(mime_from_extension(Path::new("x.PNG")), "image/png"); + assert_eq!(mime_from_extension(Path::new("x.unknown")), "application/octet-stream"); + assert_eq!(mime_from_extension(Path::new("noext")), "application/octet-stream"); + } +} diff --git a/crates/openfang-channels/src/telegram.rs b/crates/openfang-channels/src/telegram.rs index cb4a5b01b..7e5ddf849 100644 --- a/crates/openfang-channels/src/telegram.rs +++ b/crates/openfang-channels/src/telegram.rs @@ -498,7 +498,7 @@ impl TelegramAdapter { self.api_send_photo(chat_id, &url, caption.as_deref(), thread_id) .await?; } - ChannelContent::File { url, filename } => { + ChannelContent::File { url, filename, .. } => { self.api_send_document(chat_id, &url, &filename, thread_id) .await?; } @@ -521,6 +521,17 @@ impl TelegramAdapter { self.api_send_message(chat_id, text.trim(), thread_id) .await?; } + ChannelContent::Multipart(parts) => { + // Send each child as its own Telegram message. Nested + // Multipart is rejected by adapters; flatten defensively. + for part in parts { + if let ChannelContent::Multipart(_) = part { + debug_assert!(false, "nested Multipart in send_to_user"); + continue; + } + Box::pin(self.send_content(user, part, thread_id)).await?; + } + } } Ok(()) } @@ -934,7 +945,12 @@ async fn parse_telegram_update( .unwrap_or("document") .to_string(); match telegram_get_file_url(token, client, file_id, api_base_url).await { - Some(url) => ChannelContent::File { url, filename }, + Some(url) => ChannelContent::File { + url, + filename, + mime: None, + size: None, + }, None => ChannelContent::Text(format!("[Document received: {filename}]")), } } else if message.get("voice").is_some() { @@ -2051,10 +2067,7 @@ mod tests { body, ) } else { - ( - StatusCode::OK, - r#"{"ok":true,"result":true}"#.to_string(), - ) + (StatusCode::OK, r#"{"ok":true,"result":true}"#.to_string()) } } })); @@ -2131,7 +2144,10 @@ mod tests { // Two-chunk message; first POST fails. Nothing delivered → Err. let big = "a".repeat(5000); // > 4096 → split into two chunks let stub = StubServer::new(vec![ - (500, r#"{"ok":false,"error_code":500,"description":"server"}"#), + ( + 500, + r#"{"ok":false,"error_code":500,"description":"server"}"#, + ), (200, r#"{"ok":true,"result":{}}"#), ]); let base = spawn_stub_server(stub.clone()).await; @@ -2159,7 +2175,10 @@ mod tests { let big = "a".repeat(5000); let stub = StubServer::new(vec![ (200, r#"{"ok":true,"result":{}}"#), - (400, r#"{"ok":false,"error_code":400,"description":"some err"}"#), + ( + 400, + r#"{"ok":false,"error_code":400,"description":"some err"}"#, + ), ]); let base = spawn_stub_server(stub.clone()).await; let adapter = test_adapter(base); @@ -2170,7 +2189,11 @@ mod tests { result.is_ok(), "partial delivery must return Ok (best-effort), got {result:?}" ); - assert_eq!(stub.hit_count(), 2, "both chunks should have been attempted"); + assert_eq!( + stub.hit_count(), + 2, + "both chunks should have been attempted" + ); } // ----------------------------------------------------------------------- diff --git a/crates/openfang-channels/src/types.rs b/crates/openfang-channels/src/types.rs index 84247b5af..8d3904169 100644 --- a/crates/openfang-channels/src/types.rs +++ b/crates/openfang-channels/src/types.rs @@ -50,6 +50,16 @@ pub enum ChannelContent { File { url: String, filename: String, + /// Best-effort MIME type from the source platform (e.g. Discord's + /// `attachments[].content_type`). `None` if the platform did not + /// provide one; downstream consumers may sniff bytes or fall back + /// to extension-based detection. + #[serde(default, skip_serializing_if = "Option::is_none")] + mime: Option, + /// Size in bytes, when known. Useful for capacity gating before + /// the bridge attempts to materialize or transmit the file. + #[serde(default, skip_serializing_if = "Option::is_none")] + size: Option, }, /// Local file data (bytes read from disk). Used by the proactive `channel_send` /// tool when `file_path` is provided instead of `file_url`. @@ -70,6 +80,12 @@ pub enum ChannelContent { name: String, args: Vec, }, + /// A composite message carrying multiple content blocks (e.g. a Discord + /// message with several attachments, or an image with a separate file + /// sibling). Blocks are flat-mapped by the bridge into multiple LLM + /// content blocks. Implementations should not produce nested `Multipart` + /// values; consumers may `debug_assert!` against nesting. + Multipart(Vec), } /// A unified message from any channel. diff --git a/crates/openfang-channels/src/whatsapp.rs b/crates/openfang-channels/src/whatsapp.rs index 16f37b56d..8f656b559 100644 --- a/crates/openfang-channels/src/whatsapp.rs +++ b/crates/openfang-channels/src/whatsapp.rs @@ -271,7 +271,7 @@ impl ChannelAdapter for WhatsAppAdapter { return Err(format!("WhatsApp API error {status}: {body}").into()); } } - ChannelContent::File { url, filename } => { + ChannelContent::File { url, filename, .. } => { let body = serde_json::json!({ "messaging_product": "whatsapp", "to": user.platform_id, diff --git a/crates/openfang-kernel/src/kernel.rs b/crates/openfang-kernel/src/kernel.rs index c91d02a85..28a7d1500 100644 --- a/crates/openfang-kernel/src/kernel.rs +++ b/crates/openfang-kernel/src/kernel.rs @@ -7252,6 +7252,8 @@ impl KernelHandle for OpenFangKernel { "file" => openfang_channels::types::ChannelContent::File { url: media_url.to_string(), filename: filename.unwrap_or("file").to_string(), + mime: None, + size: None, }, _ => { return Err(format!( diff --git a/crates/openfang-runtime/src/agent_loop.rs b/crates/openfang-runtime/src/agent_loop.rs index 7f5f05edb..b62118dc7 100644 --- a/crates/openfang-runtime/src/agent_loop.rs +++ b/crates/openfang-runtime/src/agent_loop.rs @@ -3327,6 +3327,7 @@ mod tests { ContentBlock::Image { media_type: "image/png".to_string(), data: "aGVsbG8=".to_string(), + source_url: None, } } diff --git a/crates/openfang-runtime/src/compactor.rs b/crates/openfang-runtime/src/compactor.rs index fef90c815..c502c0db4 100644 --- a/crates/openfang-runtime/src/compactor.rs +++ b/crates/openfang-runtime/src/compactor.rs @@ -400,8 +400,31 @@ fn build_conversation_text(messages: &[Message], config: &CompactionConfig) -> S conversation_text .push_str(&format!("[Tool result ({status}): {preview}]\n\n")); } - ContentBlock::Image { media_type, .. } => { - conversation_text.push_str(&format!("[Image: {media_type}]\n\n")); + ContentBlock::Image { + media_type, + source_url, + .. + } => { + // Preserve the original CDN URL across compaction so the + // outbound Discord path (PR-C) can re-attach the image by + // re-fetching it. Only http(s) URLs are exposed: local + // `file://` tmpfile paths are an internal materialization + // detail and shouldn't leak into compacted summaries that + // may be persisted, logged, or sent across processes. + match source_url.as_deref() { + Some(url) + if url.starts_with("http://") + || url.starts_with("https://") => + { + conversation_text.push_str(&format!( + "[Image: {media_type} @ {url}]\n\n" + )); + } + _ => { + conversation_text + .push_str(&format!("[Image: {media_type}]\n\n")); + } + } } ContentBlock::Thinking { .. } => {} ContentBlock::Unknown => {} @@ -1266,6 +1289,7 @@ mod tests { content: MessageContent::Blocks(vec![ContentBlock::Image { media_type: "image/png".to_string(), data: "base64data".to_string(), + source_url: None, }]), }, ]; @@ -1278,6 +1302,75 @@ mod tests { assert!(text.contains("[Image: image/png]")); } + #[test] + fn test_build_conversation_text_image_source_url_https() { + // https:// CDN URL is exposed post-compaction so the outbound path + // can re-fetch the image. + let config = CompactionConfig::default(); + let messages = vec![Message { + role: Role::User, + content: MessageContent::Blocks(vec![ContentBlock::Image { + media_type: "image/png".to_string(), + data: "base64data".to_string(), + source_url: Some("https://cdn.discordapp.com/attachments/x/y.png".to_string()), + }]), + }]; + let text = build_conversation_text(&messages, &config); + assert!( + text.contains("[Image: image/png @ https://cdn.discordapp.com/attachments/x/y.png]"), + "https source_url should be preserved, got: {text}" + ); + } + + #[test] + fn test_build_conversation_text_image_source_url_http() { + // Plain http (rare but valid) is also exposed. + let config = CompactionConfig::default(); + let messages = vec![Message { + role: Role::User, + content: MessageContent::Blocks(vec![ContentBlock::Image { + media_type: "image/jpeg".to_string(), + data: "base64data".to_string(), + source_url: Some("http://example.com/foo.jpg".to_string()), + }]), + }]; + let text = build_conversation_text(&messages, &config); + assert!( + text.contains("[Image: image/jpeg @ http://example.com/foo.jpg]"), + "http source_url should be preserved, got: {text}" + ); + } + + #[test] + fn test_build_conversation_text_image_source_url_file_falls_back() { + // file:// URLs (local tmpfile materialization) MUST NOT leak into + // compacted summaries — fall back to the legacy mime-only form. + let config = CompactionConfig::default(); + let messages = vec![Message { + role: Role::User, + content: MessageContent::Blocks(vec![ContentBlock::Image { + media_type: "image/png".to_string(), + data: "base64data".to_string(), + source_url: Some( + "file:///Users/x/.openfang/tmp/images/abc.png".to_string(), + ), + }]), + }]; + let text = build_conversation_text(&messages, &config); + assert!( + text.contains("[Image: image/png]"), + "file:// source_url should fall back to legacy form, got: {text}" + ); + assert!( + !text.contains("file://"), + "file:// path must not leak post-compaction, got: {text}" + ); + assert!( + !text.contains(".openfang"), + "local tmpfile path must not leak post-compaction, got: {text}" + ); + } + #[test] fn test_build_conversation_text_truncates_oversized() { let config = CompactionConfig { diff --git a/crates/openfang-runtime/src/drivers/anthropic.rs b/crates/openfang-runtime/src/drivers/anthropic.rs index e6f10fe58..9ca8a135f 100644 --- a/crates/openfang-runtime/src/drivers/anthropic.rs +++ b/crates/openfang-runtime/src/drivers/anthropic.rs @@ -664,7 +664,7 @@ fn convert_message(msg: &Message) -> ApiMessage { ContentBlock::Text { text, .. } => { Some(ApiContentBlock::Text { text: text.clone() }) } - ContentBlock::Image { media_type, data } => Some(ApiContentBlock::Image { + ContentBlock::Image { media_type, data, .. } => Some(ApiContentBlock::Image { source: ApiImageSource { source_type: "base64".to_string(), media_type: media_type.clone(), diff --git a/crates/openfang-runtime/src/drivers/claude_code.rs b/crates/openfang-runtime/src/drivers/claude_code.rs index 21e0fb6d1..b2e530a88 100644 --- a/crates/openfang-runtime/src/drivers/claude_code.rs +++ b/crates/openfang-runtime/src/drivers/claude_code.rs @@ -8,11 +8,13 @@ //! Tracks active subprocess PIDs and enforces message timeouts to prevent //! hung CLI processes from blocking agents indefinitely. +use crate::image_cache::{image_tmp_dir, materialize_image, spawn_sweep_once}; use crate::llm_driver::{CompletionRequest, CompletionResponse, LlmDriver, LlmError, StreamEvent}; use async_trait::async_trait; use dashmap::DashMap; -use openfang_types::message::{ContentBlock, Role, StopReason, TokenUsage}; +use openfang_types::message::{ContentBlock, MessageContent, Role, StopReason, TokenUsage}; use serde::Deserialize; +use std::path::Path; use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncReadExt}; use tracing::{debug, info, warn}; @@ -52,6 +54,11 @@ const SENSITIVE_SUFFIXES: &[&str] = &["_SECRET", "_TOKEN", "_PASSWORD"]; /// Default subprocess timeout in seconds (5 minutes). const DEFAULT_MESSAGE_TIMEOUT_SECS: u64 = 300; +// Image materialization helpers (image_tmp_dir, ext_for_mime, +// materialize_image, sweep_old_image_tmpfiles, TTL constants, sweep guard) +// live in crate::image_cache so the outbound file-sharing path can reuse +// the same content-addressed cache without a circular dep on this driver. + /// LLM driver that delegates to the Claude Code CLI. pub struct ClaudeCodeDriver { cli_path: String, @@ -78,6 +85,9 @@ impl ClaudeCodeDriver { ); } + // Best-effort sweep of stale image tmpfiles, once per process. + spawn_sweep_once(); + Self { cli_path: cli_path .filter(|s| !s.is_empty()) @@ -131,6 +141,7 @@ impl ClaudeCodeDriver { /// Build a text prompt from the completion request messages. fn build_prompt(request: &CompletionRequest) -> String { + let tmp_dir = image_tmp_dir(); let mut parts = Vec::new(); for msg in &request.messages { @@ -139,15 +150,88 @@ impl ClaudeCodeDriver { Role::Assistant => "Assistant", Role::System => "System", }; - let text = msg.content.text_content(); - if !text.is_empty() { - parts.push(format!("[{role_label}]\n{text}")); + let rendered = Self::render_content(&msg.content, Some(&tmp_dir)); + if !rendered.is_empty() { + parts.push(format!("[{role_label}]\n{rendered}")); } } parts.join("\n\n") } + /// Render message content for the text-only CLI prompt. + /// + /// Text blocks pass through verbatim. Image blocks are materialized to + /// an on-disk tmpfile (when `image_dir` is provided) so the model can + /// view them via the CLI's `Read` tool — Claude Code is multimodal and + /// will load the file as native image content. We render a directive + /// telling the model exactly which path to read, plus the original + /// `source_url` (e.g. Discord CDN) when known. If materialization + /// fails or `image_dir` is `None` (test path), we fall back to the + /// legacy textual placeholder so the model at least knows an + /// attachment arrived. ToolUse/ToolResult/Thinking are omitted — + /// the CLI manages its own tool loop. + fn render_content(content: &MessageContent, image_dir: Option<&Path>) -> String { + match content { + MessageContent::Text(s) => s.clone(), + MessageContent::Blocks(blocks) => blocks + .iter() + .filter_map(|b| match b { + ContentBlock::Text { text, .. } => { + if text.is_empty() { + None + } else { + Some(text.clone()) + } + } + ContentBlock::Image { + media_type, + data, + source_url, + } => { + // base64 → ~3/4 the length in decoded bytes. + let approx_kb = (data.len().saturating_mul(3) / 4) / 1024; + let url_suffix = match source_url { + Some(u) => format!(" (original: {u})"), + None => String::new(), + }; + if let Some(dir) = image_dir { + // Best-effort filename hint: peel the last path + // segment off the source URL (works for Discord + // CDN, Telegram file API, S3, etc.). Materialize + // appends a sanitized suffix so a human browsing + // ~/.openfang/tmp/images/ can grep for the + // original attachment name. Falls back to a + // pure-hash filename when no URL or no segment. + let name_hint = source_url + .as_deref() + .and_then(filename_hint_from_url); + if let Some(path) = + materialize_image(media_type, data, dir, name_hint.as_deref()) + { + return Some(format!( + "[attachment: {media_type} image, ~{approx_kb} KB — view with the Read tool at {path}{url_suffix}]", + path = path.display() + )); + } + } + // Fallback: at least surface the URL if we have one. + Some(format!( + "[attachment: {media_type} image, ~{approx_kb} KB — not viewable on this provider{url_suffix}]" + )) + } + ContentBlock::ToolUse { .. } + | ContentBlock::ToolResult { .. } + | ContentBlock::Thinking { .. } + | ContentBlock::Unknown => None, + }) + .collect::>() + .join("\n"), + } + } + + // (helper `filename_hint_from_url` lives at module scope below.) + /// Map a model ID like "claude-code/opus" to CLI --model flag value. fn model_flag(model: &str) -> Option { let stripped = model.strip_prefix("claude-code/").unwrap_or(model); @@ -280,6 +364,14 @@ impl LlmDriver for ClaudeCodeDriver { cmd.arg("--model").arg(model); } + // Grant the CLI's Read tool access to our image tmp dir, which lives + // outside the agent's workspace cwd. Without --add-dir the CLI would + // refuse Read on `$HOME/.openfang/tmp/images/*` (unless + // --dangerously-skip-permissions is set) and the materialization would + // be a dead-end. Cheap and idempotent — the dir is per-user and + // content-addressed. + cmd.arg("--add-dir").arg(image_tmp_dir()); + Self::apply_env_filter(&mut cmd); // Inject HOME so the CLI can find its credentials (~/.claude/) when @@ -476,6 +568,9 @@ impl LlmDriver for ClaudeCodeDriver { cmd.arg("--model").arg(model); } + // Same image-tmp-dir grant as the non-streaming path; see complete(). + cmd.arg("--add-dir").arg(image_tmp_dir()); + Self::apply_env_filter(&mut cmd); // Same HOME and stdin hygiene as the non-streaming path. @@ -664,6 +759,52 @@ impl LlmDriver for ClaudeCodeDriver { } } +/// Best-effort: extract a filename hint from a URL's last path segment so +/// the materialized tmpfile carries a human-readable suffix. Drops query +/// and fragment, percent-decodes lossily, and bails on values that don't +/// look like filenames (no `.`, or only path-ish junk). Total — bad input +/// just yields `None` and the caller falls back to a pure-hash filename. +fn filename_hint_from_url(url: &str) -> Option { + // Strip scheme://host. Same shape as the discord adapter's helper but + // duplicated here to avoid pulling the channels crate into runtime. + let after_scheme = url.split_once("://").map(|(_, r)| r).unwrap_or(url); + let path = after_scheme.split_once('/').map(|(_, r)| r).unwrap_or(""); + let path = path.split(['?', '#']).next().unwrap_or(""); + let last = path.rsplit('/').next().unwrap_or(""); + if last.is_empty() { + return None; + } + // file:// URLs already point at our own tmpfile (the inbox materializer + // uses them) — those names are already content-addressed and a name + // hint there would just double-suffix. Skip. + if url.starts_with("file://") { + return None; + } + // Lossy percent-decode for things like `photo%20final.png`. + let mut out = Vec::with_capacity(last.len()); + let bytes = last.as_bytes(); + let mut i = 0; + while i < bytes.len() { + if bytes[i] == b'%' && i + 2 < bytes.len() { + let hi = (bytes[i + 1] as char).to_digit(16); + let lo = (bytes[i + 2] as char).to_digit(16); + if let (Some(h), Some(l)) = (hi, lo) { + out.push((h * 16 + l) as u8); + i += 3; + continue; + } + } + out.push(bytes[i]); + i += 1; + } + let decoded = String::from_utf8_lossy(&out).into_owned(); + if decoded.is_empty() { + None + } else { + Some(decoded) + } +} + /// Check if the Claude Code CLI is available. pub fn claude_code_available() -> bool { ClaudeCodeDriver::detect().is_some() || claude_credentials_exist() @@ -726,6 +867,83 @@ mod tests { assert!(prompt.contains("Hello")); } + #[test] + fn test_build_prompt_renders_image_attachment_marker() { + use openfang_types::message::{ContentBlock, Message, MessageContent}; + + // ~12 KB of base64 — decoded ~9 KB. + let fake_b64 = "A".repeat(12 * 1024); + let request = CompletionRequest { + model: "claude-code/sonnet".to_string(), + messages: vec![Message { + role: Role::User, + content: MessageContent::Blocks(vec![ + ContentBlock::Text { + text: "what's in this?".to_string(), + provider_metadata: None, + }, + ContentBlock::Image { + media_type: "image/png".to_string(), + data: fake_b64, + source_url: None, + }, + ]), + }], + tools: vec![], + max_tokens: 1024, + temperature: 0.7, + system: None, + thinking: None, + }; + + let prompt = ClaudeCodeDriver::build_prompt(&request); + assert!(prompt.contains("what's in this?"), "text preserved"); + assert!( + prompt.contains("[attachment: image/png image"), + "image rendered as synthetic attachment marker, got: {prompt}" + ); + // Either materialized to a tmpfile (preferred) or fell back to + // the legacy "not viewable" placeholder. Both are acceptable + // outcomes for this test; we just need the marker to be emitted. + assert!( + prompt.contains("view with the Read tool at") + || prompt.contains("not viewable on this provider"), + "marker either points at a tmpfile or explains the limitation, got: {prompt}" + ); + } + + #[test] + fn test_build_prompt_image_only_still_emits_marker() { + use openfang_types::message::{ContentBlock, Message, MessageContent}; + + let request = CompletionRequest { + model: "claude-code/sonnet".to_string(), + messages: vec![Message { + role: Role::User, + content: MessageContent::Blocks(vec![ContentBlock::Image { + media_type: "image/jpeg".to_string(), + data: "Zm9v".to_string(), + source_url: Some("https://cdn.example/foo.jpg".to_string()), + }]), + }], + tools: vec![], + max_tokens: 1024, + temperature: 0.7, + system: None, + thinking: None, + }; + + let prompt = ClaudeCodeDriver::build_prompt(&request); + assert!( + prompt.contains("[User]"), + "user role label emitted even with image-only content, got: {prompt}" + ); + assert!( + prompt.contains("[attachment: image/jpeg image"), + "bare image renders marker, got: {prompt}" + ); + } + #[test] fn test_model_flag_mapping() { assert_eq!( diff --git a/crates/openfang-runtime/src/drivers/gemini.rs b/crates/openfang-runtime/src/drivers/gemini.rs index 9efc617b9..f03aeb354 100644 --- a/crates/openfang-runtime/src/drivers/gemini.rs +++ b/crates/openfang-runtime/src/drivers/gemini.rs @@ -298,7 +298,7 @@ fn convert_messages( thought_signature, }); } - ContentBlock::Image { media_type, data } => { + ContentBlock::Image { media_type, data, .. } => { parts.push(GeminiPart::InlineData { inline_data: GeminiInlineData { mime_type: media_type.clone(), diff --git a/crates/openfang-runtime/src/drivers/openai.rs b/crates/openfang-runtime/src/drivers/openai.rs index 554e14b53..dac3c24b7 100644 --- a/crates/openfang-runtime/src/drivers/openai.rs +++ b/crates/openfang-runtime/src/drivers/openai.rs @@ -491,7 +491,7 @@ impl LlmDriver for OpenAIDriver { ContentBlock::Text { text, .. } => { parts.push(OaiContentPart::Text { text: text.clone() }); } - ContentBlock::Image { media_type, data } => { + ContentBlock::Image { media_type, data, .. } => { parts.push(OaiContentPart::ImageUrl { image_url: OaiImageUrl { url: format!("data:{media_type};base64,{data}"), diff --git a/crates/openfang-runtime/src/drivers/vertex.rs b/crates/openfang-runtime/src/drivers/vertex.rs index 9f0484163..58d3705c6 100644 --- a/crates/openfang-runtime/src/drivers/vertex.rs +++ b/crates/openfang-runtime/src/drivers/vertex.rs @@ -356,7 +356,7 @@ fn convert_messages( }, }); } - ContentBlock::Image { media_type, data } => { + ContentBlock::Image { media_type, data, .. } => { parts.push(VertexPart::InlineData { inline_data: VertexInlineData { mime_type: media_type.clone(), diff --git a/crates/openfang-runtime/src/image_cache.rs b/crates/openfang-runtime/src/image_cache.rs new file mode 100644 index 000000000..423044f31 --- /dev/null +++ b/crates/openfang-runtime/src/image_cache.rs @@ -0,0 +1,439 @@ +//! Content-addressed image tmpfile cache. +//! +//! Decodes base64 image payloads (the `ContentBlock::Image` shape used by +//! all LLM drivers) and writes them to a content-addressed file under +//! `$HOME/.openfang/tmp/images/` so out-of-process consumers — initially +//! the Claude Code CLI's Read tool, soon the outbound Discord bridge — +//! can reach the bytes by path. +//! +//! Originally lived inside `drivers/claude_code.rs`; lifted here so the +//! outbound file-sharing path can reuse the same cache without a circular +//! dep on the driver crate. Behavior is byte-identical to the previous +//! private implementation. +//! +//! Properties: +//! - **Idempotent.** Filename is the first 64 bits of SHA-256(bytes), so +//! re-rendering the same image hits the cache. +//! - **Atomic publish.** Bytes are written to a unique sibling tmpfile +//! then `rename(2)`-d into place; readers never see a torn file. +//! - **Time-bounded.** A best-effort sweep on first call (per process) +//! removes files older than [`IMAGE_TMP_TTL_SECS`]. + +use base64::Engine; +use sha2::{Digest, Sha256}; +use std::path::{Path, PathBuf}; +use std::sync::Once; +use tracing::{debug, info, warn}; + +/// TTL for materialized image tmpfiles (24 hours). Files older than this +/// are swept on first use. +pub const IMAGE_TMP_TTL_SECS: u64 = 24 * 60 * 60; + +/// One-shot guard so the TTL sweep only fires once per process. +static IMAGE_TMP_SWEEP_ONCE: Once = Once::new(); + +/// Resolve the directory used for materializing image attachments. +/// +/// Lives under `$HOME/.openfang/tmp/images` so it travels with the OpenFang +/// install. Falls back to the OS temp dir when `$HOME` isn't set (which +/// shouldn't happen in our deployed daemon but is handled defensively). +pub fn image_tmp_dir() -> PathBuf { + if let Ok(home) = std::env::var("HOME") { + let mut p = PathBuf::from(home); + p.push(".openfang"); + p.push("tmp"); + p.push("images"); + p + } else { + let mut p = std::env::temp_dir(); + p.push("openfang-images"); + p + } +} + +/// Map a MIME type to a sensible filename extension. +pub fn ext_for_mime(media_type: &str) -> &'static str { + match media_type.to_ascii_lowercase().as_str() { + "image/png" => "png", + "image/jpeg" | "image/jpg" => "jpg", + "image/gif" => "gif", + "image/webp" => "webp", + "image/heic" => "heic", + "image/heif" => "heif", + "image/bmp" => "bmp", + "image/svg+xml" => "svg", + _ => "bin", + } +} + +/// Decode the base64 image and write it to a content-addressed file under +/// `dir`. Idempotent: if a file with the same content hash already exists, +/// the existing path is returned without rewriting. Returns `None` on +/// decode or I/O failure (caller falls back to a textual placeholder). +/// +/// `original_name`, if present, is sanitized and appended to the filename +/// after the content hash (`__.`) so a human +/// browsing `~/.openfang/tmp/images/` can grep/eyeball-match files to the +/// inbound attachment they came from. Cache-hit lookup globs `*.` +/// so a re-render with a new (or no) name reuses the existing file. +pub fn materialize_image( + media_type: &str, + data: &str, + dir: &Path, + original_name: Option<&str>, +) -> Option { + let bytes = base64::engine::general_purpose::STANDARD + .decode(data.as_bytes()) + .ok()?; + let mut hasher = Sha256::new(); + hasher.update(&bytes); + let hash = hasher.finalize(); + let hex: String = hash.iter().take(16).map(|b| format!("{:02x}", b)).collect(); + let ext = ext_for_mime(media_type); + + // Cache-hit: if any file with this hash prefix already exists (with or + // without a name suffix, regardless of which name it carries), reuse + // it. Two callers feeding different names for the same bytes converge + // on whichever named the file first; the alternative would be writing + // multiple copies of identical bytes to disk for cosmetic reasons. + if let Some(existing) = find_existing_for_hash(dir, &hex, ext) { + if let Err(e) = touch_mtime(&existing) { + debug!(path = ?existing, error = %e, "failed to refresh image tmpfile mtime"); + } + return Some(existing); + } + + let filename = match original_name.and_then(sanitize_for_filename) { + Some(sanitized) => format!("{hex}__{sanitized}.{ext}"), + None => format!("{hex}.{ext}"), + }; + let path = dir.join(filename); + // Defensive: post-sanitize collision check (should be subsumed by the + // hash-prefix scan above, but kept so the legacy code path below is + // still safe if `find_existing_for_hash` ever misses). + if path.exists() { + // Refresh mtime on cache hit so the TTL sweep (which gates on + // `meta.modified()`) does not GC a tmpfile still being actively + // referenced. Without this, a long-running conversation that + // outlives `IMAGE_TMP_TTL_SECS` would lose its image bytes + // mid-thread, even though the content block is still in scope. + // Best-effort: any failure is debug-logged and the cached path + // is returned anyway — the worst case is the legacy 24h-GC + // behavior we just had. + if let Err(e) = touch_mtime(&path) { + debug!(path = ?path, error = %e, "failed to refresh image tmpfile mtime"); + } + return Some(path); + } + if let Err(e) = std::fs::create_dir_all(dir) { + warn!(dir = ?dir, error = %e, "failed to create openfang image tmp dir"); + return None; + } + // Atomic publish: write to a unique tmp sibling, then rename into place. + // Two concurrent renders of the same image each write their own tmpfile; + // the rename(2) is atomic on the same filesystem, so consumers never see + // a torn or partially-written file. If the destination already exists by + // the time we rename (loser of a race), the rename still succeeds (POSIX + // replaces) — and the contents are identical anyway by construction. + let tmp_path = dir.join(format!( + "{hex}.{pid}.{nanos}.tmp", + pid = std::process::id(), + nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0), + )); + if let Err(e) = std::fs::write(&tmp_path, &bytes) { + warn!(path = ?tmp_path, error = %e, "failed to write openfang image tmpfile"); + return None; + } + if let Err(e) = std::fs::rename(&tmp_path, &path) { + warn!(from = ?tmp_path, to = ?path, error = %e, "failed to rename openfang image tmpfile into place"); + // Best-effort cleanup of the orphan tmpfile. + let _ = std::fs::remove_file(&tmp_path); + return None; + } + Some(path) +} + +/// Sanitize a candidate filename fragment so it is safe to embed in a +/// path under `image_tmp_dir()`. Lowercases ASCII, replaces anything +/// outside `[a-z0-9._-]` with `_`, collapses runs of `_`, strips leading +/// dots (no hidden files), drops the extension if present (the caller +/// supplies the canonical extension from MIME), and caps length at 60. +/// Returns `None` if the result would be empty. +pub fn sanitize_for_filename(name: &str) -> Option { + // Drop any path components defensively — Discord filenames shouldn't + // contain `/`, but a malicious or malformed source could try to. + let leaf = name.rsplit(['/', '\\']).next().unwrap_or(name); + // Strip the trailing extension if any — we'll let the caller tack on + // the canonical one from media_type. `foo.tar.gz` → `foo.tar`, which + // is fine: the visual hint survives. + let stem = match leaf.rsplit_once('.') { + Some((s, _)) if !s.is_empty() => s, + _ => leaf, + }; + let mut out = String::with_capacity(stem.len()); + let mut last_underscore = false; + for c in stem.chars() { + let lc = c.to_ascii_lowercase(); + let keep = lc.is_ascii_alphanumeric() || matches!(lc, '.' | '-'); + if keep { + out.push(lc); + last_underscore = false; + } else if !last_underscore { + out.push('_'); + last_underscore = true; + } + } + let trimmed = out.trim_matches(|c: char| c == '_' || c == '.').to_string(); + if trimmed.is_empty() { + return None; + } + // Cap at 60 chars to keep total path length reasonable. + let capped: String = trimmed.chars().take(60).collect(); + Some(capped) +} + +/// Look for a previously-materialized tmpfile carrying the given content +/// hash, regardless of any human-readable name suffix that may have been +/// appended. Returns the first match found; in practice there is at most +/// one because the writer enforces uniqueness on collision via the rename +/// step. Best-effort: read errors return `None` and the caller falls +/// through to a fresh write. +fn find_existing_for_hash(dir: &Path, hex: &str, ext: &str) -> Option { + let entries = std::fs::read_dir(dir).ok()?; + let dot_ext = format!(".{ext}"); + for entry in entries.flatten() { + let path = entry.path(); + let Some(name) = path.file_name().and_then(|n| n.to_str()) else { + continue; + }; + if !name.ends_with(&dot_ext) { + continue; + } + // Match `.` or `__.`. + let stem = name.trim_end_matches(&dot_ext); + if stem == hex || stem.starts_with(&format!("{hex}__")) { + return Some(path); + } + } + None +} + +/// Refresh the mtime of `path` to "now" so it survives the next TTL +/// sweep. Uses `File::set_modified`, which on Unix calls `futimens(2)` +/// and on Windows calls `SetFileTime`. Windows requires the handle to +/// have write access (`FILE_WRITE_ATTRIBUTES`); Unix only requires the +/// caller own the file. Open with `.write(true)` for portability. +fn touch_mtime(path: &Path) -> std::io::Result<()> { + let f = std::fs::OpenOptions::new().write(true).open(path)?; + f.set_modified(std::time::SystemTime::now()) +} + +/// Delete image tmpfiles older than [`IMAGE_TMP_TTL_SECS`]. Best-effort: +/// any error is logged at debug and the sweep moves on. +pub fn sweep_old_image_tmpfiles(dir: &Path) { + let entries = match std::fs::read_dir(dir) { + Ok(e) => e, + Err(e) => { + debug!(dir = ?dir, error = %e, "image tmp sweep: read_dir failed (likely missing dir, fine)"); + return; + } + }; + let now = std::time::SystemTime::now(); + let ttl = std::time::Duration::from_secs(IMAGE_TMP_TTL_SECS); + let mut removed = 0u32; + for entry in entries.flatten() { + let path = entry.path(); + let Ok(meta) = entry.metadata() else { continue }; + if !meta.is_file() { + continue; + } + let Ok(modified) = meta.modified() else { continue }; + if let Ok(age) = now.duration_since(modified) { + if age > ttl { + if let Err(e) = std::fs::remove_file(&path) { + debug!(path = ?path, error = %e, "image tmp sweep: remove failed"); + } else { + removed += 1; + } + } + } + } + if removed > 0 { + info!(removed, "swept stale openfang image tmpfiles"); + } +} + +/// Spawn the once-per-process TTL sweep in a background thread. Safe to +/// call from any number of driver inits — the [`Once`] guard ensures only +/// the first call does work. +pub fn spawn_sweep_once() { + IMAGE_TMP_SWEEP_ONCE.call_once(|| { + let dir = image_tmp_dir(); + std::thread::spawn(move || sweep_old_image_tmpfiles(&dir)); + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use base64::Engine; + use std::time::{Duration, SystemTime}; + + /// A 1×1 transparent PNG, base64-encoded. Tiny enough to keep tests fast. + const TINY_PNG_B64: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="; + + #[test] + fn materialize_image_refreshes_mtime_on_cache_hit() { + let tmp = tempfile::tempdir().unwrap(); + let dir = tmp.path(); + + // First call materializes. + let path = materialize_image("image/png", TINY_PNG_B64, dir, None) + .expect("first materialization should succeed"); + assert!(path.exists()); + + // Backdate mtime to ~25 hours ago — past IMAGE_TMP_TTL_SECS. + let stale = SystemTime::now() - Duration::from_secs(IMAGE_TMP_TTL_SECS + 3600); + let f = std::fs::OpenOptions::new().write(true).open(&path).unwrap(); + f.set_modified(stale).unwrap(); + drop(f); + let mtime_before = std::fs::metadata(&path).unwrap().modified().unwrap(); + + // Second call should hit cache AND refresh mtime. + let path2 = materialize_image("image/png", TINY_PNG_B64, dir, None) + .expect("cache hit should return Some"); + assert_eq!(path, path2); + let mtime_after = std::fs::metadata(&path).unwrap().modified().unwrap(); + assert!( + mtime_after > mtime_before, + "mtime should be refreshed on cache hit (before={mtime_before:?}, after={mtime_after:?})" + ); + + // And the now-touched file must NOT be GC'd by a sweep that + // would have caught the stale mtime. + sweep_old_image_tmpfiles(dir); + assert!( + path.exists(), + "refreshed tmpfile should survive the TTL sweep" + ); + } + + #[test] + fn sweep_removes_stale_tmpfiles() { + // Sanity check that the sweep actually GCs old files — pairs with + // the test above to prove the refresh is what saves the file. + let tmp = tempfile::tempdir().unwrap(); + let dir = tmp.path(); + let path = materialize_image("image/png", TINY_PNG_B64, dir, None).unwrap(); + + let stale = SystemTime::now() - Duration::from_secs(IMAGE_TMP_TTL_SECS + 3600); + let f = std::fs::OpenOptions::new().write(true).open(&path).unwrap(); + f.set_modified(stale).unwrap(); + drop(f); + + sweep_old_image_tmpfiles(dir); + assert!(!path.exists(), "stale tmpfile should have been swept"); + } + + #[test] + fn ext_for_mime_known_and_unknown() { + assert_eq!(ext_for_mime("image/png"), "png"); + assert_eq!(ext_for_mime("IMAGE/JPEG"), "jpg"); + assert_eq!(ext_for_mime("image/webp"), "webp"); + assert_eq!(ext_for_mime("application/octet-stream"), "bin"); + } + + #[test] + fn materialize_image_rejects_invalid_base64() { + let tmp = tempfile::tempdir().unwrap(); + assert!( + materialize_image("image/png", "!!!not-base64!!!", tmp.path(), None).is_none() + ); + } + + #[test] + fn materialize_image_appends_sanitized_name() { + let tmp = tempfile::tempdir().unwrap(); + let dir = tmp.path(); + + let path = + materialize_image("image/png", TINY_PNG_B64, dir, Some("My Vacation Photo.PNG")) + .expect("first materialization"); + let name = path.file_name().unwrap().to_str().unwrap(); + assert!( + name.contains("__my_vacation_photo.png"), + "expected sanitized name suffix, got {name}" + ); + let stem_hex: String = name.chars().take(16).collect(); + assert!( + stem_hex.chars().all(|c| c.is_ascii_hexdigit()), + "expected leading hex hash, got {stem_hex}" + ); + } + + #[test] + fn materialize_image_cache_hit_finds_named_file() { + // Same bytes materialized first WITH a name, then again with no + // name: the second call must reuse the named file rather than + // writing a duplicate `.png`. + let tmp = tempfile::tempdir().unwrap(); + let dir = tmp.path(); + + let first = + materialize_image("image/png", TINY_PNG_B64, dir, Some("hello.png")).unwrap(); + let second = materialize_image("image/png", TINY_PNG_B64, dir, None).unwrap(); + assert_eq!(first, second, "cache lookup should find the named file"); + + let count = std::fs::read_dir(dir) + .unwrap() + .filter(|e| { + e.as_ref() + .ok() + .and_then(|e| e.metadata().ok()) + .map(|m| m.is_file()) + .unwrap_or(false) + }) + .count(); + assert_eq!(count, 1, "no duplicate tmpfile on cache hit"); + } + + #[test] + fn sanitize_for_filename_basic_cases() { + assert_eq!( + sanitize_for_filename("Hello World.png").as_deref(), + Some("hello_world") + ); + assert_eq!( + sanitize_for_filename("/etc/passwd").as_deref(), + Some("passwd") + ); + assert_eq!( + sanitize_for_filename("foo___bar.txt").as_deref(), + Some("foo_bar") + ); + // All-punctuation/dot input → None. + assert_eq!(sanitize_for_filename("...png").as_deref(), None); + // Trailing extension is stripped. + assert_eq!( + sanitize_for_filename("smoke-test.pdf").as_deref(), + Some("smoke-test") + ); + // Length cap at 60. + let long = "a".repeat(200); + let result = sanitize_for_filename(&format!("{long}.png")).unwrap(); + assert_eq!(result.len(), 60); + // Non-ASCII bytes collapse to a single `_`. + let s = sanitize_for_filename("café.jpg").unwrap(); + assert!(s.starts_with("caf"), "got {s}"); + } + + // Force-reference base64 engine to keep imports tidy in case someone + // refactors and the const is the only consumer. + #[allow(dead_code)] + fn _b64_compile_check() { + let _ = base64::engine::general_purpose::STANDARD.decode(TINY_PNG_B64); + } +} diff --git a/crates/openfang-runtime/src/lib.rs b/crates/openfang-runtime/src/lib.rs index bde54ab19..8793601b1 100644 --- a/crates/openfang-runtime/src/lib.rs +++ b/crates/openfang-runtime/src/lib.rs @@ -25,6 +25,7 @@ pub mod embedding; pub mod graceful_shutdown; pub mod hooks; pub mod host_functions; +pub mod image_cache; pub mod image_gen; pub mod kernel_handle; pub mod link_understanding; diff --git a/crates/openfang-types/src/message.rs b/crates/openfang-types/src/message.rs index 67955a445..be8afbb20 100644 --- a/crates/openfang-types/src/message.rs +++ b/crates/openfang-types/src/message.rs @@ -55,6 +55,12 @@ pub enum ContentBlock { media_type: String, /// Base64-encoded image data. data: String, + /// Original source URL (e.g. Discord CDN), if the image was + /// materialized from a remote attachment. Preserved alongside + /// `data` so text-only drivers can reference the URL and + /// vision-capable drivers retain it for diagnostics. + #[serde(default, skip_serializing_if = "Option::is_none")] + source_url: Option, }, /// A tool use request from the assistant. #[serde(rename = "tool_use")] @@ -325,6 +331,7 @@ mod tests { let block = ContentBlock::Image { media_type: "image/png".to_string(), data: "base64data".to_string(), + source_url: None, }; let json = serde_json::to_value(&block).unwrap(); assert_eq!(json["type"], "image"); @@ -461,6 +468,7 @@ mod tests { ContentBlock::Image { media_type: "image/jpeg".to_string(), data: "base64data".to_string(), + source_url: None, }, ]; let msg = Message::user_with_blocks(blocks);