Skip to content

Commit c5032a3

Browse files
fix(stackable-telemetry/webhook): Correctly extract connect and host info in Axum trace layer (#806)
* Add changelog entry * Update PR link in changelog * Add support for connect info * Add support for host info * Fix numeric fields, include host address and port * Revert Cargo.toml changes from b7c6118 * Add port and scheme to error message * Fix typo * Add quotes around field keys * chore: remove lifetimes, leverage Snafu's generated Into parameter --------- Co-authored-by: Nick Larsen <[email protected]>
1 parent cb5943f commit c5032a3

File tree

4 files changed

+143
-48
lines changed

4 files changed

+143
-48
lines changed

crates/stackable-telemetry/src/instrumentation/axum/mod.rs

+109-39
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@
1010
//!
1111
//! [1]: https://opentelemetry.io/
1212
//! [2]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/
13-
use std::{future::Future, net::SocketAddr, str::FromStr, task::Poll};
13+
use std::{future::Future, net::SocketAddr, num::ParseIntError, task::Poll};
1414

1515
use axum::{
16-
extract::{ConnectInfo, Host, MatchedPath, Request},
17-
http::{header::USER_AGENT, HeaderMap},
16+
extract::{ConnectInfo, MatchedPath, Request},
17+
http::{
18+
header::{HOST, USER_AGENT},
19+
HeaderMap,
20+
},
1821
response::Response,
1922
};
2023
use futures_util::ready;
2124
use opentelemetry::trace::SpanKind;
2225
use pin_project::pin_project;
26+
use snafu::{ResultExt, Snafu};
2327
use tower::{Layer, Service};
2428
use tracing::{debug, field::Empty, instrument, trace_span, Span};
2529
use tracing_opentelemetry::OpenTelemetrySpanExt;
@@ -30,6 +34,10 @@ mod injector;
3034
pub use extractor::*;
3135
pub use injector::*;
3236

37+
const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
38+
const DEFAULT_HTTPS_PORT: u16 = 443;
39+
const DEFAULT_HTTP_PORT: u16 = 80;
40+
3341
/// A Tower [`Layer`][1] which decorates [`TraceService`].
3442
///
3543
/// ### Example with Axum
@@ -163,13 +171,25 @@ where
163171
}
164172
}
165173

