Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lambda-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ idna_adapter = "=1.2.0"
lambda_runtime = { path = ".", features = ["tracing", "graceful-shutdown"] }
pin-project-lite = { workspace = true }
tracing-appender = "0.2"
tracing-capture = "0.1.0"
tracing-subscriber = { version = "0.3", features = ["registry"] }

[package.metadata.docs.rs]
all-features = true
172 changes: 172 additions & 0 deletions lambda-runtime/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,4 +908,176 @@ mod endpoint_tests {
server_handle.abort();
Ok(())
}

#[tokio::test]
#[cfg(feature = "experimental-concurrency")]
async fn test_concurrent_structured_logging_isolation() -> Result<(), Error> {
use std::collections::HashSet;
use tracing::info;
use tracing_capture::{CaptureLayer, SharedStorage};
use tracing_subscriber::layer::SubscriberExt;

let storage = SharedStorage::default();
let subscriber = tracing_subscriber::registry().with(CaptureLayer::new(&storage));
tracing::subscriber::set_global_default(subscriber).unwrap();

let request_count = Arc::new(AtomicUsize::new(0));
let done = Arc::new(tokio::sync::Notify::new());
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let base: http::Uri = format!("http://{addr}").parse()?;

let server_handle = {
let request_count = request_count.clone();
let done = done.clone();
tokio::spawn(async move {
loop {
let (tcp, _) = match listener.accept().await {
Ok(v) => v,
Err(_) => return,
};

let request_count = request_count.clone();
let done = done.clone();
let service = service_fn(move |req: Request<Incoming>| {
let request_count = request_count.clone();
let done = done.clone();
async move {
let (parts, body) = req.into_parts();
if parts.method == Method::POST {
let _ = body.collect().await;
}

if parts.method == Method::GET && parts.uri.path() == "/2018-06-01/runtime/invocation/next"
{
let count = request_count.fetch_add(1, Ordering::SeqCst);
if count < 300 {
let request_id = format!("test-request-{}", count + 1);
let res = Response::builder()
.status(StatusCode::OK)
.header("lambda-runtime-aws-request-id", &request_id)
.header("lambda-runtime-deadline-ms", "9999999999999")
.body(Full::new(Bytes::from_static(b"{}")))
.unwrap();
return Ok::<_, Infallible>(res);
} else {
done.notify_one();
let res = Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Full::new(Bytes::new()))
.unwrap();
return Ok::<_, Infallible>(res);
}
}

if parts.method == Method::POST && parts.uri.path().contains("/response") {
let res = Response::builder()
.status(StatusCode::OK)
.body(Full::new(Bytes::new()))
.unwrap();
return Ok::<_, Infallible>(res);
}

let res = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Full::new(Bytes::new()))
.unwrap();
Ok::<_, Infallible>(res)
}
});

let io = TokioIo::new(tcp);
tokio::spawn(async move {
let _ = ServerBuilder::new(TokioExecutor::new())
.serve_connection(io, service)
.await;
});
}
})
};

async fn test_handler(event: crate::LambdaEvent<serde_json::Value>) -> Result<(), Error> {
let request_id = &event.context.request_id;
info!(observed_request_id = request_id);
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(())
}

let handler = crate::service_fn(test_handler);
let client = Arc::new(Client::builder().with_endpoint(base).build()?);

// Add tracing layer to capture span fields
use crate::layers::trace::TracingLayer;
use tower::ServiceBuilder;
let service = ServiceBuilder::new()
.layer(TracingLayer::new())
.service(wrap_handler(handler, client.clone()));

let runtime = Runtime {
client: client.clone(),
config: Arc::new(Config {
function_name: "test_fn".to_string(),
memory: 128,
version: "1".to_string(),
log_stream: "test_stream".to_string(),
log_group: "test_log".to_string(),
}),
service,
concurrency_limit: 3,
};

let runtime_handle = tokio::spawn(async move { runtime.run_concurrent().await });

done.notified().await;
// Give handlers time to complete after server signals done
tokio::time::sleep(Duration::from_millis(500)).await;

runtime_handle.abort();
server_handle.abort();

let storage = storage.lock();
let events: Vec<_> = storage
.all_events()
.filter(|e| e.value("observed_request_id").is_some())
.collect();

assert!(
events.len() >= 300,
"Should have at least 300 log entries, got {}",
events.len()
);

let mut seen_ids = HashSet::new();
for event in &events {
let observed_id = event["observed_request_id"].as_str().unwrap();

// Find the parent "Lambda runtime invoke" span and get its requestId
let span_request_id = event
.ancestors()
.find(|s| s.metadata().name() == "Lambda runtime invoke")
.and_then(|s| s.value("requestId"))
.and_then(|v| v.as_str())
.expect("Event should have a Lambda runtime invoke ancestor with requestId");

assert!(
observed_id.starts_with("test-request-"),
"Request ID should match pattern: {}",
observed_id
);
assert!(
seen_ids.insert(observed_id.to_string()),
"Request ID should be unique: {}",
observed_id
);

// Verify span request ID matches logged request ID
assert_eq!(
observed_id, span_request_id,
"Span request ID should match logged request ID: span={}, logged={}",
span_request_id, observed_id
);
}

Ok(())
}
}