Skip to content

Commit 92fb84c

Browse files
committed
test: add request-ID isolation in concurrent exec
Adds test to ensure request IDs are properly isolated in concurrent Lambda execution with structured logging. Verifies each concurrent invocation maintains unique request ID context.
1 parent 217ecaf commit 92fb84c

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

lambda-runtime/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ idna_adapter = "=1.2.0"
7373
lambda_runtime = { path = ".", features = ["tracing", "graceful-shutdown"] }
7474
pin-project-lite = { workspace = true }
7575
tracing-appender = "0.2"
76+
tracing-subscriber = { version = "0.3", features = ["registry"] }
7677

7778
[package.metadata.docs.rs]
7879
all-features = true

lambda-runtime/src/runtime.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,4 +908,165 @@ mod endpoint_tests {
908908
server_handle.abort();
909909
Ok(())
910910
}
911+
912+
#[tokio::test]
913+
#[cfg(feature = "experimental-concurrency")]
914+
async fn test_concurrent_structured_logging_isolation() -> Result<(), Error> {
915+
use std::collections::{HashMap, HashSet};
916+
use std::sync::Mutex;
917+
use tracing::{info, subscriber::set_global_default};
918+
use tracing_subscriber::{layer::SubscriberExt, Layer};
919+
920+
#[derive(Clone)]
921+
struct LogCapture {
922+
logs: Arc<Mutex<Vec<HashMap<String, String>>>>,
923+
}
924+
925+
impl LogCapture {
926+
fn new() -> Self {
927+
Self {
928+
logs: Arc::new(Mutex::new(Vec::new())),
929+
}
930+
}
931+
}
932+
933+
impl<S> Layer<S> for LogCapture
934+
where
935+
S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
936+
{
937+
fn on_event(&self, event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) {
938+
let mut fields = HashMap::new();
939+
struct FieldVisitor<'a>(&'a mut HashMap<String, String>);
940+
impl<'a> tracing::field::Visit for FieldVisitor<'a> {
941+
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
942+
self.0.insert(field.name().to_string(), format!("{:?}", value).trim_matches('"').to_string());
943+
}
944+
}
945+
event.record(&mut FieldVisitor(&mut fields));
946+
self.logs.lock().unwrap().push(fields);
947+
}
948+
}
949+
950+
let log_capture = LogCapture::new();
951+
let subscriber = tracing_subscriber::registry().with(log_capture.clone());
952+
set_global_default(subscriber).unwrap();
953+
954+
let request_count = Arc::new(AtomicUsize::new(0));
955+
let listener = TcpListener::bind("127.0.0.1:0").await?;
956+
let addr = listener.local_addr()?;
957+
let base: http::Uri = format!("http://{addr}").parse()?;
958+
959+
let server_handle = {
960+
let request_count = request_count.clone();
961+
tokio::spawn(async move {
962+
loop {
963+
let (tcp, _) = match listener.accept().await {
964+
Ok(v) => v,
965+
Err(_) => return,
966+
};
967+
968+
let request_count = request_count.clone();
969+
let service = service_fn(move |req: Request<Incoming>| {
970+
let request_count = request_count.clone();
971+
async move {
972+
let (parts, body) = req.into_parts();
973+
if parts.method == Method::POST {
974+
let _ = body.collect().await;
975+
}
976+
977+
if parts.method == Method::GET && parts.uri.path() == "/2018-06-01/runtime/invocation/next" {
978+
let count = request_count.fetch_add(1, Ordering::SeqCst);
979+
if count < 300 {
980+
let request_id = format!("test-request-{}", count + 1);
981+
let res = Response::builder()
982+
.status(StatusCode::OK)
983+
.header("lambda-runtime-aws-request-id", &request_id)
984+
.header("lambda-runtime-deadline-ms", "9999999999999")
985+
.body(Full::new(Bytes::from_static(b"{}")))
986+
.unwrap();
987+
return Ok::<_, Infallible>(res);
988+
} else {
989+
let res = Response::builder()
990+
.status(StatusCode::NO_CONTENT)
991+
.body(Full::new(Bytes::new()))
992+
.unwrap();
993+
return Ok::<_, Infallible>(res);
994+
}
995+
}
996+
997+
if parts.method == Method::POST && parts.uri.path().contains("/response") {
998+
let res = Response::builder()
999+
.status(StatusCode::OK)
1000+
.body(Full::new(Bytes::new()))
1001+
.unwrap();
1002+
return Ok::<_, Infallible>(res);
1003+
}
1004+
1005+
let res = Response::builder()
1006+
.status(StatusCode::NOT_FOUND)
1007+
.body(Full::new(Bytes::new()))
1008+
.unwrap();
1009+
Ok::<_, Infallible>(res)
1010+
}
1011+
});
1012+
1013+
let io = TokioIo::new(tcp);
1014+
tokio::spawn(async move {
1015+
let _ = ServerBuilder::new(TokioExecutor::new()).serve_connection(io, service).await;
1016+
});
1017+
}
1018+
})
1019+
};
1020+
1021+
async fn test_handler(event: crate::LambdaEvent<serde_json::Value>) -> Result<(), Error> {
1022+
let request_id = &event.context.request_id;
1023+
info!(observed_request_id = request_id);
1024+
tokio::time::sleep(Duration::from_millis(100)).await;
1025+
Ok(())
1026+
}
1027+
1028+
let handler = crate::service_fn(test_handler);
1029+
let client = Arc::new(Client::builder().with_endpoint(base).build()?);
1030+
let runtime = Runtime {
1031+
client: client.clone(),
1032+
config: Arc::new(Config {
1033+
function_name: "test_fn".to_string(),
1034+
memory: 128,
1035+
version: "1".to_string(),
1036+
log_stream: "test_stream".to_string(),
1037+
log_group: "test_log".to_string(),
1038+
}),
1039+
service: wrap_handler(handler, client),
1040+
concurrency_limit: 3,
1041+
};
1042+
1043+
let runtime_handle = tokio::spawn(async move { runtime.run_concurrent().await });
1044+
1045+
loop {
1046+
tokio::time::sleep(Duration::from_millis(100)).await;
1047+
let count = request_count.load(Ordering::SeqCst);
1048+
if count >= 300 {
1049+
tokio::time::sleep(Duration::from_millis(500)).await;
1050+
break;
1051+
}
1052+
}
1053+
1054+
runtime_handle.abort();
1055+
server_handle.abort();
1056+
1057+
let logs = log_capture.logs.lock().unwrap();
1058+
let relevant_logs: Vec<_> = logs.iter().filter(|l| l.contains_key("observed_request_id")).collect();
1059+
1060+
assert!(relevant_logs.len() >= 300, "Should have at least 300 log entries, got {}", relevant_logs.len());
1061+
1062+
let mut seen_ids = HashSet::new();
1063+
for log in &relevant_logs {
1064+
let observed_id = log.get("observed_request_id").unwrap();
1065+
assert!(observed_id.starts_with("test-request-"), "Request ID should match pattern: {}", observed_id);
1066+
assert!(seen_ids.insert(observed_id.clone()), "Request ID should be unique: {}", observed_id);
1067+
}
1068+
1069+
println!("✅ Concurrent structured logging test passed with {} unique request IDs", seen_ids.len());
1070+
Ok(())
1071+
}
9111072
}

0 commit comments

Comments
 (0)