Skip to content

Commit

Permalink
feat(bindings): expose context on cert chain (#5132)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Clark <[email protected]>
  • Loading branch information
jmayclin and goatgoose authored Feb 26, 2025
1 parent ac1d098 commit 711ee0d
Showing 1 changed file with 160 additions and 4 deletions.
164 changes: 160 additions & 4 deletions bindings/rust/extended/s2n-tls/src/cert_chain.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::error::{Error, Fallible};
use crate::error::{Error, ErrorType, Fallible};
use s2n_tls_sys::*;
use std::{
any::Any,
ffi::c_void,
marker::PhantomData,
ptr::{self, NonNull},
sync::Arc,
Expand All @@ -13,6 +15,7 @@ use std::{
///
/// [CertificateChain] is internally reference counted. The reference counted `T`
/// must have a drop implementation.
#[derive(Debug)]
pub(crate) struct CertificateChainHandle<'a> {
pub(crate) cert: NonNull<s2n_cert_chain_and_key>,
is_owned: bool,
Expand Down Expand Up @@ -45,20 +48,57 @@ impl CertificateChainHandle<'_> {
_lifetime: PhantomData,
}
}

/// Corresponds to [s2n_cert_chain_and_key_get_ctx].
fn context_mut(&mut self) -> Option<&mut Context> {
let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) };
if context.is_null() {
None
} else {
Some(unsafe { &mut *(context as *mut Context) })
}
}

/// Corresponds to [s2n_cert_chain_and_key_get_ctx].
fn context(&self) -> Option<&Context> {
let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) };
if context.is_null() {
None
} else {
Some(unsafe { &*(context as *const Context) })
}
}
}

impl Drop for CertificateChainHandle<'_> {
/// Corresponds to [s2n_cert_chain_and_key_free].
fn drop(&mut self) {
// ignore failures since there's not much we can do about it
if self.is_owned {
if let Some(internal_context) = self.context_mut() {
drop(unsafe { Box::from_raw(internal_context) });
}
// ignore failures since there's not much we can do about it
unsafe {
// null the cert chain context out of an abundance of caution
let _ = s2n_cert_chain_and_key_set_ctx(self.cert.as_ptr(), std::ptr::null_mut())
.into_result();

let _ = s2n_cert_chain_and_key_free(self.cert.as_ptr()).into_result();
}
}
}
}

/// An internal container to hold the customer supplied application context.
///
/// We can't directly store the application context on the `s2n_cert_chain_and_key`,
/// because `*mut dyn Any` is a fat pointer (16 bytes) and can not be stored as
/// a c_void (8 bytes).
struct Context {
application_context: Box<dyn Any + Send + Sync>,
}

#[derive(Debug)]
pub struct Builder {
cert_handle: CertificateChainHandle<'static>,
}
Expand Down Expand Up @@ -125,6 +165,39 @@ impl Builder {
Ok(self)
}

/// Associates an arbitrary application context with the CertificateChain to
/// be later retrieved via [`CertificateChain::application_context()`].
///
/// This API will override an existing application context set on the Builder.
///
/// Corresponds to [s2n_cert_chain_and_key_set_ctx].
pub fn set_application_context<T: Send + Sync + 'static>(
&mut self,
app_context: T,
) -> Result<&mut Self, Error> {
match self.cert_handle.context_mut() {
Some(_) => Err(Error::bindings(
ErrorType::UsageError,
"cert builder error",
"set_application_context can only be called once",
)),
None => {
let app_context = Box::new(app_context);
let internal_context = Box::new(Context {
application_context: app_context,
});
unsafe {
s2n_cert_chain_and_key_set_ctx(
self.cert_handle.cert.as_ptr(),
Box::into_raw(internal_context) as *mut c_void,
)
.into_result()
}?;
Ok(self)
}
}
}

/// Return an immutable, internally-reference counted CertificateChain.
pub fn build(self) -> Result<CertificateChain<'static>, Error> {
// This method is currently infallible, but returning a result allows
Expand Down Expand Up @@ -177,6 +250,23 @@ impl CertificateChain<'_> {
}
}

/// Retrieves a reference to the application context associated with the
/// CertificateChain.
///
/// If an application context hasn't been set on the CertificateChain or if
/// the set application context isn't of type `T`, `None` will be returned.
///
/// To set a context on the connection, use [`Builder::set_application_context()`].
///
/// Corresponds to [s2n_cert_chain_and_key_get_ctx].
pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
if let Some(internal_context) = self.cert_handle.context() {
internal_context.application_context.downcast_ref()
} else {
None
}
}

/// Return the length of this certificate chain.
///
/// Note that the underlying API currently traverses a linked list, so this is a relatively
Expand Down Expand Up @@ -273,9 +363,12 @@ unsafe impl Send for Certificate<'_> {}
mod tests {
use crate::{
config,
error::{ErrorSource, ErrorType},
error::{Error as S2NError, ErrorSource, ErrorType},
security::DEFAULT_TLS13,
testing::{InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair},
testing::{
config_builder, CertKeyPair, InsecureAcceptAllCertificatesHandler, SniTestCerts,
TestPair,
},
};

use super::*;
Expand Down Expand Up @@ -495,4 +588,67 @@ mod tests {
fn assert_send_sync<T: 'static + Send + Sync>() {}
assert_send_sync::<CertificateChain<'static>>();
}

/// sanity check for basic cert chain context interactions
#[test]
fn application_context_workflow() -> Result<(), S2NError> {
let context: Arc<u64> = Arc::new(0xC0FFEE);
let handle = Arc::clone(&context);
assert_eq!(Arc::strong_count(&handle), 2);

let default = CertKeyPair::default();
let mut chain = Builder::new()?;
chain.load_pem(default.cert(), default.key())?;
chain.set_application_context(context)?;
let chain = chain.build()?;

let invalid_type_get = chain.application_context::<u64>();
assert!(invalid_type_get.is_none());

let retrieved_context = chain.application_context::<Arc<u64>>().unwrap();
assert_eq!(*retrieved_context.as_ref(), 0xC0FFEE);
assert_eq!(Arc::strong_count(&handle), 2);
drop(chain);
assert_eq!(Arc::strong_count(&handle), 1);
Ok(())
}

/// When an application context is overridden, it should be error.
#[test]
fn application_context_override() -> Result<(), S2NError> {
let initial: Arc<u64> = Arc::new(0xC0FFEE);
let overridden: Arc<[u8; 6]> = Arc::new(*b"coffee");

let mut builder = Builder::new()?;
builder.set_application_context(initial)?;
let err = builder.set_application_context(overridden).unwrap_err();
assert_eq!(err.kind(), ErrorType::UsageError);

Ok(())
}

/// An application context should be retrievable from a selected cert after
/// the handshake.
#[test]
fn application_context_from_selected_cert() -> Result<(), S2NError> {
let default = CertKeyPair::default();
let mut chain = Builder::new()?;
chain.load_pem(default.cert(), default.key())?;
chain.set_application_context(0xC0FFEE_u64)?;

let mut server_config = config::Builder::new();
server_config.load_chain(chain.build()?)?;

let client_config = config_builder(&crate::security::DEFAULT).unwrap();

let mut test_pair =
TestPair::from_configs(&client_config.build()?, &server_config.build()?);
test_pair.handshake()?;

let selected_cert = test_pair.server.selected_cert().unwrap();
let context = selected_cert.application_context::<u64>();
assert_eq!(context, Some(&0xC0FFEE_u64));

Ok(())
}
}

0 comments on commit 711ee0d

Please sign in to comment.