diff --git a/tonic/src/transport/channel/service/tls.rs b/tonic/src/transport/channel/service/tls.rs index 54abcbee3..bb94b8f88 100644 --- a/tonic/src/transport/channel/service/tls.rs +++ b/tonic/src/transport/channel/service/tls.rs @@ -1,5 +1,4 @@ use std::fmt; -use std::io::Cursor; use std::sync::Arc; use hyper_util::rt::TokioIo; @@ -10,7 +9,9 @@ use tokio_rustls::{ }; use super::io::BoxedIo; -use crate::transport::service::tls::{add_certs_from_pem, load_identity, TlsError, ALPN_H2}; +use crate::transport::service::tls::{ + add_certificate_to_root_store, load_identity, TlsError, ALPN_H2, +}; use crate::transport::tls::{Certificate, Identity}; #[derive(Clone)] @@ -43,7 +44,7 @@ impl TlsConnector { } for cert in ca_certs { - add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; + add_certificate_to_root_store(cert, &mut roots)?; } let builder = builder.with_root_certificates(roots); diff --git a/tonic/src/transport/server/service/tls.rs b/tonic/src/transport/server/service/tls.rs index d7667b493..4f054c0e4 100644 --- a/tonic/src/transport/server/service/tls.rs +++ b/tonic/src/transport/server/service/tls.rs @@ -1,4 +1,4 @@ -use std::{fmt, io::Cursor, sync::Arc}; +use std::{fmt, sync::Arc}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::{ @@ -9,7 +9,7 @@ use tokio_rustls::{ use crate::transport::{ server::Connected, - service::tls::{add_certs_from_pem, load_identity, ALPN_H2}, + service::tls::{add_certificate_to_root_store, load_identity, ALPN_H2}, Certificate, Identity, }; @@ -30,7 +30,7 @@ impl TlsAcceptor { None => builder.with_no_client_auth(), Some(cert) => { let mut roots = RootCertStore::empty(); - add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; + add_certificate_to_root_store(cert, &mut roots)?; let verifier = if client_auth_optional { WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated() } else { diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index ea7d1fd6b..2c06d7384 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -5,7 +5,7 @@ use tokio_rustls::rustls::{ RootCertStore, }; -use crate::transport::Identity; +use crate::transport::{tls::CertKind, Certificate, Identity}; /// h2 alpn in plain format for rustls. pub(crate) const ALPN_H2: &[u8] = b"h2"; @@ -34,12 +34,22 @@ impl fmt::Display for TlsError { impl std::error::Error for TlsError {} +fn convert_certificate_to_rustls_certificate_der( + certificate: Certificate, +) -> Result>, TlsError> { + let cert = match certificate.kind { + CertKind::Der(der) => vec![der.into()], + CertKind::Pem(pem) => rustls_pemfile::certs(&mut Cursor::new(pem)) + .collect::, _>>() + .map_err(|_| TlsError::CertificateParseError)?, + }; + Ok(cert) +} + pub(crate) fn load_identity( identity: Identity, ) -> Result<(Vec>, PrivateKeyDer<'static>), TlsError> { - let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert)) - .collect::, _>>() - .map_err(|_| TlsError::CertificateParseError)?; + let cert = convert_certificate_to_rustls_certificate_der(identity.cert)?; let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(identity.key)) else { return Err(TlsError::PrivateKeyParseError); @@ -48,15 +58,14 @@ pub(crate) fn load_identity( Ok((cert, key)) } -pub(crate) fn add_certs_from_pem( - mut certs: &mut dyn std::io::BufRead, +pub(crate) fn add_certificate_to_root_store( + certificate: Certificate, roots: &mut RootCertStore, -) -> Result<(), crate::Error> { - for cert in rustls_pemfile::certs(&mut certs).collect::, _>>()? { +) -> Result<(), TlsError> { + for cert in convert_certificate_to_rustls_certificate_der(certificate)? { roots .add(cert) .map_err(|_| TlsError::CertificateParseError)?; } - Ok(()) } diff --git a/tonic/src/transport/tls.rs b/tonic/src/transport/tls.rs index c2b7ef23f..feef8aec3 100644 --- a/tonic/src/transport/tls.rs +++ b/tonic/src/transport/tls.rs @@ -1,7 +1,13 @@ /// Represents a X509 certificate. #[derive(Debug, Clone)] pub struct Certificate { - pub(crate) pem: Vec, + pub(super) kind: CertKind, +} + +#[derive(Debug, Clone)] +pub(super) enum CertKind { + Der(Vec), + Pem(Vec), } /// Represents a private key and X509 certificate. @@ -12,39 +18,68 @@ pub struct Identity { } impl Certificate { + fn new(kind: CertKind) -> Self { + Self { kind } + } + + /// Parse a DER encoded X509 Certificate. + /// + /// The provided DER should include at least one PEM encoded certificate. + pub fn from_der(der: impl AsRef<[u8]>) -> Self { + let der = der.as_ref().into(); + Self::new(CertKind::Der(der)) + } + /// Parse a PEM encoded X509 Certificate. /// /// The provided PEM should include at least one PEM encoded certificate. pub fn from_pem(pem: impl AsRef<[u8]>) -> Self { let pem = pem.as_ref().into(); - Self { pem } + Self::new(CertKind::Pem(pem)) } - /// Get a immutable reference to underlying certificate - pub fn get_ref(&self) -> &[u8] { - self.pem.as_slice() + /// Returns whether this is a DER encoded certificate. + pub fn is_der(&self) -> bool { + matches!(self.kind, CertKind::Der(_)) } - /// Get a mutable reference to underlying certificate - pub fn get_mut(&mut self) -> &mut [u8] { - self.pem.as_mut() + /// Returns whether this is a PEM encoded certificate. + pub fn is_pem(&self) -> bool { + matches!(self.kind, CertKind::Pem(_)) } - /// Consumes `self`, returning the underlying certificate - pub fn into_inner(self) -> Vec { - self.pem + /// Returns the reference to DER encoded certificate. + /// Returns `None` When this is not encoded as DER. + pub fn der(&self) -> Option<&[u8]> { + match &self.kind { + CertKind::Der(der) => Some(der), + _ => None, + } } -} -impl AsRef<[u8]> for Certificate { - fn as_ref(&self) -> &[u8] { - self.pem.as_ref() + /// Returns the reference to PEM encoded certificate. + /// Returns `None` When this is not encoded as PEM. + pub fn pem(&self) -> Option<&[u8]> { + match &self.kind { + CertKind::Pem(pem) => Some(pem), + _ => None, + } + } + + /// Turns this value into the DER encoded bytes. + pub fn into_der(self) -> Result, Self> { + match self.kind { + CertKind::Der(der) => Ok(der), + _ => Err(self), + } } -} -impl AsMut<[u8]> for Certificate { - fn as_mut(&mut self) -> &mut [u8] { - self.pem.as_mut() + /// Turns this value into the PEM encoded bytes. + pub fn into_pem(self) -> Result, Self> { + match self.kind { + CertKind::Pem(pem) => Ok(pem), + _ => Err(self), + } } }