diff --git a/srtp/src/cipher/cipher_aead_aes_gcm.rs b/srtp/src/cipher/cipher_aead_aes_gcm.rs index cc880694b..fe9cbb961 100644 --- a/srtp/src/cipher/cipher_aead_aes_gcm.rs +++ b/srtp/src/cipher/cipher_aead_aes_gcm.rs @@ -1,6 +1,10 @@ +use std::marker::PhantomData; + +use aead::consts::{U12, U16}; +use aes::cipher::{BlockEncrypt, BlockSizeUser, Unsigned}; use aes_gcm::aead::generic_array::GenericArray; use aes_gcm::aead::{Aead, Payload}; -use aes_gcm::{Aes128Gcm, KeyInit, Nonce}; +use aes_gcm::{AesGcm, KeyInit, Nonce}; use byteorder::{BigEndian, ByteOrder}; use bytes::{Bytes, BytesMut}; use util::marshal::*; @@ -8,22 +12,43 @@ use util::marshal::*; use super::Cipher; use crate::error::{Error, Result}; use crate::key_derivation::*; +use crate::protection_profile::ProtectionProfile; pub const CIPHER_AEAD_AES_GCM_AUTH_TAG_LEN: usize = 16; const RTCP_ENCRYPTION_FLAG: u8 = 0x80; /// AEAD Cipher based on AES. -pub(crate) struct CipherAeadAesGcm { - srtp_cipher: aes_gcm::Aes128Gcm, - srtcp_cipher: aes_gcm::Aes128Gcm, +pub(crate) struct CipherAeadAesGcm +where + NonceSize: Unsigned, +{ + profile: ProtectionProfile, + srtp_cipher: aes_gcm::AesGcm, + srtcp_cipher: aes_gcm::AesGcm, srtp_session_salt: Vec, srtcp_session_salt: Vec, + _tag: PhantomData, } -impl Cipher for CipherAeadAesGcm { - fn auth_tag_len(&self) -> usize { - CIPHER_AEAD_AES_GCM_AUTH_TAG_LEN +impl Cipher for CipherAeadAesGcm +where + NS: Unsigned, + AES: BlockEncrypt + KeyInit + BlockSizeUser + 'static, + AesGcm: Aead, +{ + fn rtp_auth_tag_len(&self) -> usize { + self.profile.rtp_auth_tag_len() + } + + /// Get RTCP authenticated tag length. + fn rtcp_auth_tag_len(&self) -> usize { + self.profile.rtcp_auth_tag_len() + } + + /// Get AEAD auth key length of the cipher. + fn aead_auth_tag_len(&self) -> usize { + self.profile.aead_auth_tag_len() } fn encrypt_rtp( @@ -34,7 +59,7 @@ impl Cipher for CipherAeadAesGcm { ) -> Result { // Grow the given buffer to fit the output. let header_len = header.marshal_size(); - let mut writer = BytesMut::with_capacity(payload.len() + self.auth_tag_len()); + let mut writer = BytesMut::with_capacity(payload.len() + self.aead_auth_tag_len()); // Copy header unencrypted. writer.extend_from_slice(&payload[..header_len]); @@ -59,7 +84,7 @@ impl Cipher for CipherAeadAesGcm { header: &rtp::header::Header, roc: u32, ) -> Result { - if ciphertext.len() < self.auth_tag_len() { + if ciphertext.len() < self.aead_auth_tag_len() { return Err(Error::ErrFailedToVerifyAuthTag); } @@ -101,7 +126,7 @@ impl Cipher for CipherAeadAesGcm { } fn decrypt_rtcp(&mut self, encrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { - if encrypted.len() < self.auth_tag_len() + SRTCP_INDEX_SIZE { + if encrypted.len() < self.aead_auth_tag_len() + SRTCP_INDEX_SIZE { return Err(Error::ErrFailedToVerifyAuthTag); } @@ -131,10 +156,31 @@ impl Cipher for CipherAeadAesGcm { } } -impl CipherAeadAesGcm { +impl CipherAeadAesGcm +where + NS: Unsigned, + AES: BlockEncrypt + KeyInit + BlockSizeUser + 'static, + AesGcm: Aead, +{ /// Create a new AEAD instance. - pub(crate) fn new(master_key: &[u8], master_salt: &[u8]) -> Result { - let srtp_session_key = aes_cm_key_derivation( + pub(crate) fn new( + profile: ProtectionProfile, + master_key: &[u8], + master_salt: &[u8], + ) -> Result> { + assert_eq!(profile.aead_auth_tag_len(), AES::block_size()); + assert_eq!(profile.key_len(), AES::key_size()); + assert_eq!(profile.salt_len(), master_salt.len()); + + type Kdf = fn(u8, &[u8], &[u8], usize, usize) -> Result>; + let kdf: Kdf = match profile { + ProtectionProfile::AeadAes128Gcm => aes_cm_key_derivation, + // AES_256_GCM must use AES_256_CM_PRF as per https://datatracker.ietf.org/doc/html/rfc7714#section-11 + ProtectionProfile::AeadAes256Gcm => aes_256_cm_key_derivation, + _ => unreachable!(), + }; + + let srtp_session_key = kdf( LABEL_SRTP_ENCRYPTION, master_key, master_salt, @@ -144,9 +190,9 @@ impl CipherAeadAesGcm { let srtp_block = GenericArray::from_slice(&srtp_session_key); - let srtp_cipher = Aes128Gcm::new(srtp_block); + let srtp_cipher = AesGcm::::new(srtp_block); - let srtcp_session_key = aes_cm_key_derivation( + let srtcp_session_key = kdf( LABEL_SRTCP_ENCRYPTION, master_key, master_salt, @@ -156,29 +202,31 @@ impl CipherAeadAesGcm { let srtcp_block = GenericArray::from_slice(&srtcp_session_key); - let srtcp_cipher = Aes128Gcm::new(srtcp_block); + let srtcp_cipher = AesGcm::::new(srtcp_block); - let srtp_session_salt = aes_cm_key_derivation( + let srtp_session_salt = kdf( LABEL_SRTP_SALT, master_key, master_salt, 0, - master_key.len(), + master_salt.len(), )?; - let srtcp_session_salt = aes_cm_key_derivation( + let srtcp_session_salt = kdf( LABEL_SRTCP_SALT, master_key, master_salt, 0, - master_key.len(), + master_salt.len(), )?; Ok(CipherAeadAesGcm { + profile, srtp_cipher, srtcp_cipher, srtp_session_salt, srtcp_session_salt, + _tag: PhantomData, }) } @@ -245,3 +293,52 @@ impl CipherAeadAesGcm { aad } } + +#[cfg(test)] +mod tests { + use aes::{Aes128, Aes256}; + + use super::*; + + #[test] + fn test_aead_aes_gcm_128() { + let profile = ProtectionProfile::AeadAes128Gcm; + let master_key = vec![0u8; profile.key_len()]; + let master_salt = vec![0u8; 12]; + + let mut cipher = + CipherAeadAesGcm::::new(profile, &master_key, &master_salt).unwrap(); + + let header = rtp::header::Header { + ssrc: 0x12345678, + ..Default::default() + }; + + let payload = vec![0u8; 100]; + let encrypted = cipher.encrypt_rtp(&payload, &header, 0).unwrap(); + + let decrypted = cipher.decrypt_rtp(&encrypted, &header, 0).unwrap(); + assert_eq!(&decrypted[..], &payload[..]); + } + + #[test] + fn test_aead_aes_gcm_256() { + let profile = ProtectionProfile::AeadAes256Gcm; + let master_key = vec![0u8; profile.key_len()]; + let master_salt = vec![0u8; 12]; + + let mut cipher = + CipherAeadAesGcm::::new(profile, &master_key, &master_salt).unwrap(); + + let header = rtp::header::Header { + ssrc: 0x12345678, + ..Default::default() + }; + + let payload = vec![0u8; 100]; + let encrypted = cipher.encrypt_rtp(&payload, &header, 0).unwrap(); + + let decrypted = cipher.decrypt_rtp(&encrypted, &header, 0).unwrap(); + assert_eq!(&decrypted[..], &payload[..]); + } +} diff --git a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs index 38ebb3088..f351ef64c 100644 --- a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs +++ b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs @@ -8,6 +8,7 @@ use util::marshal::*; use super::{Cipher, CipherInner}; use crate::error::{Error, Result}; use crate::key_derivation::*; +use crate::protection_profile::ProtectionProfile; type Aes128Ctr = ctr::Ctr128BE; @@ -18,8 +19,8 @@ pub(crate) struct CipherAesCmHmacSha1 { } impl CipherAesCmHmacSha1 { - pub fn new(master_key: &[u8], master_salt: &[u8]) -> Result { - let inner = CipherInner::new(master_key, master_salt)?; + pub fn new(profile: ProtectionProfile, master_key: &[u8], master_salt: &[u8]) -> Result { + let inner = CipherInner::new(profile, master_key, master_salt)?; let srtp_session_key = aes_cm_key_derivation( LABEL_SRTP_ENCRYPTION, @@ -45,8 +46,19 @@ impl CipherAesCmHmacSha1 { } impl Cipher for CipherAesCmHmacSha1 { - fn auth_tag_len(&self) -> usize { - self.inner.auth_tag_len() + /// Get RTP authenticated tag length. + fn rtp_auth_tag_len(&self) -> usize { + self.inner.profile.rtp_auth_tag_len() + } + + /// Get RTCP authenticated tag length. + fn rtcp_auth_tag_len(&self) -> usize { + self.inner.profile.rtcp_auth_tag_len() + } + + /// Get AEAD auth key length of the cipher. + fn aead_auth_tag_len(&self) -> usize { + self.inner.profile.aead_auth_tag_len() } fn get_rtcp_index(&self, input: &[u8]) -> usize { @@ -59,7 +71,7 @@ impl Cipher for CipherAesCmHmacSha1 { header: &rtp::header::Header, roc: u32, ) -> Result { - let mut writer = Vec::with_capacity(plaintext.len() + self.auth_tag_len()); + let mut writer = Vec::with_capacity(plaintext.len() + self.rtp_auth_tag_len()); // Write the plaintext to the destination buffer. writer.extend_from_slice(plaintext); @@ -77,7 +89,7 @@ impl Cipher for CipherAesCmHmacSha1 { stream.apply_keystream(&mut writer[header.marshal_size()..]); // Generate the auth tag. - let auth_tag = &self.inner.generate_srtp_auth_tag(&writer, roc)[..self.auth_tag_len()]; + let auth_tag = &self.inner.generate_srtp_auth_tag(&writer, roc)[..self.rtp_auth_tag_len()]; writer.extend(auth_tag); Ok(Bytes::from(writer)) @@ -90,19 +102,19 @@ impl Cipher for CipherAesCmHmacSha1 { roc: u32, ) -> Result { let encrypted_len = encrypted.len(); - if encrypted_len < self.auth_tag_len() { - return Err(Error::SrtpTooSmall(encrypted_len, self.auth_tag_len())); + if encrypted_len < self.rtp_auth_tag_len() { + return Err(Error::SrtpTooSmall(encrypted_len, self.rtp_auth_tag_len())); } - let mut writer = Vec::with_capacity(encrypted_len - self.auth_tag_len()); + let mut writer = Vec::with_capacity(encrypted_len - self.rtp_auth_tag_len()); // Split the auth tag and the cipher text into two parts. - let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..]; - let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()]; + let actual_tag = &encrypted[encrypted_len - self.rtp_auth_tag_len()..]; + let cipher_text = &encrypted[..encrypted_len - self.rtp_auth_tag_len()]; // Generate the auth tag we expect to see from the ciphertext. let expected_tag = - &self.inner.generate_srtp_auth_tag(cipher_text, roc)[..self.auth_tag_len()]; + &self.inner.generate_srtp_auth_tag(cipher_text, roc)[..self.rtp_auth_tag_len()]; // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. @@ -132,7 +144,7 @@ impl Cipher for CipherAesCmHmacSha1 { fn encrypt_rtcp(&mut self, decrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { let mut writer = - Vec::with_capacity(decrypted.len() + SRTCP_INDEX_SIZE + self.auth_tag_len()); + Vec::with_capacity(decrypted.len() + SRTCP_INDEX_SIZE + self.rtcp_auth_tag_len()); // Write the decrypted to the destination buffer. writer.extend_from_slice(decrypted); @@ -155,7 +167,7 @@ impl Cipher for CipherAesCmHmacSha1 { writer.put_u32(srtcp_index as u32 | (1u32 << 31)); // Generate the auth tag. - let auth_tag = &self.inner.generate_srtcp_auth_tag(&writer)[..self.auth_tag_len()]; + let auth_tag = &self.inner.generate_srtcp_auth_tag(&writer)[..self.rtcp_auth_tag_len()]; writer.extend(auth_tag); Ok(Bytes::from(writer)) @@ -163,14 +175,14 @@ impl Cipher for CipherAesCmHmacSha1 { fn decrypt_rtcp(&mut self, encrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { let encrypted_len = encrypted.len(); - if encrypted_len < self.auth_tag_len() + SRTCP_INDEX_SIZE { + if encrypted_len < self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE { return Err(Error::SrtcpTooSmall( encrypted_len, - self.auth_tag_len() + SRTCP_INDEX_SIZE, + self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE, )); } - let tail_offset = encrypted_len - (self.auth_tag_len() + SRTCP_INDEX_SIZE); + let tail_offset = encrypted_len - (self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE); if tail_offset < 8 { return Err(Error::ErrTooShortRtcp); } @@ -185,18 +197,19 @@ impl Cipher for CipherAesCmHmacSha1 { } // Split the auth tag and the cipher text into two parts. - let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..]; - if actual_tag.len() != self.auth_tag_len() { + let actual_tag = &encrypted[encrypted_len - self.rtcp_auth_tag_len()..]; + if actual_tag.len() != self.rtcp_auth_tag_len() { return Err(Error::RtcpInvalidLengthAuthTag( actual_tag.len(), - self.auth_tag_len(), + self.rtcp_auth_tag_len(), )); } - let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()]; + let cipher_text = &encrypted[..encrypted_len - self.rtcp_auth_tag_len()]; // Generate the auth tag we expect to see from the ciphertext. - let expected_tag = &self.inner.generate_srtcp_auth_tag(cipher_text)[..self.auth_tag_len()]; + let expected_tag = + &self.inner.generate_srtcp_auth_tag(cipher_text)[..self.rtcp_auth_tag_len()]; // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. diff --git a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/mod.rs b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/mod.rs index 00bae63ae..d36fb3477 100644 --- a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/mod.rs +++ b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/mod.rs @@ -24,6 +24,7 @@ type HmacSha1 = Hmac; pub const CIPHER_AES_CM_HMAC_SHA1AUTH_TAG_LEN: usize = 10; pub(crate) struct CipherInner { + profile: ProtectionProfile, srtp_session_salt: Vec, srtp_session_auth: HmacSha1, srtcp_session_salt: Vec, @@ -31,7 +32,7 @@ pub(crate) struct CipherInner { } impl CipherInner { - pub fn new(master_key: &[u8], master_salt: &[u8]) -> Result { + pub fn new(profile: ProtectionProfile, master_key: &[u8], master_salt: &[u8]) -> Result { let srtp_session_salt = aes_cm_key_derivation( LABEL_SRTP_SALT, master_key, @@ -70,6 +71,7 @@ impl CipherInner { .map_err(|e| Error::Other(e.to_string()))?; Ok(Self { + profile, srtp_session_salt, srtp_session_auth, srtcp_session_salt, @@ -121,12 +123,8 @@ impl CipherInner { signer.finalize().into_bytes().into() } - fn auth_tag_len(&self) -> usize { - CIPHER_AES_CM_HMAC_SHA1AUTH_TAG_LEN - } - fn get_rtcp_index(&self, input: &[u8]) -> usize { - let tail_offset = input.len() - (self.auth_tag_len() + SRTCP_INDEX_SIZE); + let tail_offset = input.len() - (self.profile.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE); (BigEndian::read_u32(&input[tail_offset..tail_offset + SRTCP_INDEX_SIZE]) & !(1 << 31)) as usize } diff --git a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs index 9b6ef27ad..48e36ddbf 100644 --- a/srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs +++ b/srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs @@ -5,6 +5,7 @@ use subtle::ConstantTimeEq; use util::marshal::*; use super::{Cipher, CipherInner}; +use crate::protection_profile::ProtectionProfile; use crate::{ error::{Error, Result}, key_derivation::*, @@ -17,8 +18,8 @@ pub(crate) struct CipherAesCmHmacSha1 { } impl CipherAesCmHmacSha1 { - pub fn new(master_key: &[u8], master_salt: &[u8]) -> Result { - let inner = CipherInner::new(master_key, master_salt)?; + pub fn new(profile: ProtectionProfile, master_key: &[u8], master_salt: &[u8]) -> Result { + let inner = CipherInner::new(profile, master_key, master_salt)?; let srtp_session_key = aes_cm_key_derivation( LABEL_SRTP_ENCRYPTION, @@ -56,8 +57,19 @@ impl CipherAesCmHmacSha1 { } impl Cipher for CipherAesCmHmacSha1 { - fn auth_tag_len(&self) -> usize { - self.inner.auth_tag_len() + /// Get RTP authenticated tag length. + fn rtp_auth_tag_len(&self) -> usize { + self.inner.profile.rtp_auth_tag_len() + } + + /// Get RTCP authenticated tag length. + fn rtcp_auth_tag_len(&self) -> usize { + self.inner.profile.rtcp_auth_tag_len() + } + + /// Get AEAD auth key length of the cipher. + fn aead_auth_tag_len(&self) -> usize { + self.inner.profile.aead_auth_tag_len() } fn get_rtcp_index(&self, input: &[u8]) -> usize { @@ -71,7 +83,7 @@ impl Cipher for CipherAesCmHmacSha1 { roc: u32, ) -> Result { let header_len = header.marshal_size(); - let mut writer = Vec::with_capacity(plaintext.len() + self.auth_tag_len()); + let mut writer = Vec::with_capacity(plaintext.len() + self.rtp_auth_tag_len()); // Copy the header unencrypted. writer.extend_from_slice(&plaintext[..header_len]); @@ -94,7 +106,7 @@ impl Cipher for CipherAesCmHmacSha1 { .unwrap(); // Generate and write the auth tag. - let auth_tag = &self.inner.generate_srtp_auth_tag(&writer, roc)[..self.auth_tag_len()]; + let auth_tag = &self.inner.generate_srtp_auth_tag(&writer, roc)[..self.rtp_auth_tag_len()]; writer.extend_from_slice(auth_tag); Ok(Bytes::from(writer)) @@ -107,20 +119,20 @@ impl Cipher for CipherAesCmHmacSha1 { roc: u32, ) -> Result { let encrypted_len = encrypted.len(); - if encrypted_len < self.auth_tag_len() { - return Err(Error::SrtpTooSmall(encrypted_len, self.auth_tag_len())); + if encrypted_len < self.rtp_auth_tag_len() { + return Err(Error::SrtpTooSmall(encrypted_len, self.rtp_auth_tag_len())); } let header_len = header.marshal_size(); - let mut writer = Vec::with_capacity(encrypted_len - self.auth_tag_len()); + let mut writer = Vec::with_capacity(encrypted_len - self.rtp_auth_tag_len()); // Split the auth tag and the cipher text into two parts. - let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..]; - let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()]; + let actual_tag = &encrypted[encrypted_len - self.rtp_auth_tag_len()..]; + let cipher_text = &encrypted[..encrypted_len - self.rtp_auth_tag_len()]; // Generate the auth tag we expect to see from the ciphertext. let expected_tag = - &self.inner.generate_srtp_auth_tag(cipher_text, roc)[..self.auth_tag_len()]; + &self.inner.generate_srtp_auth_tag(cipher_text, roc)[..self.rtp_auth_tag_len()]; // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. @@ -139,7 +151,7 @@ impl Cipher for CipherAesCmHmacSha1 { &self.inner.srtp_session_salt, ); - writer.resize(encrypted_len - self.auth_tag_len(), 0); + writer.resize(encrypted_len - self.rtp_auth_tag_len(), 0); self.rtp_ctx.decrypt_init(None, None, Some(&nonce)).unwrap(); let count = self .rtp_ctx @@ -155,7 +167,8 @@ impl Cipher for CipherAesCmHmacSha1 { fn encrypt_rtcp(&mut self, decrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { let decrypted_len = decrypted.len(); - let mut writer = Vec::with_capacity(decrypted_len + SRTCP_INDEX_SIZE + self.auth_tag_len()); + let mut writer = + Vec::with_capacity(decrypted_len + SRTCP_INDEX_SIZE + self.rtcp_auth_tag_len()); // Write the decrypted to the destination buffer. writer.extend_from_slice(&decrypted[..HEADER_LENGTH + SSRC_LENGTH]); @@ -187,7 +200,7 @@ impl Cipher for CipherAesCmHmacSha1 { writer.put_u32(srtcp_index as u32 | (1u32 << 31)); // Generate the auth tag. - let auth_tag = &self.inner.generate_srtcp_auth_tag(&writer)[..self.auth_tag_len()]; + let auth_tag = &self.inner.generate_srtcp_auth_tag(&writer)[..self.rtcp_auth_tag_len()]; writer.extend(auth_tag); Ok(Bytes::from(writer)) @@ -196,14 +209,14 @@ impl Cipher for CipherAesCmHmacSha1 { fn decrypt_rtcp(&mut self, encrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result { let encrypted_len = encrypted.len(); - if encrypted_len < self.auth_tag_len() + SRTCP_INDEX_SIZE { + if encrypted_len < self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE { return Err(Error::SrtcpTooSmall( encrypted_len, - self.auth_tag_len() + SRTCP_INDEX_SIZE, + self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE, )); } - let tail_offset = encrypted_len - (self.auth_tag_len() + SRTCP_INDEX_SIZE); + let tail_offset = encrypted_len - (self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE); if tail_offset < 8 { return Err(Error::ErrTooShortRtcp); } @@ -218,18 +231,19 @@ impl Cipher for CipherAesCmHmacSha1 { } // Split the auth tag and the cipher text into two parts. - let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..]; - if actual_tag.len() != self.auth_tag_len() { + let actual_tag = &encrypted[encrypted_len - self.rtcp_auth_tag_len()..]; + if actual_tag.len() != self.rtcp_auth_tag_len() { return Err(Error::RtcpInvalidLengthAuthTag( actual_tag.len(), - self.auth_tag_len(), + self.rtcp_auth_tag_len(), )); } - let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()]; + let cipher_text = &encrypted[..encrypted_len - self.rtcp_auth_tag_len()]; // Generate the auth tag we expect to see from the ciphertext. - let expected_tag = &self.inner.generate_srtcp_auth_tag(cipher_text)[..self.auth_tag_len()]; + let expected_tag = + &self.inner.generate_srtcp_auth_tag(cipher_text)[..self.rtcp_auth_tag_len()]; // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. diff --git a/srtp/src/cipher/mod.rs b/srtp/src/cipher/mod.rs index 181d7328d..351a8fecf 100644 --- a/srtp/src/cipher/mod.rs +++ b/srtp/src/cipher/mod.rs @@ -31,8 +31,14 @@ use crate::error::Result; /// Cipher represents a implementation of one /// of the SRTP Specific ciphers. pub(crate) trait Cipher { - /// Get authenticated tag length. - fn auth_tag_len(&self) -> usize; + /// Get RTP authenticated tag length. + fn rtp_auth_tag_len(&self) -> usize; + + /// Get RTCP authenticated tag length. + fn rtcp_auth_tag_len(&self) -> usize; + + /// Get AEAD auth key length of the cipher. + fn aead_auth_tag_len(&self) -> usize; /// Retrieved RTCP index. fn get_rtcp_index(&self, input: &[u8]) -> usize; diff --git a/srtp/src/context/mod.rs b/srtp/src/context/mod.rs index 045aaaec4..3c8f79a96 100644 --- a/srtp/src/context/mod.rs +++ b/srtp/src/context/mod.rs @@ -7,6 +7,8 @@ mod srtp_test; use std::collections::HashMap; +use aes::Aes128; +use aes::Aes256; use util::replay_detector::*; use crate::cipher::cipher_aead_aes_gcm::*; @@ -119,13 +121,21 @@ impl Context { } let cipher: Box = match profile { - ProtectionProfile::Aes128CmHmacSha1_80 => { - Box::new(CipherAesCmHmacSha1::new(master_key, master_salt)?) + ProtectionProfile::Aes128CmHmacSha1_32 | ProtectionProfile::Aes128CmHmacSha1_80 => { + Box::new(CipherAesCmHmacSha1::new(profile, master_key, master_salt)?) } - ProtectionProfile::AeadAes128Gcm => { - Box::new(CipherAeadAesGcm::new(master_key, master_salt)?) - } + ProtectionProfile::AeadAes128Gcm => Box::new(CipherAeadAesGcm::::new( + profile, + master_key, + master_salt, + )?), + + ProtectionProfile::AeadAes256Gcm => Box::new(CipherAeadAesGcm::::new( + profile, + master_key, + master_salt, + )?), }; let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt { diff --git a/srtp/src/context/srtcp_test.rs b/srtp/src/context/srtcp_test.rs index 6f76abee9..c7389bf24 100644 --- a/srtp/src/context/srtcp_test.rs +++ b/srtp/src/context/srtcp_test.rs @@ -130,7 +130,7 @@ fn test_rtcp_lifecycle() -> Result<()> { #[test] fn test_rtcp_invalid_auth_tag() -> Result<()> { - let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len(); + let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.rtcp_auth_tag_len(); let mut decrypt_context = Context::new( &RTCP_TEST_MASTER_KEY, @@ -217,7 +217,7 @@ fn test_encrypt_rtcp_separation() -> Result<()> { None, )?; - let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len(); + let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.rtcp_auth_tag_len(); let mut decrypt_context = Context::new( &RTCP_TEST_MASTER_KEY, diff --git a/srtp/src/context/srtp.rs b/srtp/src/context/srtp.rs index 810fb5643..2ed7a259f 100644 --- a/srtp/src/context/srtp.rs +++ b/srtp/src/context/srtp.rs @@ -10,7 +10,7 @@ impl Context { encrypted: &[u8], header: &rtp::header::Header, ) -> Result { - let auth_tag_len = self.cipher.auth_tag_len(); + let auth_tag_len = self.cipher.rtp_auth_tag_len(); if encrypted.len() < header.marshal_size() + auth_tag_len { return Err(Error::ErrTooShortRtp); } diff --git a/srtp/src/context/srtp_test.rs b/srtp/src/context/srtp_test.rs index 79390aa9f..c9cfe0d52 100644 --- a/srtp/src/context/srtp_test.rs +++ b/srtp/src/context/srtp_test.rs @@ -121,7 +121,7 @@ fn test_rtp_invalid_auth() -> Result<()> { fn test_rtp_lifecycle() -> Result<()> { let mut encrypt_context = build_test_context()?; let mut decrypt_context = build_test_context()?; - let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len(); + let auth_tag_len = ProtectionProfile::Aes128CmHmacSha1_80.rtp_auth_tag_len(); for test_case in RTP_TEST_CASES.iter() { let decrypted_pkt = rtp::packet::Packet { diff --git a/srtp/src/error.rs b/srtp/src/error.rs index da39333d3..bb7e1ee0f 100644 --- a/srtp/src/error.rs +++ b/srtp/src/error.rs @@ -48,6 +48,12 @@ pub enum Error { #[error("index_over_kdr > 0 is not supported yet")] UnsupportedIndexOverKdr, + #[error("invalid master key length for aes_256_cm")] + InvalidMasterKeyLength, + #[error("invalid master salt length for aes_256_cm")] + InvalidMasterSaltLength, + #[error("out_len > 32 is not supported for aes_256_cm")] + UnsupportedOutLength, #[error("SRTP Master Key must be len {0}, got {1}")] SrtpMasterKeyLength(usize, usize), #[error("SRTP Salt must be len {0}, got {1}")] diff --git a/srtp/src/key_derivation.rs b/srtp/src/key_derivation.rs index 343450937..4eaf8b222 100644 --- a/srtp/src/key_derivation.rs +++ b/srtp/src/key_derivation.rs @@ -1,6 +1,6 @@ -use aes::cipher::generic_array::GenericArray; use aes::cipher::BlockEncrypt; -use aes::Aes128; +use aes::Aes256; +use aes::{cipher::generic_array::GenericArray, Aes128}; use aes_gcm::KeyInit; use crate::error::{Error, Result}; @@ -50,7 +50,58 @@ pub(crate) fn aes_cm_key_derivation( prf_in[n_master_key - 1] = (i & 0xFF) as u8; out[n..n + n_master_key].copy_from_slice(&prf_in); - let out_key = GenericArray::from_mut_slice(&mut out[n..n + n_master_key]); + let out_key = GenericArray::from_mut_slice(&mut out[n..n + 16]); + block.encrypt_block(out_key); + } + + Ok(out[..out_len].to_vec()) +} + +// As per https://datatracker.ietf.org/doc/html/rfc6188 +// The key derivation rate is zero as per https://datatracker.ietf.org/doc/html/rfc5764 hence index_over-kdr is 0 +const AES_256_BS: usize = 16; +pub(crate) fn aes_256_cm_key_derivation( + label: u8, + master_key: &[u8], + master_salt: &[u8], + index_over_kdr: usize, + out_len: usize, +) -> Result> { + if index_over_kdr != 0 { + // 24-bit "index DIV kdr" must be xored to prf input. + return Err(Error::UnsupportedIndexOverKdr); + } + + if master_key.len() != 32 { + return Err(Error::InvalidMasterKeyLength); + } + + if master_salt.len() > 14 { + return Err(Error::InvalidMasterSaltLength); + } + + if out_len > 32 { + return Err(Error::UnsupportedOutLength); + } + + let mut prf_in = [0; AES_256_BS]; + prf_in[7] = label; + prf_in[8..12].copy_from_slice((index_over_kdr as u32).to_be_bytes().as_slice()); + + for (i, x) in prf_in.iter_mut().enumerate() { + *x ^= master_salt.get(i).unwrap_or(&0); + } + + //The resulting value is then AES encrypted using the master key to get the cipher key. + let key = GenericArray::from_slice(master_key); + let block = Aes256::new(key); + + let mut out = vec![0u8; ((out_len + AES_256_BS) / AES_256_BS) * AES_256_BS]; + for (i, n) in (0..out_len).step_by(AES_256_BS).enumerate() { + prf_in[AES_256_BS - 2..].copy_from_slice(&((i as u16).to_be_bytes())); + + out[n..n + AES_256_BS].copy_from_slice(&prf_in); + let out_key = GenericArray::from_mut_slice(&mut out[n..n + 16]); block.encrypt_block(out_key); } @@ -170,4 +221,70 @@ mod test { Ok(()) } + + #[test] + fn test_aes_256_cm_key_derivation() -> Result<()> { + // Key Derivation Test Vectors from https://datatracker.ietf.org/doc/html/rfc6188#section-7.2 + let master_key = vec![ + 0xF0, 0xF0, 0x49, 0x14, 0xB5, 0x13, 0xF2, 0x76, 0x3A, 0x1B, 0x1F, 0xA1, 0x30, 0xF1, + 0x0E, 0x29, 0x98, 0xF6, 0xF6, 0xE4, 0x3E, 0x43, 0x09, 0xD1, 0xE6, 0x22, 0xA0, 0xE3, + 0x32, 0xB9, 0xF1, 0xB6, + ]; + let master_salt = vec![ + 0x3B, 0x04, 0x80, 0x3D, 0xE5, 0x1E, 0xE7, 0xC9, 0x64, 0x23, 0xAB, 0x5B, 0x78, 0xD2, + ]; + + let expected_session_key = vec![ + 0x5B, 0xA1, 0x06, 0x4E, 0x30, 0xEC, 0x51, 0x61, 0x3C, 0xAD, 0x92, 0x6C, 0x5A, 0x28, + 0xEF, 0x73, 0x1E, 0xC7, 0xFB, 0x39, 0x7F, 0x70, 0xA9, 0x60, 0x65, 0x3C, 0xAF, 0x06, + 0x55, 0x4C, 0xD8, 0xC4, + ]; + let expected_session_salt = vec![ + 0xFA, 0x31, 0x79, 0x16, 0x85, 0xCA, 0x44, 0x4A, 0x9E, 0x07, 0xC6, 0xC6, 0x4E, 0x93, + ]; + let expected_session_auth_tag = vec![ + 0xFD, 0x9C, 0x32, 0xD3, 0x9E, 0xD5, 0xFB, 0xB5, 0xA9, 0xDC, 0x96, 0xB3, 0x08, 0x18, + 0x45, 0x4D, 0x13, 0x13, 0xDC, 0x05, + ]; + + let session_key = aes_256_cm_key_derivation( + LABEL_SRTP_ENCRYPTION, + &master_key, + &master_salt, + 0, + master_key.len(), + )?; + assert_eq!( + session_key, expected_session_key, + "Session Key:\n{session_key:?} \ndoes not match expected:\n{expected_session_key:?}\nMaster Key:\n{master_key:?}\nMaster Salt:\n{master_salt:?}\n", + ); + + let session_salt = aes_256_cm_key_derivation( + LABEL_SRTP_SALT, + &master_key, + &master_salt, + 0, + master_salt.len(), + )?; + assert_eq!( + session_salt, expected_session_salt, + "Session Salt {session_salt:?} does not match expected {expected_session_salt:?}" + ); + + let auth_key_len = ProtectionProfile::Aes128CmHmacSha1_80.auth_key_len(); + + let session_auth_tag = aes_256_cm_key_derivation( + LABEL_SRTP_AUTHENTICATION_TAG, + &master_key, + &master_salt, + 0, + auth_key_len, + )?; + assert_eq!( + session_auth_tag, expected_session_auth_tag, + "Session Auth Tag {session_auth_tag:?} does not match expected {expected_session_auth_tag:?}", + ); + + Ok(()) + } } diff --git a/srtp/src/protection_profile.rs b/srtp/src/protection_profile.rs index aad2dfa63..51b5e5b49 100644 --- a/srtp/src/protection_profile.rs +++ b/srtp/src/protection_profile.rs @@ -4,34 +4,54 @@ pub enum ProtectionProfile { #[default] Aes128CmHmacSha1_80 = 0x0001, + Aes128CmHmacSha1_32 = 0x0002, AeadAes128Gcm = 0x0007, + AeadAes256Gcm = 0x0008, } impl ProtectionProfile { pub fn key_len(&self) -> usize { match *self { - ProtectionProfile::Aes128CmHmacSha1_80 | ProtectionProfile::AeadAes128Gcm => 16, + ProtectionProfile::Aes128CmHmacSha1_32 + | ProtectionProfile::Aes128CmHmacSha1_80 + | ProtectionProfile::AeadAes128Gcm => 16, + ProtectionProfile::AeadAes256Gcm => 32, } } pub fn salt_len(&self) -> usize { match *self { - ProtectionProfile::Aes128CmHmacSha1_80 => 14, - ProtectionProfile::AeadAes128Gcm => 12, + ProtectionProfile::Aes128CmHmacSha1_32 | ProtectionProfile::Aes128CmHmacSha1_80 => 14, + ProtectionProfile::AeadAes128Gcm | ProtectionProfile::AeadAes256Gcm => 12, } } - pub fn auth_tag_len(&self) -> usize { + pub fn rtp_auth_tag_len(&self) -> usize { match *self { - ProtectionProfile::Aes128CmHmacSha1_80 => 10, //CIPHER_AES_CM_HMAC_SHA1AUTH_TAG_LEN, - ProtectionProfile::AeadAes128Gcm => 16, //CIPHER_AEAD_AES_GCM_AUTH_TAG_LEN, + ProtectionProfile::Aes128CmHmacSha1_80 => 10, + ProtectionProfile::Aes128CmHmacSha1_32 => 4, + ProtectionProfile::AeadAes128Gcm | ProtectionProfile::AeadAes256Gcm => 0, + } + } + + pub fn rtcp_auth_tag_len(&self) -> usize { + match *self { + ProtectionProfile::Aes128CmHmacSha1_80 | ProtectionProfile::Aes128CmHmacSha1_32 => 10, + ProtectionProfile::AeadAes128Gcm | ProtectionProfile::AeadAes256Gcm => 0, + } + } + + pub fn aead_auth_tag_len(&self) -> usize { + match *self { + ProtectionProfile::Aes128CmHmacSha1_80 | ProtectionProfile::Aes128CmHmacSha1_32 => 0, + ProtectionProfile::AeadAes128Gcm | ProtectionProfile::AeadAes256Gcm => 16, } } pub fn auth_key_len(&self) -> usize { match *self { - ProtectionProfile::Aes128CmHmacSha1_80 => 20, - ProtectionProfile::AeadAes128Gcm => 0, + ProtectionProfile::Aes128CmHmacSha1_80 | ProtectionProfile::Aes128CmHmacSha1_32 => 20, + ProtectionProfile::AeadAes128Gcm | ProtectionProfile::AeadAes256Gcm => 0, } } } diff --git a/srtp/src/session/session_rtcp_test.rs b/srtp/src/session/session_rtcp_test.rs index 23db4325a..79a097fd9 100644 --- a/srtp/src/session/session_rtcp_test.rs +++ b/srtp/src/session/session_rtcp_test.rs @@ -153,7 +153,7 @@ fn encrypt_srtcp( const PLI_PACKET_SIZE: usize = 8; async fn get_sender_ssrc(read_stream: &Arc) -> Result { - let auth_tag_size = ProtectionProfile::Aes128CmHmacSha1_80.auth_tag_len(); + let auth_tag_size = ProtectionProfile::Aes128CmHmacSha1_80.rtcp_auth_tag_len(); let mut read_buffer = BytesMut::with_capacity(PLI_PACKET_SIZE + auth_tag_size); read_buffer.resize(PLI_PACKET_SIZE + auth_tag_size, 0u8); diff --git a/webrtc/src/dtls_transport/mod.rs b/webrtc/src/dtls_transport/mod.rs index 7e07f6316..96f5c2a1f 100644 --- a/webrtc/src/dtls_transport/mod.rs +++ b/webrtc/src/dtls_transport/mod.rs @@ -44,7 +44,9 @@ pub mod dtls_transport_state; pub(crate) fn default_srtp_protection_profiles() -> Vec { vec![ SrtpProtectionProfile::Srtp_Aead_Aes_128_Gcm, + SrtpProtectionProfile::Srtp_Aead_Aes_256_Gcm, SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, + SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_32, ] } @@ -414,9 +416,15 @@ impl RTCDtlsTransport { dtls::extension::extension_use_srtp::SrtpProtectionProfile::Srtp_Aead_Aes_128_Gcm => { srtp::protection_profile::ProtectionProfile::AeadAes128Gcm } + dtls::extension::extension_use_srtp::SrtpProtectionProfile::Srtp_Aead_Aes_256_Gcm => { + srtp::protection_profile::ProtectionProfile::AeadAes256Gcm + } dtls::extension::extension_use_srtp::SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80 => { srtp::protection_profile::ProtectionProfile::Aes128CmHmacSha1_80 } + dtls::extension::extension_use_srtp::SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_32 => { + srtp::protection_profile::ProtectionProfile::Aes128CmHmacSha1_32 + } _ => { if let Err(err) = dtls_conn.close().await { log::error!("{}", err);