diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c8d2428..70a3a173 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +### Fixed + +- Fix `slumber generate curl` output for multipart forms with file fields + ## [4.0.1] - 2025-09-14 diff --git a/Cargo.lock b/Cargo.lock index cdafc4d4..78f92c70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1412,13 +1412,16 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf9f1e950e0d9d1d3c47184416723cf29c0d1f93bd8cccf37e4beb6b44f31710" dependencies = [ + "base64 0.22.1", "bytes", "futures-channel", "futures-util", "http", "http-body", "hyper", + "ipnet", "libc", + "percent-encoding", "pin-project-lite", "socket2", "tokio", @@ -1638,6 +1641,16 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_executable" version = "1.0.4" @@ -2831,9 +2844,8 @@ checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" [[package]] name = "reqwest" -version = "0.12.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" +version = "0.12.23" +source = "git+https://github.com/LucasPickering/reqwest?branch=2374-custom-boundary#ab1b6a5223c4da5c8d0b63207d0255a6e97db9f5" dependencies = [ "base64 0.22.1", "bytes", @@ -2845,18 +2857,14 @@ dependencies = [ "hyper", "hyper-rustls", "hyper-util", - "ipnet", "js-sys", "log", - "mime", "mime_guess", - "once_cell", "percent-encoding", "pin-project-lite", "quinn", "rustls", "rustls-native-certs", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", @@ -2864,14 +2872,16 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", + "tokio-util", "tower 0.5.2", + "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", - "webpki-roots 0.26.11", - "windows-registry", + "webpki-roots 1.0.0", ] [[package]] @@ -3015,15 +3025,6 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" version = "1.12.0" @@ -3477,7 +3478,6 @@ dependencies = [ "pretty_assertions", "proptest", "proptest-derive", - "regex", "reqwest", "rstest", "rusqlite", @@ -4116,6 +4116,24 @@ dependencies = [ "tower-service", ] +[[package]] +name = "tower-http" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +dependencies = [ + "bitflags 2.9.1", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -4473,6 +4491,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" @@ -4624,7 +4655,7 @@ dependencies = [ "windows-interface", "windows-link", "windows-result", - "windows-strings 0.4.2", + "windows-strings", ] [[package]] @@ -4655,17 +4686,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" -[[package]] -name = "windows-registry" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" -dependencies = [ - "windows-result", - "windows-strings 0.3.1", - "windows-targets 0.53.0", -] - [[package]] name = "windows-result" version = "0.3.4" @@ -4675,15 +4695,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-strings" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" -dependencies = [ - "windows-link", -] - [[package]] name = "windows-strings" version = "0.4.2" @@ -4744,29 +4755,13 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", + "windows_i686_gnullvm", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] -[[package]] -name = "windows-targets" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" -dependencies = [ - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", -] - [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -4779,12 +4774,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" - [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -4797,12 +4786,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" - [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -4815,24 +4798,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" -[[package]] -name = "windows_i686_gnu" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" - [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" - [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -4845,12 +4816,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_i686_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" - [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -4863,12 +4828,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" - [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -4881,12 +4840,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" - [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -4899,12 +4852,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" - [[package]] name = "winnow" version = "0.7.12" diff --git a/Cargo.toml b/Cargo.toml index 9b8cd116..366f4822 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,10 @@ indexmap = {version = "2.0.0", default-features = false} itertools = "0.13.0" mime = "0.3.17" pretty_assertions = "1.4.0" -reqwest = {version = "0.12.5", default-features = false} +regex = {version = "1.10.5", default-features = false} +# TODO use a stable version +# reqwest = {version = "0.12.23", default-features = false} +reqwest = {git = "https://github.com/LucasPickering/reqwest", branch = "2374-custom-boundary", default-features = false} rstest = {version = "0.24.0", default-features = false} saphyr = "0.0.6" schemars = "1.0.2" diff --git a/crates/cli/src/completions.rs b/crates/cli/src/completions.rs index 5270abab..5af54e5b 100644 --- a/crates/cli/src/completions.rs +++ b/crates/cli/src/completions.rs @@ -161,7 +161,21 @@ mod tests { #[rstest] fn test_complete_recipe(_current_dir: CurrentDirGuard) { let completions = complete(complete_recipe()); - assert_eq!(&completions, &["getUser", "jsonBody", "chained"]); + assert_eq!( + &completions, + &[ + "getUser", + "query", + "headers", + "authBasic", + "authBearer", + "textBody", + "jsonBody", + "fileBody", + "multipart", + "chained" + ] + ); } /// Complete request IDs from the database diff --git a/crates/cli/tests/slumber.yml b/crates/cli/tests/slumber.yml index 4f840e1f..171124cc 100644 --- a/crates/cli/tests/slumber.yml +++ b/crates/cli/tests/slumber.yml @@ -20,6 +20,39 @@ requests: method: GET url: "{{ host }}/users/{{ username }}" + query: + method: GET + url: "{{ host }}/query" + query: + a: 1 + b: [2, 3] + + headers: + method: GET + url: "{{ host }}/headers" + headers: + Content-Type: text/plain + + authBasic: + method: GET + url: "{{ host }}/headers" + authentication: + type: basic + username: "{{ username }}" + password: hunter2 + + authBearer: + method: GET + url: "{{ host }}/headers" + authentication: + type: bearer + token: my-token + + textBody: + method: POST + url: "{{ host }}/text" + body: "This is an HTTP body" + jsonBody: method: POST url: "{{ host }}/json" @@ -27,6 +60,20 @@ requests: type: json data: { "username": "{{ username }}", "name": "Frederick Smidgen" } + fileBody: + method: POST + url: "{{ host }}/file" + body: "{{ file('test_data/data.json') }}" + + multipart: + method: POST + url: "{{ host }}/multipart" + body: + type: form_multipart + data: + username: "{{ username }}" + file: "{{ file('test_data/data.json') }}" + chained: method: GET url: "{{ host }}/chained/{{ response('getUser', trigger='always') | jsonpath('$.username') }}" diff --git a/crates/cli/tests/test_generate.rs b/crates/cli/tests/test_generate.rs index aeb0b12a..035e8d16 100644 --- a/crates/cli/tests/test_generate.rs +++ b/crates/cli/tests/test_generate.rs @@ -2,44 +2,35 @@ mod common; +use rstest::rstest; use serde_json::json; use slumber_core::database::Database; use wiremock::{Mock, MockServer, ResponseTemplate, matchers}; -/// Test generating a curl command with: -/// - URL -/// - Query params -/// - Headers -#[test] -fn test_generate_curl() { - let (mut command, _) = common::slumber(); - command.args(["generate", "curl", "getUser"]); - command - .assert() - .success() - .stdout("curl -XGET --url 'http://server/users/username1'\n"); -} - -/// Make sure the profile option is reflected correctly -#[test] -fn test_generate_curl_profile() { - let (mut command, _) = common::slumber(); - command.args(["generate", "curl", "getUser", "-p", "profile2"]); - command - .assert() - .success() - .stdout("curl -XGET --url 'http://server/users/username2'\n"); -} - -/// Make sure field overrides are applied correctly -#[test] -fn test_generate_curl_override() { +/// Test generating a curl command with different flags. Most of the request +/// components are tested in unit tests in the core crate, so we just need to +/// test CLI behavior here. +#[rstest] +#[case::url( + &["getUser"], + "curl -XGET --url 'http://server/users/username1'\n", +)] +#[case::profile( + &["getUser", "-p", "profile2"], + "curl -XGET --url 'http://server/users/username2'\n", +)] +#[case::overrides( + &["getUser", "-o", "username=username3"], + "curl -XGET --url 'http://server/users/username3'\n", +)] +fn test_generate_curl( + #[case] arguments: &[&str], + #[case] expected: &'static str, +) { let (mut command, _) = common::slumber(); - command.args(["generate", "curl", "getUser", "-o", "username=username3"]); - command - .assert() - .success() - .stdout("curl -XGET --url 'http://server/users/username3'\n"); + command.args(["generate", "curl"]); + command.args(arguments); + command.assert().success().stdout(expected); } /// Test failure when a downstream request is needed but cannot be triggered diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index eef501a6..bf6d84ac 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -22,8 +22,7 @@ futures = {workspace = true} indexmap = {workspace = true, features = ["serde"]} itertools = {workspace = true} mime = {workspace = true} -regex = {version = "1.10.5", default-features = false} -reqwest = {workspace = true, features = ["json", "multipart", "rustls-tls", "rustls-tls-native-roots", "rustls-tls-native-roots-no-provider"]} +reqwest = {workspace = true, features = ["json", "multipart", "rustls-tls", "rustls-tls-native-roots", "rustls-tls-native-roots-no-provider", "stream"]} rstest = {workspace = true, optional = true} rusqlite = {version = "0.35.0", default-features = false, features = ["bundled", "chrono", "uuid"]} rusqlite_migration = "2.1.0" diff --git a/crates/core/src/http.rs b/crates/core/src/http.rs index f93cff30..ac066293 100644 --- a/crates/core/src/http.rs +++ b/crates/core/src/http.rs @@ -63,9 +63,9 @@ use reqwest::{ redirect, }; use slumber_config::HttpEngineConfig; -use slumber_template::Template; +use slumber_template::{Stream, StreamMetadata, Template}; use slumber_util::ResultTraced; -use std::{collections::HashSet, error::Error}; +use std::{collections::HashSet, error::Error, path::PathBuf}; use tracing::{error, info, info_span}; const USER_AGENT: &str = concat!("slumber/", env!("CARGO_PKG_VERSION")); @@ -134,7 +134,7 @@ impl HttpEngine { pub async fn build( &self, seed: RequestSeed, - template_context: &TemplateContext, + context: &TemplateContext, ) -> Result { let RequestSeed { id, @@ -146,18 +146,23 @@ impl HttpEngine { .entered(); let future = async { - let recipe = template_context - .collection - .recipes - .try_get_recipe(recipe_id)?; + let recipe = + context.collection.recipes.try_get_recipe(recipe_id)?; // Render everything up front so we can parallelize it let (url, query, headers, authentication, body) = try_join!( - recipe.render_url(template_context), - recipe.render_query(options, template_context), - recipe.render_headers(options, template_context), - recipe.render_authentication(options, template_context), - recipe.render_body(options, template_context), + recipe.render_url(context), + recipe.render_query(options, context), + recipe.render_headers(options, context), + recipe.render_authentication(options, context), + // Body *has* to go last. Bodies are the only component that + // can be streamed. If a profile field is present in both the + // body and elsewhere, it should *never* be streamed. By + // starting every other component first, we ensure the body + // will never be the one to initiate the render for a multi-use + // profile field, meaning it won't get to render as a stream. + // This is kinda fragile but it's also a rare use case. + recipe.render_body(options, context), )?; // Build the reqwest request first, so we can have it do all the @@ -168,7 +173,7 @@ impl HttpEngine { let mut builder = client.request(recipe.method.into(), url).query(&query); if let Some(body) = body { - builder = body.apply(builder); + builder = body.apply(builder).await?; } // Set headers *after* body so the use can override the Content-Type // header that was set if they want to @@ -180,13 +185,12 @@ impl HttpEngine { let request = builder.build()?; Ok((client, request)) }; - let (client, request) = - seed.run_future(future, template_context).await?; + let (client, request) = seed.run_future(future, context).await?; Ok(RequestTicket { record: RequestRecord::new( seed, - template_context.selected_profile.clone(), + context.selected_profile.clone(), &request, self.large_body_size, ) @@ -200,7 +204,7 @@ impl HttpEngine { pub async fn build_url( &self, seed: RequestSeed, - template_context: &TemplateContext, + context: &TemplateContext, ) -> Result { let RequestSeed { id, @@ -212,15 +216,13 @@ impl HttpEngine { .entered(); let future = async { - let recipe = template_context - .collection - .recipes - .try_get_recipe(recipe_id)?; + let recipe = + context.collection.recipes.try_get_recipe(recipe_id)?; // Parallelization! let (url, query) = try_join!( - recipe.render_url(template_context), - recipe.render_query(options, template_context), + recipe.render_url(context), + recipe.render_query(options, context), )?; // Use RequestBuilder so we can offload the handling of query params @@ -231,7 +233,7 @@ impl HttpEngine { .build()?; Ok(request) }; - let request = seed.run_future(future, template_context).await?; + let request = seed.run_future(future, context).await?; Ok(request.url().clone()) } @@ -240,7 +242,7 @@ impl HttpEngine { pub async fn build_body( &self, seed: RequestSeed, - template_context: &TemplateContext, + context: &TemplateContext, ) -> Result, RequestBuildError> { let RequestSeed { id, @@ -252,14 +254,10 @@ impl HttpEngine { .entered(); let future = async { - let recipe = template_context - .collection - .recipes - .try_get_recipe(recipe_id)?; - - let Some(body) = - recipe.render_body(options, template_context).await? - else { + let recipe = + context.collection.recipes.try_get_recipe(recipe_id)?; + + let Some(body) = recipe.render_body(options, context).await? else { return Ok(None); }; @@ -275,7 +273,7 @@ impl HttpEngine { let url = Url::parse("http://localhost").unwrap(); let client = self.get_client(&url); let mut builder = client.request(reqwest::Method::GET, url); - builder = body.apply(builder); + builder = body.apply(builder).await?; let request = builder.build()?; // We just added a body so we know it's present, and we // know it's not a stream. This requires a clone which sucks @@ -292,7 +290,7 @@ impl HttpEngine { } } }; - seed.run_future(future, template_context).await + seed.run_future(future, context).await } /// Render a recipe into a cURL command that will execute the request. @@ -304,7 +302,7 @@ impl HttpEngine { pub async fn build_curl( &self, seed: RequestSeed, - template_context: &TemplateContext, + context: &TemplateContext, ) -> Result { let RequestSeed { id, @@ -316,18 +314,16 @@ impl HttpEngine { .entered(); let future = async { - let recipe = template_context - .collection - .recipes - .try_get_recipe(recipe_id)?; + let recipe = + context.collection.recipes.try_get_recipe(recipe_id)?; // Render everything up front so we can parallelize it let (url, query, headers, authentication, body) = try_join!( - recipe.render_url(template_context), - recipe.render_query(options, template_context), - recipe.render_headers(options, template_context), - recipe.render_authentication(options, template_context), - recipe.render_body(options, template_context), + recipe.render_url(context), + recipe.render_query(options, context), + recipe.render_headers(options, context), + recipe.render_authentication(options, context), + recipe.render_body(options, context), )?; // Buidl the command @@ -342,7 +338,7 @@ impl HttpEngine { } Ok(builder.build()) }; - seed.run_future(future, template_context).await + seed.run_future(future, context).await } /// Get the appropriate client to use for this request. If the request URL's @@ -368,11 +364,11 @@ impl RequestSeed { async fn run_future( &self, future: impl Future>, - template_context: &TemplateContext, + context: &TemplateContext, ) -> Result { let start_time = Utc::now(); future.await.traced().map_err(|error| RequestBuildError { - profile_id: template_context.selected_profile.clone(), + profile_id: context.selected_profile.clone(), recipe_id: self.recipe_id.clone(), id: self.id, start_time, @@ -466,11 +462,11 @@ impl Recipe { /// Render base URL, *excluding* query params async fn render_url( &self, - template_context: &TemplateContext, + context: &TemplateContext, ) -> anyhow::Result { let url = self .url - .render_string(template_context) + .render_string(context) .await .context("Rendering URL")?; url.parse::() @@ -481,7 +477,7 @@ impl Recipe { async fn render_query( &self, options: &BuildOptions, - template_context: &TemplateContext, + context: &TemplateContext, ) -> anyhow::Result> { let iter = self.query_iter().enumerate().filter_map(|(i, (k, _, v))| { @@ -492,12 +488,9 @@ impl Recipe { Some(async move { Ok::<_, anyhow::Error>(( k.to_owned(), - template - .render_string(template_context) - .await - .context(format!( - "Rendering query parameter `{k}`" - ))?, + template.render_string(context).await.context( + format!("Rendering query parameter `{k}`"), + )?, )) }) }); @@ -509,7 +502,7 @@ impl Recipe { async fn render_headers( &self, options: &BuildOptions, - template_context: &TemplateContext, + context: &TemplateContext, ) -> anyhow::Result { let mut headers = HeaderMap::new(); @@ -521,7 +514,7 @@ impl Recipe { let template = options.headers.get(i, value_template)?; Some(async move { - self.render_header(template_context, header, template).await + self.render_header(context, header, template).await }) }, ); @@ -540,12 +533,12 @@ impl Recipe { /// Render a single key/value header async fn render_header( &self, - template_context: &TemplateContext, + context: &TemplateContext, header: &str, value_template: &Template, ) -> anyhow::Result<(HeaderName, HeaderValue)> { let mut value: Vec = value_template - .render_bytes(template_context) + .render_bytes(context) .await .context(format!("Rendering header `{header}`"))? .into(); @@ -574,7 +567,7 @@ impl Recipe { async fn render_authentication( &self, options: &BuildOptions, - template_context: &TemplateContext, + context: &TemplateContext, ) -> anyhow::Result>> { let authentication = options .authentication @@ -582,28 +575,29 @@ impl Recipe { .or(self.authentication.as_ref()); match authentication { Some(Authentication::Basic { username, password }) => { - let (username, password) = try_join!( - async { - username - .render_string(template_context) + let (username, password) = + try_join!( + async { + username + .render_string(context) + .await + .context("Rendering username") + }, + async { + OptionFuture::from(password.as_ref().map( + |password| password.render_string(context), + )) .await - .context("Rendering username") - }, - async { - OptionFuture::from(password.as_ref().map(|password| { - password.render_string(template_context) - })) - .await - .transpose() - .context("Rendering password") - }, - )?; + .transpose() + .context("Rendering password") + }, + )?; Ok(Some(Authentication::Basic { username, password })) } Some(Authentication::Bearer { token }) => { let token = token - .render_string(template_context) + .render_string(context) .await .context("Rendering bearer token")?; Ok(Some(Authentication::Bearer { token })) @@ -616,7 +610,7 @@ impl Recipe { async fn render_body( &self, options: &BuildOptions, - template_context: &TemplateContext, + context: &TemplateContext, ) -> anyhow::Result> { let Some(body) = options.body.as_ref().or(self.body.as_ref()) else { return Ok(None); @@ -624,14 +618,10 @@ impl Recipe { let rendered = match body { RecipeBody::Raw(body) => RenderedBody::Raw( - body.render_bytes(template_context) - .await - .context("Rendering body")?, + body.render_bytes(context).await.context("Rendering body")?, ), RecipeBody::Json(json) => RenderedBody::Json( - json.render(template_context) - .await - .context("Rendering body")?, + json.render(context).await.context("Rendering body")?, ), RecipeBody::FormUrlencoded(fields) => { let iter = fields.iter().enumerate().filter_map( @@ -639,12 +629,10 @@ impl Recipe { let template = options.form_fields.get(i, value_template)?; Some(async move { - let value = template - .render_string(template_context) - .await - .context(format!( - "Rendering form field `{field}`" - ))?; + let value = + template.render_string(context).await.context( + format!("Rendering form field `{field}`"), + )?; Ok::<_, anyhow::Error>((field.clone(), value)) }) }, @@ -658,14 +646,13 @@ impl Recipe { let template = options.form_fields.get(i, value_template)?; Some(async move { - let value = template - .render_bytes(template_context) - .await - .context(format!( - "Rendering form field `{field}`" - ))? - .into(); - Ok::<_, anyhow::Error>((field.clone(), value)) + let value = + template.render_stream(context).await.context( + format!("Rendering form field `{field}`"), + )?; + + let part = Self::stream(value); + Ok::<_, anyhow::Error>((field.clone(), part)) }) }, ); @@ -675,6 +662,18 @@ impl Recipe { }; Ok(Some(rendered)) } + + /// Convert a template stream to a multipart form part + fn stream(stream: Stream) -> FormPart { + match stream { + Stream::Value(value) => FormPart::Bytes(value.into_bytes()), + // If the stream is a file, we can pass that directly to reqwest + Stream::Stream { + metadata: StreamMetadata::File { path }, + .. + } => FormPart::File(path), + } + } } impl Authentication { @@ -699,23 +698,56 @@ enum RenderedBody { /// URL-encoded FormUrlencoded(Vec<(String, String)>), /// Field:value mapping. Values can be arbitrary bytes - FormMultipart(Vec<(String, Vec)>), + FormMultipart(Vec<(String, FormPart)>), } impl RenderedBody { - fn apply(self, builder: RequestBuilder) -> RequestBuilder { + /// Add this body to the builder + async fn apply( + self, + builder: RequestBuilder, + ) -> anyhow::Result { // Set body. The variant tells us _how_ to set it match self { - RenderedBody::Raw(bytes) => builder.body(bytes), - RenderedBody::Json(json) => builder.json(&json), - RenderedBody::FormUrlencoded(fields) => builder.form(&fields), + RenderedBody::Raw(bytes) => Ok(builder.body(bytes)), + RenderedBody::Json(json) => Ok(builder.json(&json)), + RenderedBody::FormUrlencoded(fields) => Ok(builder.form(&fields)), RenderedBody::FormMultipart(fields) => { let mut form = Form::new(); - for (field, value) in fields { - let part = Part::bytes(value); - form = form.part(field, part); + + // Use a static boundary in tests for assertions. Test-only + // code can be dangerous, but in non-test we're just using the + // default library behavior. There's also plenty of tests in + // other crates that hit this code path, and cfg(test) won't + // be enabled for those. + if cfg!(test) { + form.set_boundary("BOUNDARY"); } - builder.multipart(form) + + for (field, part) in fields { + form = form.part(field, part.into_reqwest().await?); + } + Ok(builder.multipart(form)) + } + } + } +} + +/// Form field value for a multipart form +#[derive(Debug)] +pub enum FormPart { + /// Data will be raw bytes + Bytes(Bytes), + /// Data will be streamed from a file. The path should be absolute + File(PathBuf), +} + +impl FormPart { + async fn into_reqwest(self) -> anyhow::Result { + match self { + Self::Bytes(bytes) => Ok(Part::bytes(>::from(bytes))), + Self::File(path) => { + Part::file(path).await.map_err(anyhow::Error::from) } } } diff --git a/crates/core/src/http/curl.rs b/crates/core/src/http/curl.rs index 82434c7a..cac346b0 100644 --- a/crates/core/src/http/curl.rs +++ b/crates/core/src/http/curl.rs @@ -1,6 +1,6 @@ use crate::{ collection::Authentication, - http::{HttpMethod, RenderedBody}, + http::{FormPart, HttpMethod, RenderedBody}, }; use anyhow::Context; use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue}; @@ -100,8 +100,15 @@ impl CurlBuilder { } } RenderedBody::FormMultipart(form) => { - for (field, value) in form { - let value = as_text(value)?; + for (field, part) in form { + let value = match part { + FormPart::Bytes(bytes) => as_text(bytes)?, + // Use curl's file path syntax + FormPart::File(path) => { + let path = path.to_string_lossy(); + &format!("@{path}") + } + }; write!(&mut self.command, " -F '{field}={value}'").unwrap(); } } diff --git a/crates/core/src/http/models.rs b/crates/core/src/http/models.rs index de1dc9b8..6948df71 100644 --- a/crates/core/src/http/models.rs +++ b/crates/core/src/http/models.rs @@ -435,6 +435,8 @@ pub struct RequestRecord { pub headers: HeaderMap, /// Body content as bytes. This should be decoded as needed. This will /// **not** be populated for bodies that are above the "large" threshold. + /// - `Some(empty bytes)`: There was no body (e.g. GET request) + /// - `None`: Body couldn't be stored (stream or too large) pub body: Option, } diff --git a/crates/core/src/http/tests.rs b/crates/core/src/http/tests.rs index d236d90b..7db925ea 100644 --- a/crates/core/src/http/tests.rs +++ b/crates/core/src/http/tests.rs @@ -7,19 +7,18 @@ use crate::{ }; use indexmap::{IndexMap, indexmap}; use pretty_assertions::assert_eq; -use regex::Regex; use reqwest::{Body, StatusCode, header}; use rstest::rstest; use serde_json::json; use slumber_util::{Factory, assert_err, test_data_dir}; -use std::ptr; +use std::{path, ptr}; use wiremock::{Mock, MockServer, ResponseTemplate, matchers}; /// Create a template context. Take a set of extra recipes to add to the created /// collection -fn template_context(recipe: Recipe) -> TemplateContext { +fn template_context(recipe: Recipe, host: Option<&str>) -> TemplateContext { let profile_data = indexmap! { - "host".into() => "http://localhost".into(), + "host".into() => host.unwrap_or("http://localhost").into(), "mode".into() => "sudo".into(), "user_id".into() => "1".into(), "group_id".into() => "3".into(), @@ -28,6 +27,9 @@ fn template_context(recipe: Recipe) -> TemplateContext { "token".into() => "tokenzzz".into(), "test_data_dir".into() => test_data_dir().to_str().unwrap().into(), "prompt".into() => "{{ prompt() }}".into(), + "stream".into() => "{{ file('data.json') }}".into(), + // Streamed value that we can use to test deduping + "stream_prompt".into() => "{{ file(concat([prompt(), '.txt'])) }}".into(), "error".into() => "{{ fake_fn() }}".into(), }; let profile = Profile { @@ -36,6 +38,7 @@ fn template_context(recipe: Recipe) -> TemplateContext { }; TemplateContext { prompter: Box::new(TestPrompter::new(["first", "second"])), + root_dir: test_data_dir(), ..TemplateContext::factory((by_id([profile]), by_id([recipe]))) } } @@ -91,7 +94,7 @@ async fn test_build_request(http_engine: HttpEngine) { ..Recipe::factory(()) }; let recipe_id = recipe.id.clone(); - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed(&context, BuildOptions::default()); let ticket = http_engine.build(seed, &context).await.unwrap(); @@ -144,7 +147,7 @@ async fn test_build_url(http_engine: HttpEngine) { }, ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed(&context, BuildOptions::default()); let url = http_engine.build_url(seed, &context).await.unwrap(); @@ -171,10 +174,13 @@ async fn test_build_body( #[case] body: RecipeBody, #[case] expected_body: &[u8], ) { - let context = template_context(Recipe { - body: Some(body), - ..Recipe::factory(()) - }); + let context = template_context( + Recipe { + body: Some(body), + ..Recipe::factory(()) + }, + None, + ); let seed = seed(&context, BuildOptions::default()); let body = http_engine.build_body(seed, &context).await.unwrap(); @@ -213,7 +219,7 @@ async fn test_authentication( ..Recipe::factory(()) }; let recipe_id = recipe.id.clone(); - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed(&context, BuildOptions::default()); let ticket = http_engine.build(seed, &context).await.unwrap(); @@ -245,36 +251,32 @@ async fn test_authentication( #[case::json( RecipeBody::json(json!({"group_id": "{{ group_id }}"})).unwrap(), None, - Some(r#"{"group_id":"3"}"#), - "^application/json$", - &[], + "application/json", + r#"{"group_id":"3"}"#, )] // Content-Type has been overridden by an explicit header #[case::json_content_type_override( RecipeBody::json(json!({"group_id": "{{ group_id }}"})).unwrap(), Some("text/plain"), - Some(r#"{"group_id":"3"}"#), - "^text/plain$", - &[], + "text/plain", + r#"{"group_id":"3"}"#, )] #[case::json_unpack( // Single-chunk templates should get unpacked to the actual JSON value // instead of returned as a string RecipeBody::json(json!("{{ [1,2,3] }}")).unwrap(), None, - Some("[1,2,3]"), - "^application/json$", - &[], + "application/json", + "[1,2,3]", )] #[case::json_no_unpack( // This template doesn't get unpacked because it is multiple chunks RecipeBody::json(json!("no: {{ [1,2,3] }}")).unwrap(), None, + "application/json", // Spaces are added because this uses the template Value stringification // instead of serde_json stringification - Some(r#""no: [1, 2, 3]""#), - "^application/json$", - &[], + r#""no: [1, 2, 3]""#, )] #[case::json_string_from_file( // JSON data is loaded as a string and NOT unpacked. file() returns bytes @@ -283,9 +285,8 @@ async fn test_authentication( "{{ file(concat([test_data_dir, '/data.json'])) | trim() }}" )).unwrap(), None, - Some(r#""{ \"a\": 1, \"b\": 2 }""#), - "^application/json$", - &[], + "application/json", + r#""{ \"a\": 1, \"b\": 2 }""#, )] #[case::json_from_file_parsed( // Pipe to json_parse() to parse it @@ -293,9 +294,8 @@ async fn test_authentication( "{{ file(concat([test_data_dir, '/data.json'])) | json_parse() }}" )).unwrap(), None, - Some(r#"{"a":1,"b":2}"#), - "^application/json$", - &[], + "application/json", + r#"{"a":1,"b":2}"#, )] #[case::form_urlencoded( RecipeBody::FormUrlencoded(indexmap! { @@ -303,101 +303,125 @@ async fn test_authentication( "token".into() => "{{ token }}".into() }), None, - Some("user_id=1&token=tokenzzz"), - "^application/x-www-form-urlencoded$", - &[], + "application/x-www-form-urlencoded", + "user_id=1&token=tokenzzz", )] // reqwest sets the content type when initializing the body, so make sure // that doesn't override the user's value #[case::form_urlencoded_content_type_override( RecipeBody::FormUrlencoded(Default::default()), Some("text/plain"), - Some(""), - "^text/plain$", - &[], + "text/plain", + "" )] #[case::form_multipart( RecipeBody::FormMultipart(indexmap! { "user_id".into() => "{{ user_id }}".into(), - "binary".into() => invalid_utf8(), }), None, - // multipart bodies are automatically turned into streams by reqwest, - // and we don't store stream bodies atm - // https://github.com/LucasPickering/slumber/issues/256 + // Normally the boundary is random, but we make it static for testing + "multipart/form-data; boundary=BOUNDARY", + "--BOUNDARY\r +Content-Disposition: form-data; name=\"user_id\"\r +\r +1\r +--BOUNDARY--\r +", +)] +#[case::form_multipart_file( + RecipeBody::FormMultipart(indexmap! { + "file".into() => + "{{ file(concat([test_data_dir, '/data.json'])) }}".into(), + }), + None, + "multipart/form-data; boundary=BOUNDARY", + "--BOUNDARY\r +Content-Disposition: form-data; name=\"file\"; filename=\"data.json\"\r +Content-Type: application/json\r +\r +{ \"a\": 1, \"b\": 2 }\r +--BOUNDARY--\r +", +)] +#[case::form_multipart_file_not_streamed( + RecipeBody::FormMultipart(indexmap! { + // This file does *not* get streamed because it's not a single-chunk + // template + "file".into() => + "data: {{ file(concat([test_data_dir, '/data.json'])) }}".into(), + }), None, - "^multipart/form-data; boundary=[a-f0-9-]{67}$", - &[("content-length", "321")], + "multipart/form-data; boundary=BOUNDARY", + "--BOUNDARY\r +Content-Disposition: form-data; name=\"file\"\r +\r +data: { \"a\": 1, \"b\": 2 }\r +--BOUNDARY--\r +", )] #[tokio::test] async fn test_structured_body( http_engine: HttpEngine, #[case] body: RecipeBody, #[case] content_type: Option<&str>, - #[case] expected_body: Option<&'static str>, - // For multipart bodies, the content type includes random content - #[case] expected_content_type: Regex, - #[case] extra_headers: &[(&str, &str)], + #[case] expected_content_type: &str, + #[case] expected_body: &'static str, ) { + // We're going to actually send the request so we can get the full body. + // Reqwest doesn't expose the body for multipart requests because it may be + // streamed + let server = MockServer::start().await; + Mock::given(matchers::method("POST")) + .and(matchers::path("/post")) + .respond_with(move |request: &wiremock::Request| { + // Echo back the Content-Type and body so we can assert on it + ResponseTemplate::new(StatusCode::OK) + .append_header( + header::CONTENT_TYPE, + request + .headers + .get(header::CONTENT_TYPE) + .expect("Missing Content-Type header"), + ) + .set_body_bytes(request.body.clone()) + }) + .mount(&server) + .await; + let headers = if let Some(content_type) = content_type { indexmap! {"content-type".into() => content_type.into()} } else { IndexMap::default() }; let recipe = Recipe { + method: HttpMethod::Post, + url: "{{ host }}/post".into(), headers, body: Some(body), ..Recipe::factory(()) }; - let recipe_id = recipe.id.clone(); - let context = template_context(recipe); + let context = template_context(recipe, Some(&server.uri())); let seed = seed(&context, BuildOptions::default()); let ticket = http_engine.build(seed, &context).await.unwrap(); + let exchange = ticket.send().await.unwrap(); - // Assert on the actual built request *and* the record, to make sure - // they're consistent with each other - let actual_content_type = ticket - .request - .headers() + // Mocker echoes the Content-Type header and body, assert on them + assert_eq!(exchange.response.status, StatusCode::OK); + let actual_content_type = exchange + .response + .headers .get(header::CONTENT_TYPE) .expect("Missing Content-Type header") .to_str() .expect("Invalid Content-Type header"); - assert!( - expected_content_type.is_match(actual_content_type), - "Expected Content-Type `{actual_content_type}` \ - to match `{expected_content_type}`" - ); - assert_eq!( - ticket - .request - .body() - .and_then(Body::as_bytes) - // We know all the bodies are UTF-8. This gives better errors - .map(|bytes| std::str::from_utf8(bytes).unwrap()), - expected_body - ); - - assert_eq!( - *ticket.record, - RequestRecord { - id: ticket.record.id, - body: expected_body.map(Bytes::from), - // Use the actual content type here, because the expected - // content type maybe be a pattern and we need an exactl string. - // We checked actual=expected above so this is fine - headers: header_map( - [("content-type", actual_content_type)] - .into_iter() - .chain(extra_headers.iter().copied()) - ), - ..RequestRecord::factory(( - Some(context.collection.first_profile_id().clone()), - recipe_id - )) - } - ); + assert_eq!(actual_content_type, expected_content_type); + if let Some(body) = exchange.response.body.text() { + assert_eq!(body, expected_body); + } else { + // We expect all bodies to be text + panic!("Non UTF-8 body: {:?}", exchange.response.body.bytes()); + } } /// Test overriding authentication in BuildOptions @@ -411,7 +435,7 @@ async fn test_override_authentication(http_engine: HttpEngine) { }), ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed( &context, @@ -451,7 +475,7 @@ async fn test_override_headers(http_engine: HttpEngine) { }, ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed( &context, BuildOptions { @@ -495,7 +519,7 @@ async fn test_override_query(http_engine: HttpEngine) { }, ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed( &context, BuildOptions { @@ -525,7 +549,7 @@ async fn test_override_body_raw(http_engine: HttpEngine) { body: Some(RecipeBody::Raw("{{ username }}".into())), ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed( &context, BuildOptions { @@ -555,7 +579,7 @@ async fn test_override_body_json(http_engine: HttpEngine) { ), ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed( &context, @@ -594,7 +618,7 @@ async fn test_override_body_form(http_engine: HttpEngine) { ..Recipe::factory(()) }; let recipe_id = recipe.id.clone(); - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed( &context, @@ -631,28 +655,88 @@ async fn test_override_body_form(http_engine: HttpEngine) { /// Using the same profile field in two different templates should be /// deduplicated, so that the expression is only evaluated once #[rstest] +#[case::url_body( + // Dedupe happens within a single template AND across templates + "{{ host }}/{{ prompt }}/{{ prompt }}", + "{{ prompt }}".into(), + "first", +)] +#[case::url_multipart_body( + "{{ host }}/{{ stream_prompt }}/{{ stream_prompt }}", + // The body should *not* be streamed because is cached from the URL. This + // works by rendering the body last + RecipeBody::FormMultipart(indexmap!{ + "file".into() => "{{ stream_prompt }}".into(), + }), + "--BOUNDARY\r +Content-Disposition: form-data; name=\"file\"\r +\r +first\r +--BOUNDARY--\r +", +)] +#[case::multipart_body_multiple( + "{{ host }}/first/first", + // Field is used twice in the same body. The stream source gets cloned so + // they reference the same file but are both streamed + RecipeBody::FormMultipart(indexmap!{ + "f1".into() => "{{ stream_prompt }}".into(), + "f2".into() => "{{ stream_prompt }}".into(), + }), + "--BOUNDARY\r +Content-Disposition: form-data; name=\"f1\"; filename=\"first.txt\"\r +Content-Type: text/plain\r +\r +first\r +--BOUNDARY\r +Content-Disposition: form-data; name=\"f2\"; filename=\"first.txt\"\r +Content-Type: text/plain\r +\r +first\r +--BOUNDARY--\r +", +)] #[tokio::test] -async fn test_profile_duplicate(http_engine: HttpEngine) { +async fn test_profile_duplicate( + http_engine: HttpEngine, + #[case] url: Template, + #[case] body: RecipeBody, + #[case] expected_body: &str, +) { + // We're going to actually send the request so we can get the full body. + // Reqwest doesn't expose the body for multipart requests because it may be + // streamed + let server = MockServer::start().await; + let host = server.uri(); + Mock::given(matchers::method("POST")) + .and(matchers::path("/first/first")) + .respond_with(|request: &wiremock::Request| { + ResponseTemplate::new(StatusCode::OK) + .set_body_bytes(request.body.clone()) + }) + .mount(&server) + .await; + let recipe = Recipe { method: HttpMethod::Post, - url: "{{ host }}/{{ prompt }}/{{ prompt }}".into(), - body: Some("{{ prompt }}".into()), + url, + body: Some(body), ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, Some(&host)); let seed = seed(&context, BuildOptions::default()); let ticket = http_engine.build(seed, &context).await.unwrap(); - let expected_url: Url = "http://localhost/first/first".parse().unwrap(); - let expected_body = "first"; + // Make sure the URL rendered correctly before sending + let expected_url: Url = format!("{host}/first/first").parse().unwrap(); + let exchange = ticket.send().await.unwrap(); - let request = &ticket.request; - assert_eq!(request.url(), &expected_url); + assert_eq!(exchange.response.status, StatusCode::OK); + assert_eq!(exchange.request.url, expected_url); assert_eq!( - request - .body() - .and_then(|body| std::str::from_utf8(body.as_bytes()?).ok()), + // The response body is the same as the request body + std::str::from_utf8(exchange.response.body.bytes()).ok(), Some(expected_body) ); } @@ -669,7 +753,7 @@ async fn test_profile_duplicate_error(http_engine: HttpEngine) { ..Recipe::factory(()) }; let recipe_id = recipe.id.clone(); - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = RequestSeed::new(recipe_id, BuildOptions::default()); assert_err!( @@ -697,10 +781,10 @@ async fn test_send_request(http_engine: HttpEngine) { .await; let recipe = Recipe { - url: format!("{host}/get", host = server.uri()).as_str().into(), + url: "{{ host }}/get".into(), ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, Some(&server.uri())); let seed = seed(&context, BuildOptions::default()); // Build+send the request @@ -745,7 +829,7 @@ async fn test_render_headers_strip() { }, ..Recipe::factory(()) }; - let context = template_context(Recipe::factory(())); + let context = template_context(Recipe::factory(()), None); let rendered = recipe .render_headers(&BuildOptions::default(), &context) .await @@ -788,7 +872,7 @@ async fn test_build_curl(http_engine: HttpEngine) { }, ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed(&context, BuildOptions::default()); let command = http_engine.build_curl(seed, &context).await.unwrap(); @@ -829,7 +913,7 @@ async fn test_build_curl_authentication( authentication: Some(authentication), ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed(&context, BuildOptions::default()); let command = http_engine.build_curl(seed, &context).await.unwrap(); let expected_command = format!( @@ -859,6 +943,12 @@ async fn test_build_curl_authentication( }), "-F 'user_id=1' -F 'token=tokenzzz'" )] +#[case::form_multipart_file( + RecipeBody::FormMultipart(indexmap! { + "file".into() => "{{ file('data.json') }}".into(), + }), + "-F 'file=@{ROOT}{SEP}data.json'" +)] #[tokio::test] async fn test_build_curl_body( http_engine: HttpEngine, @@ -869,10 +959,14 @@ async fn test_build_curl_body( body: Some(body), ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, None); let seed = seed(&context, BuildOptions::default()); let command = http_engine.build_curl(seed, &context).await.unwrap(); + let expected_arguments = expected_arguments + // Dynamic replacements for system-specific contents + .replace("{ROOT}", &context.root_dir.to_string_lossy()) + .replace("{SEP}", path::MAIN_SEPARATOR_STR); let expected_command = format!("curl -XGET --url 'http://localhost/url' {expected_arguments}"); assert_eq!(command, expected_command); @@ -909,10 +1003,10 @@ async fn test_follow_redirects( ..Default::default() }); let recipe = Recipe { - url: format!("{host}/redirect").as_str().into(), + url: "{{ host }}/redirect".into(), ..Recipe::factory(()) }; - let context = template_context(recipe); + let context = template_context(recipe, Some(&host)); let seed = seed(&context, BuildOptions::default()); // Build+send the request diff --git a/crates/core/src/render.rs b/crates/core/src/render.rs index d8db0be5..ce2111a0 100644 --- a/crates/core/src/render.rs +++ b/crates/core/src/render.rs @@ -12,7 +12,6 @@ use crate::{ collection::{Collection, Profile, ProfileId, RecipeId}, http::{Exchange, RequestSeed, ResponseRecord, TriggeredRequestError}, render::functions::RequestTrigger, - util::{FutureCache, FutureCacheOutcome}, }; use anyhow::anyhow; use async_trait::async_trait; @@ -21,7 +20,9 @@ use derive_more::From; use indexmap::IndexMap; use itertools::Itertools; use serde_json_path::JsonPath; -use slumber_template::{Arguments, Identifier, RenderError}; +use slumber_template::{ + Arguments, FieldCache, Identifier, RenderError, Stream, +}; use slumber_util::ResultTraced; use std::{ fmt::Debug, io, iter, path::PathBuf, process::ExitStatus, sync::Arc, @@ -82,102 +83,91 @@ impl TemplateContext { }); } - // Defer loading the most recent exchange until we know we'll need it - let get_latest = || async { - self.http_provider - .get_latest_request(self.selected_profile.as_ref(), recipe_id) - .await - .map_err(FunctionError::Database) - }; - - // Helper to execute the request, if triggered - let send_request = || async { - // There are 3 different ways we can generate the build optoins: - // 1. Default (enable all query params/headers) - // 2. Load from UI state for both TUI and CLI - // 3. Load from UI state for TUI, enable all for CLI - // These all have their own issues: - // 1. Triggered request doesn't necessarily match behavior if user - // were to execute the request themself - // 2. CLI behavior is silently controlled by UI state - // 3. TUI and CLI behavior may not match - // All 3 options are unintuitive in some way, but 1 is the easiest - // to implement so I'm going with that for now. - let build_options = Default::default(); - - self.http_provider - .send_request( - RequestSeed::new(recipe_id.clone(), build_options), - self, - ) - .await - .map_err(|error| FunctionError::Trigger { - recipe_id: recipe_id.clone(), - error, - }) - }; - let exchange = match trigger { - RequestTrigger::Never => { - get_latest().await?.ok_or(FunctionError::ResponseMissing)? - } + RequestTrigger::Never => self + .get_latest_cached(recipe_id) + .await? + .ok_or(FunctionError::ResponseMissing)?, RequestTrigger::NoHistory => { // If a exchange is present in history, use that. If not, fetch - if let Some(exchange) = get_latest().await? { + if let Some(exchange) = + self.get_latest_cached(recipe_id).await? + { exchange } else { - send_request().await? + self.send_request(recipe_id).await? } } - RequestTrigger::Expire { duration } => match get_latest().await? { - Some(exchange) - if exchange.end_time + duration.inner() >= Utc::now() => - { - exchange + RequestTrigger::Expire { duration } => { + match self.get_latest_cached(recipe_id).await? { + Some(exchange) + if exchange.end_time + duration.inner() + >= Utc::now() => + { + exchange + } + _ => self.send_request(recipe_id).await?, } - _ => send_request().await?, - }, - RequestTrigger::Always => send_request().await?, + } + RequestTrigger::Always => self.send_request(recipe_id).await?, }; Ok(exchange.response) } + + /// Get the most recent cached exchange for the given recipe + async fn get_latest_cached( + &self, + recipe_id: &RecipeId, + ) -> Result, FunctionError> { + self.http_provider + .get_latest_request(self.selected_profile.as_ref(), recipe_id) + .await + .map_err(FunctionError::Database) + } + + /// Send a request for the recipe and return the exchange + async fn send_request( + &self, + recipe_id: &RecipeId, + ) -> Result { + // There are 3 different ways we can generate the build optoins: + // 1. Default (enable all query params/headers) + // 2. Load from UI state for both TUI and CLI + // 3. Load from UI state for TUI, enable all for CLI + // These all have their own issues: + // 1. Triggered request doesn't necessarily match behavior if user + // were to execute the request themself + // 2. CLI behavior is silently controlled by UI state + // 3. TUI and CLI behavior may not match + // All 3 options are unintuitive in some way, but 1 is the easiest + // to implement so I'm going with that for now. + let build_options = Default::default(); + + self.http_provider + .send_request( + RequestSeed::new(recipe_id.clone(), build_options), + self, + ) + .await + .map_err(|error| FunctionError::Trigger { + recipe_id: recipe_id.clone(), + error, + }) + } } impl slumber_template::Context for TemplateContext { - async fn get( + async fn get_field( &self, - field: &slumber_template::Identifier, - ) -> Result { + field: &Identifier, + ) -> Result { // Check overrides first. The override value is NOT treated as a // template if let Some(value) = self.overrides.get(field.as_str()) { return Ok(value.clone().into()); } - // Check the field cache to see if this value is already being computed - // somewhere else. If it is, we'll block on that and re-use the result. - // If not, we get a guard back, meaning we're responsible for the - // computation. At the end, we'll write back to the guard so everyone - // else can copy our homework. - let cache = &self.state.profile_results; - let guard = match cache.get_or_init(field.clone()).await { - FutureCacheOutcome::Hit(value) => return Ok(value), - FutureCacheOutcome::Miss(guard) => guard, - // The future responsible for writing to the guard failed. All the - // places that render multiple templates use try_join so one failure - // should cause all other futures to stop being polled. So - // theoretically this error will never be seen, but we return an - // error instead of panicking just to be safe. We could clone the - // error and share it between calls, but implementing Clone for an - // error type can be pretty annoying. - FutureCacheOutcome::NoResponse => { - return Err(RenderError::CacheFailed { - field: field.clone(), - }); - } - }; - // Then check the current profile let template = self .current_profile() @@ -187,24 +177,26 @@ impl slumber_template::Context for TemplateContext { })?; // Render the nested template - let bytes = template.render_bytes(self).await.map_err(|error| { + template.render_stream(self).await.map_err(|error| { + // We *could* just return the error, but wrap it to give additional + // context FunctionError::ProfileNested { field: field.clone(), error, } - })?; - let value: slumber_template::Value = bytes.into(); + .into() + }) + } - // Store value in the cache so other instances of this chain can use it - guard.set(value.clone()); - Ok(value) + fn field_cache(&self) -> &FieldCache { + &self.state.field_cache } async fn call( &self, function_name: &Identifier, arguments: Arguments<'_, Self>, - ) -> Result { + ) -> Result { match function_name.as_str() { "base64" => functions::base64(arguments), "boolean" => functions::boolean(arguments), @@ -212,7 +204,7 @@ impl slumber_template::Context for TemplateContext { "concat" => functions::concat(arguments), "debug" => functions::debug(arguments), "env" => functions::env(arguments), - "file" => functions::file(arguments).await, + "file" => functions::file(arguments), "float" => functions::float(arguments), "integer" => functions::integer(arguments), "json_parse" => functions::json_parse(arguments), @@ -290,7 +282,7 @@ pub struct RenderGroupState { /// times. If a field fails to render, the guard holder should drop the /// guard without entering a result. This will kill the entire render so /// other renderers of that field will be cancelled. - profile_results: FutureCache, + field_cache: FieldCache, } /// An abstraction that provides behavior for chained HTTP requests. This diff --git a/crates/core/src/render/functions.rs b/crates/core/src/render/functions.rs index f7fe57cb..52fdb7b5 100644 --- a/crates/core/src/render/functions.rs +++ b/crates/core/src/render/functions.rs @@ -7,13 +7,14 @@ use crate::{ use base64::{Engine, prelude::BASE64_STANDARD}; use bytes::Bytes; use derive_more::FromStr; +use futures::FutureExt; use itertools::Itertools; use serde::{Deserialize, de::value::SeqDeserializer}; use serde_json_path::NodeList; use slumber_macros::template; use slumber_template::{ - Expected, TryFromValue, Value, ValueError, WithValue, - impl_try_from_value_str, + Expected, Stream, StreamMetadata, TryFromValue, Value, ValueError, + WithValue, impl_try_from_value_str, }; use slumber_util::{TimeSpan, paths::expand_home}; use std::{env, fmt::Debug, path::PathBuf, process::Stdio, sync::Arc}; @@ -253,15 +254,25 @@ pub fn env(variable: String) -> String { /// {{ file("config.json") }} => Contents of config.json file /// ``` #[template] -pub async fn file( - #[context] context: &TemplateContext, - path: String, -) -> Result { +pub fn file(#[context] context: &TemplateContext, path: String) -> Stream { let path = context.root_dir.join(expand_home(PathBuf::from(path))); - let bytes = fs::read(&path) - .await - .map_err(|error| FunctionError::File { path, error })?; - Ok(bytes.into()) + // Return the file as a stream. If streaming isn't available here, it will + // be resolved immediately instead + Stream::Stream { + metadata: StreamMetadata::File { path: path.clone() }, + f: Arc::new(move || { + // This possible this function gets called multiple times, and each + // future has to be 'static + let path = path.clone(); + async move { + fs::read(&path) + .await + .map(Bytes::from) + .map_err(|error| FunctionError::File { path, error }.into()) + } + .boxed() + }), + } } /// Convert a value to a float diff --git a/crates/core/src/render/tests.rs b/crates/core/src/render/tests.rs index 9f4a8257..1dfdf50a 100644 --- a/crates/core/src/render/tests.rs +++ b/crates/core/src/render/tests.rs @@ -15,9 +15,10 @@ use chrono::{DateTime, Utc}; use indexmap::{IndexMap, indexmap}; use rstest::rstest; use serde_json::json; -use slumber_template::{Expression, Literal, Template}; +use slumber_template::{Expression, Literal, Stream, Template, Value}; use slumber_util::{ - Factory, TempDir, assert_result, paths::get_repo_root, temp_dir, + Factory, TempDir, assert_matches, assert_result, paths::get_repo_root, + temp_dir, test_data_dir, }; use std::time::Duration; use tokio::fs; @@ -114,7 +115,7 @@ async fn test_boolean(#[case] input: Expression, #[case] expected: bool) { let template = Template::function_call("boolean", [input], []); assert_result( template.render_value(&TemplateContext::factory(())).await, - Ok(slumber_template::Value::Boolean(expected)), + Ok(Value::Boolean(expected)), ); } @@ -313,7 +314,7 @@ async fn test_float( let template = Template::function_call("float", [input], []); assert_result( template.render_value(&TemplateContext::factory(())).await, - expected.map(slumber_template::Value::from), + expected.map(Value::from), ); } @@ -341,7 +342,7 @@ async fn test_integer( let template = Template::function_call("integer", [input], []); assert_result( template.render_value(&TemplateContext::factory(())).await, - expected.map(slumber_template::Value::from), + expected.map(Value::from), ); } @@ -355,7 +356,7 @@ async fn test_integer( #[tokio::test] async fn test_json_parse( #[case] json: &'static [u8], - #[case] expected: Result, + #[case] expected: Result, ) { let template = Template::function_call("json_parse", [json.into()], []); assert_result( @@ -769,6 +770,47 @@ async fn test_trim( ); } +/// Test different conditions where streaming is/isn't allowed +#[rstest] +#[case::stream_direct("{{ file('data.json') }}", true, true)] +#[case::stream_piped("{{ 'data.json' | file() }}", true, true)] +#[case::stream_via_profile("{{ file_field }}", true, true)] +#[case::no_stream_direct("{{ file('data.json') }}", false, false)] +#[case::no_stream_via_profile("{{ file_field }}", false, false)] +#[case::no_stream_not_root("data: {{ file('data.json') }}", true, false)] +#[case::no_stream_not_root_via_profile("data: {{ file_field }}", true, false)] +#[tokio::test] +async fn test_stream( + #[case] template: Template, + #[case] can_stream: bool, + #[case] expect_stream: bool, +) { + // Put some profile data in the context + let profile_data = indexmap! { + "file_field".into() => "{{ file('data.json') }}".into(), + }; + let profile = Profile { + data: profile_data, + ..Profile::factory(()) + }; + let context = TemplateContext { + root_dir: test_data_dir(), + ..TemplateContext::factory((by_id([profile]), IndexMap::new())) + }; + + let stream = if can_stream { + template.render_stream(&context).await + } else { + template.render_value(&context).await.map(Stream::Value) + } + .unwrap(); + if expect_stream { + assert_matches!(stream, Stream::Stream { .. }); + } else { + assert_matches!(stream, Stream::Value(_)); + } +} + /// Bytes that can't be converted to a string fn invalid_utf8() -> &'static [u8] { b"\xc3\x28" diff --git a/crates/core/src/util.rs b/crates/core/src/util.rs index b6948621..51e08ccf 100644 --- a/crates/core/src/util.rs +++ b/crates/core/src/util.rs @@ -2,15 +2,7 @@ use derive_more::Display; use dialoguer::Confirm; -use std::{ - collections::{HashMap, hash_map::Entry}, - fmt, - hash::Hash, - ops::DerefMut, - sync::Arc, -}; -use tokio::sync::{Mutex, OwnedRwLockWriteGuard, RwLock}; -use tracing::error; +use std::fmt; /// Show the user a confirmation prompt pub fn confirm(prompt: impl Into) -> bool { @@ -50,94 +42,3 @@ impl Display for MaybeStr<'_> { } } } - -/// A cache of values that either have been computed, or are asynchronously -/// being computed. This allows multiple computers of the same async values to -/// deduplicate their work. -#[derive(Debug)] -pub(crate) struct FutureCache { - /// Cache each value by key. The outer mutex will only be held open for as - /// long as it takes to check if the value is in the cache or not. The - /// inner lock will be blocked on until the value is available. - cache: Mutex>>>>, -} - -impl FutureCache { - /// Get a value from the cache, or if not present, insert a placeholder - /// value and return a guard that can be used to insert the completed value - /// later. The placeholder will tell subsequent accessors of this key that - /// the value is being computed, and will be present later. If the - /// placeholder is present and the final value being computed, **this block - /// will not return until the value is available**. - pub async fn get_or_init(&self, key: K) -> FutureCacheOutcome { - let mut cache = self.cache.lock().await; - match cache.entry(key) { - Entry::Occupied(entry) => { - let lock = Arc::clone(entry.get()); - drop(cache); // Drop the outer lock as quickly as possible - - match &*lock.read_owned().await { - Some(value) => FutureCacheOutcome::Hit(value.clone()), - None => FutureCacheOutcome::NoResponse, - } - } - Entry::Vacant(entry) => { - let lock = Arc::new(RwLock::new(None)); - entry.insert(Arc::clone(&lock)); - // Grab the write lock and hold it as long as the parent is - // working to compute the value - let guard = lock - .try_write_owned() - .expect("Lock was just created, who the hell grabbed it??"); - // Drop the root cache lock *after* we acquire the lock for our - // own future, to prevent other tasks grabbing it first - drop(cache); - - FutureCacheOutcome::Miss(FutureCacheGuard(guard)) - } - } - } -} - -impl Default for FutureCache { - fn default() -> Self { - Self { - cache: Default::default(), - } - } -} - -/// Outcome of check a future cache for a particular key -pub(crate) enum FutureCacheOutcome { - /// The value is already in the cache - Hit(V), - /// The value is not in the cache. Caller is responsible for inserting it - /// by calling [FutureCacheGuard::set] once computed. - Miss(FutureCacheGuard), - /// The first entrant dropped their write guard without writing to it, so - /// there's no response to return - NoResponse, -} - -/// A handle for writing a computed future value back into the cache. This is -/// returned once per key, to the first caller of that key. The caller is then -/// responsible for calling [FutureCacheGuard::set] to insert the value for -/// everyone else. Subsequent callers to the cache will block until `set` is -/// called. -pub(crate) struct FutureCacheGuard(OwnedRwLockWriteGuard>); - -impl FutureCacheGuard { - pub fn set(mut self, value: V) { - *self.0.deref_mut() = Some(value); - } -} - -impl Drop for FutureCacheGuard { - fn drop(&mut self) { - if self.0.is_none() { - // Friendly little error logging. We don't have a good way of - // identifying *which* lock this happened to :( - error!("Future cache guard dropped without setting a value"); - } - } -} diff --git a/crates/doc_utils/src/template_functions.rs b/crates/doc_utils/src/template_functions.rs index f168b419..97705f18 100644 --- a/crates/doc_utils/src/template_functions.rs +++ b/crates/doc_utils/src/template_functions.rs @@ -320,6 +320,9 @@ fn type_map() -> HashMap { ]), ), (parse_quote!(String), TypeDef::String), + // We're hiding streams from the type system, since they will + // transparently convert to bytes + (parse_quote!(Stream), TypeDef::Bytes), (parse_quote!(TrimMode), union!("start" | "end" | "both")), (parse_quote!(slumber_template::Value), TypeDef::Value), (parse_quote!(Value), TypeDef::Value), diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index ac20c2c5..0587a155 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -96,7 +96,7 @@ pub fn template(_attr: TokenStream, item: TokenStream) -> TokenStream { #[allow(unused_mut)] mut arguments: ::slumber_template::Arguments<'_, #context_type_param> ) -> ::core::result::Result< - ::slumber_template::Value, + ::slumber_template::Stream, ::slumber_template::RenderError > { #inner_fn diff --git a/crates/template/Cargo.toml b/crates/template/Cargo.toml index ad499944..ef6e591c 100644 --- a/crates/template/Cargo.toml +++ b/crates/template/Cargo.toml @@ -12,17 +12,18 @@ version = {workspace = true} [dependencies] bytes = {workspace = true, features = ["serde"]} -derive_more = {workspace = true, features = ["deref", "display", "from"]} +derive_more = {workspace = true, features = ["debug", "deref", "display", "from"]} futures = {workspace = true} indexmap = {workspace = true, features = ["serde"]} itertools = {workspace = true} -regex = {version = "1.10.5", default-features = false} +regex = {workspace = true} saphyr = {workspace = true} schemars = {workspace = true, optional = true} serde = {workspace = true, features = ["derive"]} serde_json = {workspace = true} slumber_util = {workspace = true} thiserror = {workspace = true} +tokio = {workspace = true, features = ["sync"]} tracing = {workspace = true} winnow = {workspace = true} @@ -35,7 +36,7 @@ rstest = {workspace = true} serde_json = {workspace = true} serde_yaml = {workspace = true} slumber_util = {workspace = true, features = ["test"]} -tokio = {workspace = true, features = ["macros", "rt"]} +tokio = {workspace = true, features = ["fs", "macros", "rt"]} [features] schema = ["dep:schemars"] diff --git a/crates/template/src/error.rs b/crates/template/src/error.rs index 770eddd7..97061156 100644 --- a/crates/template/src/error.rs +++ b/crates/template/src/error.rs @@ -35,11 +35,11 @@ impl From> for TemplateParseError { /// they should be pretty brief. #[derive(Debug, Error)] pub enum RenderError { - /// 2+ futures were rendering the same profile field. One future was doing - /// the actual rendering and the rest were waiting on the first. If the - /// first one fails, the rest will return this error. Theoretically this - /// will never actually be emitted because `try_join` should return after - /// the initial error, so this is a placeholder. + /// 2+ futures were rendering the same field. One future was doing the + /// actual rendering and the rest were waiting on the first. If the first + /// one fails, the rest will return this error. Theoretically this will + /// never actually be emitted because `try_join` should return after the + /// initial error, so this is a placeholder. #[error("Error rendering cached profile field `{field}`")] CacheFailed { field: Identifier }, diff --git a/crates/template/src/expression.rs b/crates/template/src/expression.rs index ab8c2601..6a3e1507 100644 --- a/crates/template/src/expression.rs +++ b/crates/template/src/expression.rs @@ -3,7 +3,8 @@ #[cfg(test)] use crate::test_util; use crate::{ - Arguments, Context, RenderError, Value, error::RenderErrorContext, + Arguments, Context, RenderError, Stream, Value, error::RenderErrorContext, + util::FieldCacheOutcome, }; use bytes::Bytes; use derive_more::{Deref, Display, From}; @@ -13,7 +14,7 @@ use futures::{ }; use indexmap::IndexMap; -type RenderResult = Result; +type RenderResult = Result; /// A dynamic segment of a template that will be computed at render time. /// Expressions are derived from the template context and may include external @@ -52,44 +53,92 @@ impl Expression { Self::Literal(literal) => Ok(literal.into()), Self::Array(expressions) => { // Render each inner expression - let values = future::try_join_all( - expressions - .iter() - .map(|expression| expression.render(context)), - ) - .boxed() // Box for recursion - .await?; - Ok(Value::Array(values)) + let values = + future::try_join_all(expressions.iter().map( + |expression| expression.render_value(context), + )) + // Box for recursion + .boxed() + .await?; + Ok(Value::Array(values).into()) } Self::Object(entries) => { let pairs: Vec<(String, Value)> = future::try_join_all( entries.iter().map(|(key, value)| { let key_future = async move { - let key = key.render(context).await?; + let key = key.render_value(context).await?; // Keys must be strings, so convert here key.try_into_string().map_err(|error| { RenderError::Value(error.error) }) }; - try_join(key_future, value.render(context)) + try_join(key_future, value.render_value(context)) }), ) .boxed() // Box for recursion .await?; // Keys will be deduped here, with the last taking priority - Ok(Value::Object(IndexMap::from_iter(pairs))) + Ok(Value::Object(IndexMap::from_iter(pairs)).into()) } - Self::Field(identifier) => context.get(identifier).await, + Self::Field(field) => Self::render_field(field, context).await, Self::Call(call) => call.call(context, None).await, Self::Pipe { expression, call } => { // Compute the left hand side first. Box for recursion - let value = expression.render(context).boxed().await?; + let value = + expression.render_value(context).boxed().await?; call.call(context, Some(value)).await } } } } + /// Render the value of a field. This will apply caching, so that a field + /// never has to be rendered more than once for a given context. + async fn render_field( + field: &Identifier, + context: &Ctx, + ) -> RenderResult { + // Check the field cache to see if this value is already being computed + // somewhere else. If it is, we'll block on that and re-use the result. + // If not, we get a guard back, meaning we're responsible for the + // computation. At the end, we'll write back to the guard so everyone + // else can copy our homework. + let cache = context.field_cache(); + let guard = match cache.get_or_init(field.clone()).await { + FieldCacheOutcome::Hit(stream) => return Ok(stream), + FieldCacheOutcome::Miss(guard) => guard, + // The future responsible for writing to the guard failed. Cloning + // errors is annoying so we return an empty response here. The + // initial error should've been returned elsewhere so that can be + // used instead. + FieldCacheOutcome::NoResponse => { + return Err(RenderError::CacheFailed { + field: field.clone(), + }); + } + }; + + // This value hasn't been rendered yet - ask the context to evaluate it + let mut stream = context.get_field(field).await?; + // If streaming isn't supported here, convert to a value before caching, + // so that the stream isn't evaluated multiple times unless necessary + if !context.can_stream() { + stream = stream.resolve().await?.into(); + } + + // Store value in the cache so other references to this field can use it + guard.set(stream.clone()); + Ok(stream) + } + + /// Render this expression, resolving any stream to a concrete value. + async fn render_value( + &self, + context: &Ctx, + ) -> Result { + self.render(context).await?.resolve().await + } + /// Build a function call expression. Any keyword arguments with `None` /// values will be omitted pub fn call( @@ -248,7 +297,7 @@ impl FunctionCall { &self, context: &Ctx, piped_argument: Option, - ) -> Result { + ) -> RenderResult { // Provide context to the error let map_error = |error: RenderError| { error.context(RenderErrorContext::Function(self.function.clone())) @@ -277,7 +326,7 @@ impl FunctionCall { let position_future = future::try_join_all(self.position.iter().enumerate().map( |(index, expression)| async move { - expression.render(context).await.map_err(|error| { + expression.render_value(context).await.map_err(|error| { error.context(RenderErrorContext::ArgumentRender { argument: index.to_string(), expression: expression.clone(), @@ -287,13 +336,14 @@ impl FunctionCall { )); let keyword_future = future::try_join_all(self.keyword.iter().map( |(name, expression)| async { - let value = - expression.render(context).await.map_err(|error| { + let value = expression.render_value(context).await.map_err( + |error| { error.context(RenderErrorContext::ArgumentRender { argument: name.to_string(), expression: expression.clone(), }) - })?; + }, + )?; Ok((name.to_string(), value)) }, )); diff --git a/crates/template/src/lib.rs b/crates/template/src/lib.rs index 5629d29c..2804779a 100644 --- a/crates/template/src/lib.rs +++ b/crates/template/src/lib.rs @@ -12,16 +12,19 @@ mod parse; mod test_util; #[cfg(test)] mod tests; +mod util; mod value; pub use error::{ Expected, RenderError, TemplateParseError, ValueError, WithValue, }; pub use expression::{Expression, FunctionCall, Identifier, Literal}; -pub use value::{Arguments, FunctionOutput, TryFromValue, Value}; +pub use util::FieldCache; +pub use value::{ + Arguments, FunctionOutput, Stream, StreamMetadata, TryFromValue, Value, +}; use bytes::{Bytes, BytesMut}; -use derive_more::From; use futures::future; use itertools::Itertools; #[cfg(test)] @@ -32,22 +35,76 @@ use std::{fmt::Debug, sync::Arc}; /// `Context` defines how template fields and functions are resolved. Both /// field resolution and function calls can be asynchronous. pub trait Context: Sized + Send + Sync { + /// Does the render target support streaming? Typically this should return + /// `false`. To enable streaming, just call [Template::render_stream] and + /// the context will be wrapped to enable streaming. + /// + /// This is a method on the context to avoid plumbing around a second object + /// to all render locations. + fn can_stream(&self) -> bool { + false + } + /// Get the value of a field from the context. The implementor can decide /// where fields are derived from. Fields can also be computed dynamically /// and be `async`. For example, fields can be loaded from a map of nested /// templates, in which case the nested template would need to be rendered - /// before this can be returned. - fn get( + /// before this can be returned. Rendered fields will be cached via the + /// cache returned by [Self::field_cache], so the same field will never be + /// requested twice for this context object. + fn get_field( &self, identifier: &Identifier, - ) -> impl Future> + Send; + ) -> impl Future> + Send; + + /// A cache to store the outcome of rendered fields. + fn field_cache(&self) -> &FieldCache; /// Call a function by name fn call( &self, function_name: &Identifier, arguments: Arguments<'_, Self>, - ) -> impl Future> + Send; + ) -> impl Future> + Send; +} + +/// A wrapper for a [Context] implementation that enables streaming all other +/// behavior is forwarded to the inner context. This is automatically applied by +/// [Template::render_stream], but can also be used manually to control the +/// output of [Template::render_chunks]. +pub struct StreamContext<'a, T>(&'a T); + +impl<'a, T> StreamContext<'a, T> { + pub fn new(context: &'a T) -> Self { + Self(context) + } +} + +impl Context for StreamContext<'_, T> { + fn can_stream(&self) -> bool { + true + } + + async fn get_field( + &self, + identifier: &Identifier, + ) -> Result { + self.0.get_field(identifier).await + } + + fn field_cache(&self) -> &FieldCache { + self.0.field_cache() + } + + async fn call( + &self, + function_name: &Identifier, + arguments: Arguments<'_, Self>, + ) -> Result { + self.0 + .call(function_name, arguments.map_context(|ctx| ctx.0)) + .await + } } /// A parsed template, which can contain raw and/or templated content. The @@ -165,59 +222,81 @@ impl Template { self.chunks.is_empty() } - /// Render the template using values from the given context. If any chunk - /// failed to render, return an error. The render output is converted to a - /// [Value] by these rules: - /// - If the template is a single dynamic chunk, the output value will be - /// directly converted to JSON, allowing non-string JSON values + /// Render the template. If any chunk fails to render, return an error. The + /// render output is converted to a [Value] by these rules: + /// - If the template is a single dynamic chunk, return the output of that + /// chunk, which may be any type of [Value] /// - Any other template will be rendered to a string by stringifying each /// dynamic chunk and concatenating them all together /// - If rendering to a string fails because the bytes are not valid UTF-8, /// concatenate into a bytes object instead /// - /// Return an error iff any chunk failed to render. This will never fail on + /// Return an error iff any chunk fails to render. This will never fail on /// output conversion because it can always fall back to returning raw /// bytes. pub async fn render_value( &self, context: &Ctx, ) -> Result { - let mut chunks = self.render_chunks(context).await; + let stream = self.render_stream(context).await?; + let value = stream.resolve().await?; + // Try to convert bytes to string, because that's generally more useful + // to the consumer + match value { + Value::Bytes(bytes) => match String::from_utf8(bytes.into()) { + Ok(s) => Ok(Value::String(s)), + Err(error) => Ok(Value::Bytes(error.into_bytes().into())), + }, + _ => Ok(value), + } + } + + /// Render the template. The output may be a concrete value *or* a + /// streamable value, following these rules: + /// - If the template is a single dynamic chunk, return the output of that + /// chunk, which may be a stream or any type of [Value] + /// - Any other template will be rendered to bytes by concatenating all + /// chunks together + /// + /// Since streams return bytes, concrete values are also returned as bytes + /// because we know the consumer will accept bytes. + pub async fn render_stream( + &self, + context: &Ctx, + ) -> Result { + let context = StreamContext(context); + let mut chunks = self.render_chunks(&context).await; // If we have a single dynamic chunk, return its value directly instead // of stringifying if let &[RenderedChunk::Rendered(_)] = chunks.as_slice() { - let Some(RenderedChunk::Rendered(value)) = chunks.pop() else { + let Some(RenderedChunk::Rendered(stream)) = chunks.pop() else { // Checked pattern above unreachable!() }; - return Ok(value); + return Ok(stream); } - // Stitch together into bytes. Attempt to convert that UTF-8, but if + // Stitch together into bytes. Attempt to convert that to UTF-8, but if // that fails fall back to just returning the bytes - let bytes = chunks_to_bytes(chunks)?; - match String::from_utf8(bytes.into()) { - Ok(s) => Ok(Value::String(s)), - Err(error) => Ok(Value::Bytes(error.into_bytes().into())), - } + chunks_to_bytes(chunks).await.map(Stream::from) } - /// Render the template using values from the given context. If any chunk - /// failed to render, return an error. The output is returned as bytes, - /// meaning it can safely render to non-UTF-8 content. Use - /// [Self::render_string] if you want the bytes converted to a string. + /// Render the template. If any chunk fails to render, return an error. The + /// output is returned as bytes, meaning it can safely render to non-UTF-8 + /// content. Use [Self::render_string] if you want the bytes converted to a + /// string. pub async fn render_bytes( &self, context: &Ctx, ) -> Result { let chunks = self.render_chunks(context).await; - chunks_to_bytes(chunks) + chunks_to_bytes(chunks).await } - /// Render the template using values from the given context. If any chunk - /// failed to render, return an error. The output will be converted from raw - /// bytes to UTF-8. If it is not valid UTF-8, return an error. + /// Render the template. If any chunk fails to render, return an error. The + /// output will be converted from raw bytes to UTF-8. If it is not valid + /// UTF-8, return an error. pub async fn render_string( &self, context: &Ctx, @@ -226,11 +305,11 @@ impl Template { String::from_utf8(bytes.into()).map_err(RenderError::other) } - /// Render the template using values from the given context, returning the - /// individual rendered chunks rather than stitching them together into a - /// string. If any individual chunk fails to render, its error will be - /// returned inline as [RenderedChunk::Error] and the rest of the template - /// will still be rendered. + /// Render the template, returning the individual rendered chunks rather + /// than stitching them together into a string. If any individual chunk + /// fails to render, its error will be returned inline as + /// [RenderedChunk::Error] and the rest of the template will still be + /// rendered. pub async fn render_chunks( &self, context: &Ctx, @@ -243,10 +322,22 @@ impl Template { TemplateChunk::Raw(text) => { RenderedChunk::Raw(Arc::clone(text)) } - TemplateChunk::Expression(expression) => expression - .render(context) - .await - .map_or_else(RenderedChunk::Error, RenderedChunk::Rendered), + TemplateChunk::Expression(expression) => { + match expression.render(context).await { + Ok(stream) if context.can_stream() => { + RenderedChunk::Rendered(stream) + } + // If the context doesn't support streaming, resolve the + // stream now + Ok(stream) => stream + .resolve() + .await + .map_or_else(RenderedChunk::Error, |value| { + RenderedChunk::Rendered(value.into()) + }), + Err(error) => RenderedChunk::Error(error), + } + } } }); @@ -317,8 +408,8 @@ pub enum RenderedChunk { /// stored in an `Arc` so we can reference the text in the parsed input /// without having to clone it. Raw(Arc), - /// Outcome of rendering a template key - Rendered(Value), + /// A dynamic chunk of a template, rendered to a stream/value + Rendered(Stream), /// An error occurred while rendering a template key Error(RenderError), } @@ -328,9 +419,12 @@ impl PartialEq for RenderedChunk { fn eq(&self, other: &Self) -> bool { match (self, other) { (Self::Raw(raw1), Self::Raw(raw2)) => raw1 == raw2, - (Self::Rendered(value1), Self::Rendered(value2)) => { - value1 == value2 - } + ( + Self::Rendered(Stream::Value(value1)), + Self::Rendered(Stream::Value(value2)), + ) => value1 == value2, + // Streams are never equal + (Self::Rendered(_), Self::Rendered(_)) => false, (Self::Error(error1), Self::Error(error2)) => { // RenderError doesn't have a PartialEq impl, so we have to // do a string comparison. @@ -342,31 +436,37 @@ impl PartialEq for RenderedChunk { } /// Concatenate rendered chunks into bytes. If any chunk is an error, return an -/// error -fn chunks_to_bytes(chunks: Vec) -> Result { +/// error. This is async because the chunk may be a stream, in which case it +/// will be resolved. +async fn chunks_to_bytes( + chunks: Vec, +) -> Result { // Take an educated guess at the needed capacity to avoid reallocations let capacity = chunks .iter() .map(|chunk| match chunk { RenderedChunk::Raw(s) => s.len(), - RenderedChunk::Rendered(Value::Bytes(bytes)) => bytes.len(), - RenderedChunk::Rendered(Value::String(s)) => s.len(), + RenderedChunk::Rendered(Stream::Value(Value::Bytes(bytes))) => { + bytes.len() + } + RenderedChunk::Rendered(Stream::Value(Value::String(s))) => s.len(), // Take a rough guess for anything other than bytes/string RenderedChunk::Rendered(_) => 5, RenderedChunk::Error(_) => 0, }) .sum(); - chunks - .into_iter() - .try_fold(BytesMut::with_capacity(capacity), |mut acc, chunk| { - match chunk { - RenderedChunk::Raw(s) => acc.extend(s.as_bytes()), - RenderedChunk::Rendered(value) => { - acc.extend(value.into_bytes()); - } - RenderedChunk::Error(error) => return Err(error), + + let mut bytes = BytesMut::with_capacity(capacity); + for chunk in chunks { + match chunk { + RenderedChunk::Raw(s) => bytes.extend(s.as_bytes()), + RenderedChunk::Rendered(stream) => { + // If the chunk is still a stream, resolve to a value now + let value = stream.resolve().await?; + bytes.extend(value.into_bytes()); } - Ok(acc) - }) - .map(Bytes::from) + RenderedChunk::Error(error) => return Err(error), + } + } + Ok(bytes.into()) } diff --git a/crates/template/src/tests.rs b/crates/template/src/tests.rs index 5b67ecfb..103eb877 100644 --- a/crates/template/src/tests.rs +++ b/crates/template/src/tests.rs @@ -1,7 +1,17 @@ -use crate::{Arguments, Context, Identifier, RenderError, Template, Value}; +use crate::{ + Arguments, Context, FieldCache, Identifier, RenderError, Stream, Template, + Value, value::StreamMetadata, +}; +use bytes::Bytes; +use futures::FutureExt; use indexmap::indexmap; use rstest::rstest; -use slumber_util::assert_err; +use slumber_util::{assert_err, assert_matches, test_data_dir}; +use std::sync::{ + Arc, + atomic::{AtomicI64, Ordering}, +}; +use tokio::fs; /// Test simple expression rendering #[rstest] @@ -21,7 +31,13 @@ use slumber_util::assert_err; )] #[tokio::test] async fn test_expression(#[case] template: Template, #[case] expected: Value) { - assert_eq!(template.render_value(&TestContext).await.unwrap(), expected); + assert_eq!( + template + .render_value(&TestContext::default()) + .await + .unwrap(), + expected + ); } /// Render to a value. Templates with a single dynamic chunk are allowed to @@ -34,12 +50,52 @@ async fn test_expression(#[case] template: Template, #[case] expected: Value) { "my name is {{ invalid_utf8 }}", Value::Bytes(b"my name is \xc3\x28".as_slice().into(), ))] +// Stream gets resolved to bytes, then converted to a string +#[case::stream("{{ stream() }}", "{ \"a\": 1, \"b\": 2 }".into())] #[tokio::test] async fn test_render_value( #[case] template: Template, #[case] expected: Value, ) { - assert_eq!(template.render_value(&TestContext).await.unwrap(), expected); + assert_eq!( + template + .render_value(&TestContext::default()) + .await + .unwrap(), + expected + ); +} + +/// Render to a stream +#[rstest] +#[case::stream("{{ stream() }}", b"{ \"a\": 1, \"b\": 2 }", true)] +#[case::text("text: {{ stream() }}", b"text: { \"a\": 1, \"b\": 2 }", false)] +#[case::binary( + "{{ invalid_utf8 }} {{ stream() }}", + b"\xc3\x28 { \"a\": 1, \"b\": 2 }", + false +)] +#[tokio::test] +async fn test_render_stream( + #[case] template: Template, + #[case] expected_resolved: &[u8], + #[case] expected_is_stream: bool, +) { + let stream = template + .render_stream(&TestContext::default()) + .await + .unwrap(); + + if expected_is_stream { + assert_matches!(stream, Stream::Stream { .. }); + } else { + assert_matches!(stream, Stream::Value(Value::Bytes(_))); + } + + assert_eq!( + stream.resolve().await.unwrap().into_bytes(), + expected_resolved + ); } /// Convert JSON values to template values @@ -104,7 +160,10 @@ fn test_from_json(#[case] json: serde_json::Value, #[case] expected: Value) { #[tokio::test] async fn test_pipe(#[case] template: Template, #[case] expected: &str) { assert_eq!( - template.render_string(&TestContext).await.unwrap(), + template + .render_string(&TestContext::default()) + .await + .unwrap(), expected ); } @@ -135,37 +194,74 @@ async fn test_function_error( assert_err!( // Use anyhow to get the error message to include the whole chain template - .render_string(&TestContext) + .render_string(&TestContext::default()) .await .map_err(anyhow::Error::from), expected_error ); } -struct TestContext; +/// Using the same field multiple times should be deduplicated, so that the +/// expression is only evaluated once +#[tokio::test] +async fn test_field_duplicate() { + let context = TestContext::default(); + let template: Template = "{{ increment }} + {{ increment }}".into(); + + // Should deduplicate multiple uses in the same template + assert_eq!(template.render_string(&context).await.unwrap(), "1 + 1"); + // Rendering again with the same context should retain the caching + assert_eq!(template.render_string(&context).await.unwrap(), "1 + 1"); +} + +#[derive(Debug, Default)] +struct TestContext { + increment: AtomicI64, + field_cache: FieldCache, +} impl Context for TestContext { - async fn get(&self, identifier: &Identifier) -> Result { + fn can_stream(&self) -> bool { + true + } + + async fn get_field( + &self, + identifier: &Identifier, + ) -> Result { match identifier.as_str() { "name" => Ok("Mike".into()), "array" => Ok(vec!["a", "b", "c"].into()), - "invalid_utf8" => Ok(Value::Bytes(b"\xc3\x28".as_slice().into())), + // A field that increments each time it's evaluated, to test for + // deduplication + "increment" => { + let previous_incrs = + self.increment.fetch_add(1, Ordering::Relaxed); + // Return the number of times this has been evaluated, including + // this call + Ok((previous_incrs + 1).into()) + } + "invalid_utf8" => Ok(b"\xc3\x28".into()), _ => Err(RenderError::FieldUnknown { field: identifier.clone(), }), } } + fn field_cache(&self) -> &FieldCache { + &self.field_cache + } + async fn call( &self, function_name: &Identifier, mut arguments: Arguments<'_, Self>, - ) -> Result { + ) -> Result { match function_name.as_str() { "identity" => { let value: Value = arguments.pop_position()?; arguments.ensure_consumed()?; - Ok(value) + Ok(value.into()) } "add" => { let a: i64 = arguments.pop_position()?; @@ -185,6 +281,22 @@ impl Context for TestContext { Ok(a.into()) } } + "stream" => { + let path = test_data_dir().join("data.json"); + Ok(Stream::Stream { + metadata: StreamMetadata::File { path: path.clone() }, + f: Arc::new(move || { + let path = path.clone(); + async move { + fs::read(path) + .await + .map(Bytes::from) + .map_err(RenderError::other) + } + .boxed() + }), + }) + } _ => Err(RenderError::FunctionUnknown), } } diff --git a/crates/template/src/util.rs b/crates/template/src/util.rs new file mode 100644 index 00000000..8ef8d5e7 --- /dev/null +++ b/crates/template/src/util.rs @@ -0,0 +1,94 @@ +use crate::{Identifier, Stream}; +use std::{ + collections::{HashMap, hash_map::Entry}, + ops::DerefMut, + sync::Arc, +}; +use tokio::sync::{Mutex, OwnedRwLockWriteGuard, RwLock}; +use tracing::error; + +/// A cache of template values that either have been computed, or are +/// asynchronously being computed. This allows multiple references to the same +/// template field to share their work. +#[derive(Debug, Default)] +pub struct FieldCache { + /// Cache each value by key. The outer mutex will only be held open for as + /// long as it takes to check if the value is in the cache or not. The + /// inner lock will be blocked on until the value is available. + cache: Mutex>>>>, +} + +impl FieldCache { + /// Get a value from the cache, or if not present, insert a placeholder + /// value and return a guard that can be used to insert the completed value + /// later. The placeholder will tell subsequent accessors of this key that + /// the value is being computed, and will be present later. If the + /// placeholder is present and the final value being computed, **this block + /// will not return until the value is available**. + pub(crate) async fn get_or_init( + &self, + field: Identifier, + ) -> FieldCacheOutcome { + let mut cache = self.cache.lock().await; + match cache.entry(field) { + Entry::Occupied(entry) => { + let lock = Arc::clone(entry.get()); + drop(cache); // Drop the outer lock as quickly as possible + + match &*lock.read_owned().await { + Some(value) => FieldCacheOutcome::Hit(value.clone()), + None => FieldCacheOutcome::NoResponse, + } + } + Entry::Vacant(entry) => { + let lock = Arc::new(RwLock::new(None)); + entry.insert(Arc::clone(&lock)); + // Grab the write lock and hold it as long as the parent is + // working to compute the value + let guard = lock + .try_write_owned() + .expect("Lock was just created, who the hell grabbed it??"); + // Drop the root cache lock *after* we acquire the lock for our + // own future, to prevent other tasks grabbing it first + drop(cache); + + FieldCacheOutcome::Miss(FutureCacheGuard(guard)) + } + } + } +} + +/// Outcome of check a future cache for a particular key +pub(crate) enum FieldCacheOutcome { + /// The value is already in the cache + Hit(Stream), + /// The value is not in the cache. Caller is responsible for inserting it + /// by calling [FutureCacheGuard::set] once computed. + Miss(FutureCacheGuard), + /// The first entrant dropped their write guard without writing to it, so + /// there's no response to return + NoResponse, +} + +/// A handle for writing a computed future value back into the cache. This is +/// returned once per key, to the first caller of that key. The caller is then +/// responsible for calling [FutureCacheGuard::set] to insert the value for +/// everyone else. Subsequent callers to the cache will block until `set` is +/// called. +pub(crate) struct FutureCacheGuard(OwnedRwLockWriteGuard>); + +impl FutureCacheGuard { + pub fn set(mut self, stream: Stream) { + *self.0.deref_mut() = Some(stream); + } +} + +impl Drop for FutureCacheGuard { + fn drop(&mut self) { + if self.0.is_none() { + // Friendly little error logging. We don't have a good way of + // identifying *which* lock this happened to :( + error!("Future cache guard dropped without setting a value"); + } + } +} diff --git a/crates/template/src/value.rs b/crates/template/src/value.rs index 33db08b2..7658d0da 100644 --- a/crates/template/src/value.rs +++ b/crates/template/src/value.rs @@ -7,9 +7,10 @@ use crate::{ }; use bytes::Bytes; use derive_more::From; +use futures::future::BoxFuture; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use std::{collections::VecDeque, fmt::Debug}; +use std::{collections::VecDeque, fmt::Debug, path::PathBuf, sync::Arc}; /// A runtime template value. This very similar to a JSON value, except: /// - Numbers do not support arbitrary size @@ -122,6 +123,13 @@ impl From<&str> for Value { } } +// Convert from byte literals +impl From<&'static [u8; N]> for Value { + fn from(value: &'static [u8; N]) -> Self { + Self::Bytes(value.as_slice().into()) + } +} + impl From> for Value where Value: From, @@ -152,6 +160,60 @@ impl From for Value { } } +/// A source of a template value. This can be a concrete [Value] or a streamable +/// source such as a file. This is used widely within rendering because it's a +/// superset of all values. Not all renders accept streams as results though, +/// so it's a separate type rather than a variant on [Value]. To convert a +/// stream into a value, call [Self::resolve]. +#[derive(Clone, derive_more::Debug)] +pub enum Stream { + /// A pre-resolved value + Value(Value), + /// Stream data from a (potentially) large source such as a file + Stream { + /// Additional information about the source of the stream + metadata: StreamMetadata, + /// Function returning the stream future. This can be cloned so that it + /// can be called multiple times, as the stream may be cloned by the + /// field cache. + #[debug(skip)] + f: Arc< + dyn Fn() -> BoxFuture<'static, Result> + + Send + + Sync, + >, + }, +} + +impl Stream { + /// Resolve this stream to a concrete [Value]. If it's already a value, just + /// return it. Otherwise the stream will be awaited and collected into + /// bytes. + pub async fn resolve(self) -> Result { + match self { + Self::Value(value) => Ok(value), + Self::Stream { f, .. } => f().await.map(Value::Bytes), + } + } +} + +impl> From for Stream { + fn from(value: T) -> Self { + Self::Value(value.into()) + } +} + +/// Metadata about the source of a [Stream]. This helps consumers present the +/// stream to the user, e.g. in a template preview +#[derive(Clone, Debug)] +pub enum StreamMetadata { + /// Data is being streamed from a file + File { + /// **Absolute** path to the file + path: PathBuf, + }, +} + /// Convert [Value] to a type fallibly /// /// This is used for converting function arguments to the static types expected @@ -395,6 +457,19 @@ impl<'ctx, Ctx> Arguments<'ctx, Ctx> { } } + /// Replace the context by mapping it through a function + pub fn map_context( + self, + f: impl FnOnce(&'ctx Ctx) -> &'ctx Ctx2, + ) -> Arguments<'ctx, Ctx2> { + Arguments { + context: f(self.context), + position: self.position, + num_popped: self.num_popped, + keyword: self.keyword, + } + } + /// Push a piped argument onto the back of the positional argument list pub(crate) fn push_piped(&mut self, argument: Value) { self.position.push_back(argument); @@ -405,30 +480,27 @@ impl<'ctx, Ctx> Arguments<'ctx, Ctx> { /// /// This is used for converting function outputs back to template values. pub trait FunctionOutput { - fn into_result(self) -> Result; + fn into_result(self) -> Result; } -impl FunctionOutput for T -where - Value: From, -{ - fn into_result(self) -> Result { +impl> FunctionOutput for T { + fn into_result(self) -> Result { Ok(self.into()) } } impl FunctionOutput for Result where - T: Into + Send + Sync, + T: Into + Send + Sync, E: Into + Send + Sync, { - fn into_result(self) -> Result { + fn into_result(self) -> Result { self.map(T::into).map_err(E::into) } } impl FunctionOutput for Option { - fn into_result(self) -> Result { - self.map(T::into_result).unwrap_or(Ok(Value::Null)) + fn into_result(self) -> Result { + self.map(T::into_result).unwrap_or(Ok(Value::Null.into())) } } diff --git a/crates/tui/src/message.rs b/crates/tui/src/message.rs index d51960b5..cc47668a 100644 --- a/crates/tui/src/message.rs +++ b/crates/tui/src/message.rs @@ -166,6 +166,9 @@ pub enum Message { /// way back down the component tree. TemplatePreview { template: Template, + /// Does the consumer support streaming? If so, the output chunks may + /// contain streams + can_stream: bool, #[debug(skip)] on_complete: Callback>, }, diff --git a/crates/tui/src/state.rs b/crates/tui/src/state.rs index 7f02317c..e53d31c1 100644 --- a/crates/tui/src/state.rs +++ b/crates/tui/src/state.rs @@ -18,7 +18,7 @@ use slumber_core::{ http::{Exchange, RequestError, RequestId, RequestSeed}, render::{Prompter, TemplateContext}, }; -use slumber_template::{RenderedChunk, Template}; +use slumber_template::{RenderedChunk, StreamContext, Template}; use slumber_util::ResultTraced; use std::{ path::{Path, PathBuf}, @@ -384,6 +384,7 @@ impl LoadedState { } Message::TemplatePreview { template, + can_stream, on_complete, } => { self.render_template_preview( @@ -394,6 +395,7 @@ impl LoadedState { // and this shortcut saves us a lot of plumbing so it's // worth it self.view.selected_profile_id().cloned(), + can_stream, on_complete, ); } @@ -646,12 +648,17 @@ impl LoadedState { &self, template: Template, profile_id: Option, + can_stream: bool, on_complete: Callback>, ) { let context = self.template_context(profile_id, true); util::spawn(async move { // Render chunks, then write them to the output destination - let chunks = template.render_chunks(&context).await; + let chunks = if can_stream { + template.render_chunks(&StreamContext::new(&context)).await + } else { + template.render_chunks(&context).await + }; on_complete(chunks); }); } diff --git a/crates/tui/src/view/common/template_preview.rs b/crates/tui/src/view/common/template_preview.rs index 3cad06ee..4f5682a7 100644 --- a/crates/tui/src/view/common/template_preview.rs +++ b/crates/tui/src/view/common/template_preview.rs @@ -11,7 +11,7 @@ use ratatui::{ widgets::Widget, }; use slumber_core::http::content_type::ContentType; -use slumber_template::{RenderedChunk, Template}; +use slumber_template::{RenderedChunk, Stream, StreamMetadata, Template}; use std::{ ops::Deref, sync::{Arc, Mutex}, @@ -42,10 +42,22 @@ impl TemplatePreview { /// defines which profile to use for the render. Optionally provide content /// type to enable syntax highlighting, which will be applied to both /// unrendered and rendered content. + /// + /// ## Params + /// + /// - `template`: Template to render + /// - `content_type`: Content-Type of the output, which can be used to apply + /// syntax highlighting + /// - `overridden`: Has the template been overridden by the user in the + /// current session? Applies additional styling + /// - `can_stream`: Does this component of the recipe support streaming? If + /// so, the template will be rendered to a stream if possible and its + /// metadata will be displayed rather than the resolved value. pub fn new( template: Template, content_type: Option, overridden: bool, + can_stream: bool, ) -> Self { let tui_context = TuiContext::get(); let style = if overridden { @@ -83,6 +95,7 @@ impl TemplatePreview { ViewContext::send_message(Message::TemplatePreview { template, + can_stream, on_complete: Box::new(on_complete), }); } @@ -114,7 +127,7 @@ impl TemplatePreview { impl From