174+
#[derive(Debug, Snafu)]
175+
pub enum ServerHostError {
176+
#[snafu(display("failed to parse port {port:?} as u16 from string"))]
177+
ParsePort { source: ParseIntError, port: String },
178+
179+
#[snafu(display("encountered invalid request scheme {scheme:?}"))]
180+
InvalidScheme { scheme: String },
181+
182+
#[snafu(display("failed to extract any host information from request"))]
183+
ExtractHost,
184+
}
185+
166186
/// This trait provides various helper functions to extract data from a
167187
/// HTTP [`Request`].
168188
pub trait RequestExt {
169189
/// Returns the client socket address, if available.
170190
fn client_socket_address(&self) -> Option<SocketAddr>;
171191

172-
/// Returns the server socket address, if available.
192+
/// Returns the server host, if available.
173193
///
174194
/// ### Value Selection Strategy
175195
///
@@ -186,7 +206,7 @@ pub trait RequestExt {
186206
/// > - The Host header.
187207
///
188208
/// [1]: https://opentelemetry.io/docs/specs/semconv/http/http-spans/#setting-serveraddress-and-serverport-attributes
189-
fn server_socket_address(&self) -> Option<SocketAddr>;
209+
fn server_host(&self) -> Result<(String, u16), ServerHostError>;
190210

191211
/// Returns the matched path, like `/object/:object_id/tags`.
192212
///
@@ -211,9 +231,33 @@ pub trait RequestExt {
211231
}
212232

213233
impl RequestExt for Request {
214-
fn server_socket_address(&self) -> Option<SocketAddr> {
215-
let host = self.extensions().get::<Host>()?;
216-
SocketAddr::from_str(&host.0).ok()
234+
fn server_host(&self) -> Result<(String, u16), ServerHostError> {
235+
// There is currently no obvious way to use the Host extractor from Axum
236+
// directly. Using that extractor either requires impossible code (async
237+
// in the Service's call function, unnecessary cloning or consuming self
238+
// and returning a newly created request). That's why the following
239+
// section mirrors the Axum extractor implementation. The implementation
240+
// currently only looks for the X-Forwarded-Host / Host header and falls
241+
// back to the request URI host. The Axum implementation also extracts
242+
// data from the Forwarded header.
243+
244+
if let Some(host) = self
245+
.headers()
246+
.get(X_FORWARDED_HOST_HEADER_KEY)
247+
.and_then(|host| host.to_str().ok())
248+
{
249+
return server_host_to_tuple(host, self.uri().scheme_str());
250+
}
251+
252+
if let Some(host) = self.headers().get(HOST).and_then(|host| host.to_str().ok()) {
253+
return server_host_to_tuple(host, self.uri().scheme_str());
254+
}
255+
256+
if let (Some(host), Some(port)) = (self.uri().host(), self.uri().port_u16()) {
257+
return Ok((host.to_owned(), port));
258+
}
259+
260+
ExtractHostSnafu.fail()
217261
}
218262

219263
fn client_socket_address(&self) -> Option<SocketAddr> {
@@ -242,6 +286,29 @@ impl RequestExt for Request {
242286
}
243287
}
244288

289+
fn server_host_to_tuple(
290+
host: &str,
291+
scheme: Option<&str>,
292+
) -> Result<(String, u16), ServerHostError> {
293+
if let Some((host, port)) = host.split_once(':') {
294+
// First, see if the host header value contains a colon indicating that
295+
// it includes a non-default port.
296+
let port: u16 = port.parse().context(ParsePortSnafu { port })?;
297+
Ok((host.to_owned(), port))
298+
} else {
299+
// If there is no port included in the header value, the port is implied.
300+
// Port 443 for HTTPS and port 80 for HTTP.
301+
let port = match scheme {
302+
Some("https") => DEFAULT_HTTPS_PORT,
303+
Some("http") => DEFAULT_HTTP_PORT,
304+
Some(scheme) => return InvalidSchemeSnafu { scheme }.fail(),
305+
_ => return InvalidSchemeSnafu { scheme: "" }.fail(),
306+
};
307+
308+
Ok((host.to_owned(), port))
309+
}
310+
}
311+
245312
/// This trait provides various helper functions to create a [`Span`] out of
246313
/// an HTTP [`Request`].
247314
pub trait SpanExt {
@@ -319,7 +386,7 @@ impl SpanExt for Span {
319386
// - https://github.com/tokio-rs/tracing/pull/732
320387
//
321388
// Additionally we cannot use consts for field names. There was an
322-
// upstream PR to add support for it, but it was unexpectingly closed.
389+
// upstream PR to add support for it, but it was unexpectedly closed.
323390
// See https://github.com/tokio-rs/tracing/pull/2254.
324391
//
325392
// If this is eventually supported (maybe with our efforts), we can use
@@ -332,22 +399,21 @@ impl SpanExt for Span {
332399
debug!("create http span");
333400
let span = trace_span!(
334401
"HTTP request",
335-
otel.name = span_name,
336-
otel.kind = ?SpanKind::Server,
337-
otel.status_code = Empty,
338-
otel.status_message = Empty,
339-
http.request.method = http_method,
340-
http.response.status_code = Empty,
341-
url.path = url.path(),
342-
url.query = url.query(),
343-
url.scheme = url.scheme_str().unwrap_or_default(),
344-
user_agent.original = Empty,
345-
server.address = Empty,
346-
server.port = Empty,
347-
client.address = Empty,
348-
client.port = Empty,
349-
http.route = Empty,
350-
http.response.status_code = Empty,
402+
"otel.name" = span_name,
403+
"otel.kind" = ?SpanKind::Server,
404+
"otel.status_code" = Empty,
405+
"otel.status_message" = Empty,
406+
"http.request.method" = http_method,
407+
"http.response.status_code" = Empty,
408+
"http.route" = Empty,
409+
"url.path" = url.path(),
410+
"url.query" = url.query(),
411+
"url.scheme" = url.scheme_str().unwrap_or_default(),
412+
"user_agent.original" = Empty,
413+
"server.address" = Empty,
414+
"server.port" = Empty,
415+
"client.address" = Empty,
416+
"client.port" = Empty,
351417
// TODO (@Techassi): Add network.protocol.version
352418
);
353419

@@ -363,9 +429,12 @@ impl SpanExt for Span {
363429
// Setting server.address and server.port
364430
// See https://opentelemetry.io/docs/specs/semconv/http/http-spans/#setting-serveraddress-and-serverport-attributes
365431

366-
if let Some(server_socket_address) = req.server_socket_address() {
367-
span.record("server.address", server_socket_address.ip().to_string())
368-
.record("server.port", server_socket_address.port());
432+
if let Ok((host, port)) = req.server_host() {
433+
// NOTE (@Techassi): We cast to i64, because otherwise the field
434+
// will NOT be recorded as a number but as a string. This is likely
435+
// an issue in the tracing-opentelemetry crate.
436+
span.record("server.address", host)
437+
.record("server.port", port as i64);
369438
}
370439

371440
// Setting fields according to the HTTP server semantic conventions
@@ -375,25 +444,22 @@ impl SpanExt for Span {
375444
span.record("client.address", client_socket_address.ip().to_string());
376445

377446
if opt_in {
378-
span.record("client.port", client_socket_address.port());
447+
// NOTE (@Techassi): We cast to i64, because otherwise the field
448+
// will NOT be recorded as a number but as a string. This is
449+
// likely an issue in the tracing-opentelemetry crate.
450+
span.record("client.port", client_socket_address.port() as i64);
379451
}
380452
}
381453

382454
// Only include the headers if the user opted in, because this might
383455
// potentially be an expensive operation when many different headers
384456
// are present. The OpenTelemetry spec also marks this as opt-in.
385457

386-
// NOTE (@Techassi): Currently, tracing doesn't support recording
387-
// fields which are not registered at span creation which thus makes
388-
// it impossible to record request headers at runtime.
458+
// FIXME (@Techassi): Currently, tracing doesn't support recording
459+
// fields which are not registered at span creation which thus makes it
460+
// impossible to record request headers at runtime.
389461
// See: https://github.com/tokio-rs/tracing/issues/1343
390462

391-
// FIXME (@Techassi): Add support for this when tracing allows dynamic
392-
// fields.
393-
// if opt_in {
394-
// span.add_header_fields(req.headers())
395-
// }
396-
397463
if let Some(http_route) = req.matched_path() {
398464
span.record("http.route", http_route.as_str());
399465
}
@@ -420,7 +486,11 @@ impl SpanExt for Span {
420486

421487
fn finalize_with_response(&self, response: &mut Response) {
422488
let status_code = response.status();
423-
self.record("http.response.status_code", status_code.as_u16());
489+
490+
// NOTE (@Techassi): We cast to i64, because otherwise the field will
491+
// NOT be recorded as a number but as a string. This is likely an issue
492+
// in the tracing-opentelemetry crate.
493+
self.record("http.response.status_code", status_code.as_u16() as i64);
424494

425495
// Only set the span status to "Error" when we encountered an server
426496
// error. See:

crates/stackable-webhook/CHANGELOG.md

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@ All notable changes to this project will be documented in this file.
44

55
## [Unreleased]
66

7+
### Fixed
8+
9+
- Fix the extraction of `ConnectInfo` (data about the connection client) and
10+
the `Host` info (data about the server) in the `AxumTraceLayer`. This was
11+
previously not extracted correctly and thus not included in the OpenTelemetry
12+
compatible traces ([#806]).
13+
14+
[#806]: https://github.com/stackabletech/operator-rs/pull/806
15+
716
## [0.3.0] - 2024-05-08
817

918
### Added

crates/stackable-webhook/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ impl WebhookServer {
173173
async fn run_server(self) -> Result<()> {
174174
debug!("run webhook server");
175175

176+
// TODO (@Techassi): Make opt-in configurable from the outside
176177
// Create an OpenTelemetry tracing layer
177178
debug!("create tracing service (layer)");
178179
let trace_layer = AxumTraceLayer::new().with_opt_in();

crates/stackable-webhook/src/tls.rs

+24-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use tokio_rustls::{
1818
},
1919
TlsAcceptor,
2020
};
21-
use tower::Service;
21+
use tower::{Service, ServiceExt};
2222
use tracing::{instrument, trace, warn};
2323

2424
pub type Result<T, E = Error> = std::result::Result<T, E>;
@@ -139,10 +139,24 @@ impl TlsServer {
139139
socket_addr: self.socket_addr,
140140
})?;
141141

142+
// To be able to extract the connect info from incoming requests, it is
143+
// required to turn the router into a Tower service which is capable of
144+
// doing that. Calling `into_make_service_with_connect_info` returns a
145+
// new struct `IntoMakeServiceWithConnectInfo` which implements the
146+
// Tower Service trait. This service is called after the TCP connection
147+
// has been accepted.
148+
//
149+
// Inspired by:
150+
// - https://github.com/tokio-rs/axum/discussions/2397
151+
// - https://github.com/tokio-rs/axum/blob/b02ce307371a973039018a13fa012af14775948c/examples/serve-with-hyper/src/main.rs#L98
152+
153+
let mut router = self
154+
.router
155+
.into_make_service_with_connect_info::<SocketAddr>();
156+
142157
pin_mut!(tcp_listener);
143158
loop {
144159
let tls_acceptor = tls_acceptor.clone();
145-
let router = self.router.clone();
146160

147161
// Wait for new tcp connection
148162
let (tcp_stream, remote_addr) = match tcp_listener.accept().await {
@@ -153,6 +167,10 @@ impl TlsServer {
153167
}
154168
};
155169

170+
// Here, the connect info is extracted by calling Tower's Service
171+
// trait function on `IntoMakeServiceWithConnectInfo`
172+
let tower_service = router.call(remote_addr).await.unwrap();
173+
156174
tokio::spawn(async move {
157175
// Wait for tls handshake to happen
158176
let Ok(tls_stream) = tls_acceptor.accept(tcp_stream).await else {
@@ -167,16 +185,13 @@ impl TlsServer {
167185
// Hyper also has its own `Service` trait and doesn't use tower. We can use
168186
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
169187
// `tower::Service::call`.
170-
let service = service_fn(move |request: Request<Incoming>| {
171-
// We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
172-
// tower's `Service` requires `&mut self`.
173-
//
174-
// We don't need to call `poll_ready` since `Router` is always ready.
175-
router.clone().call(request)
188+
let hyper_service = service_fn(move |request: Request<Incoming>| {
189+
// We need to clone here, because oneshot consumes self
190+
tower_service.clone().oneshot(request)
176191
});
177192

178193
if let Err(err) = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
179-
.serve_connection_with_upgrades(tls_stream, service)
194+
.serve_connection_with_upgrades(tls_stream, hyper_service)
180195
.await
181196
{
182197
warn!(%err, %remote_addr, "failed to serve connection");

0 commit comments

Comments
 (0)