Skip to content

Make SRTP AES_256_GCM actually work #677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 117 additions & 20 deletions srtp/src/cipher/cipher_aead_aes_gcm.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,54 @@
use std::marker::PhantomData;

use aead::consts::{U12, U16};
use aes::cipher::{BlockEncrypt, BlockSizeUser, Unsigned};
use aes_gcm::aead::generic_array::GenericArray;
use aes_gcm::aead::{Aead, Payload};
use aes_gcm::{Aes128Gcm, KeyInit, Nonce};
use aes_gcm::{AesGcm, KeyInit, Nonce};
use byteorder::{BigEndian, ByteOrder};
use bytes::{Bytes, BytesMut};
use util::marshal::*;

use super::Cipher;
use crate::error::{Error, Result};
use crate::key_derivation::*;
use crate::protection_profile::ProtectionProfile;

pub const CIPHER_AEAD_AES_GCM_AUTH_TAG_LEN: usize = 16;

const RTCP_ENCRYPTION_FLAG: u8 = 0x80;

/// AEAD Cipher based on AES.
pub(crate) struct CipherAeadAesGcm {
srtp_cipher: aes_gcm::Aes128Gcm,
srtcp_cipher: aes_gcm::Aes128Gcm,
pub(crate) struct CipherAeadAesGcm<AES, NonceSize = U12>
where
NonceSize: Unsigned,
{
profile: ProtectionProfile,
srtp_cipher: aes_gcm::AesGcm<AES, NonceSize>,
srtcp_cipher: aes_gcm::AesGcm<AES, NonceSize>,
srtp_session_salt: Vec<u8>,
srtcp_session_salt: Vec<u8>,
_tag: PhantomData<AES>,
}

impl Cipher for CipherAeadAesGcm {
fn auth_tag_len(&self) -> usize {
CIPHER_AEAD_AES_GCM_AUTH_TAG_LEN
impl<AES, NS> Cipher for CipherAeadAesGcm<AES, NS>
where
NS: Unsigned,
AES: BlockEncrypt + KeyInit + BlockSizeUser<BlockSize = U16> + 'static,
AesGcm<AES, NS>: Aead,
{
fn rtp_auth_tag_len(&self) -> usize {
self.profile.rtp_auth_tag_len()
}

/// Get RTCP authenticated tag length.
fn rtcp_auth_tag_len(&self) -> usize {
self.profile.rtcp_auth_tag_len()
}

/// Get AEAD auth key length of the cipher.
fn aead_auth_tag_len(&self) -> usize {
self.profile.aead_auth_tag_len()
}

fn encrypt_rtp(
Expand All @@ -34,7 +59,7 @@ impl Cipher for CipherAeadAesGcm {
) -> Result<Bytes> {
// Grow the given buffer to fit the output.
let header_len = header.marshal_size();
let mut writer = BytesMut::with_capacity(payload.len() + self.auth_tag_len());
let mut writer = BytesMut::with_capacity(payload.len() + self.aead_auth_tag_len());

// Copy header unencrypted.
writer.extend_from_slice(&payload[..header_len]);
Expand All @@ -59,7 +84,7 @@ impl Cipher for CipherAeadAesGcm {
header: &rtp::header::Header,
roc: u32,
) -> Result<Bytes> {
if ciphertext.len() < self.auth_tag_len() {
if ciphertext.len() < self.aead_auth_tag_len() {
return Err(Error::ErrFailedToVerifyAuthTag);
}

Expand Down Expand Up @@ -101,7 +126,7 @@ impl Cipher for CipherAeadAesGcm {
}

fn decrypt_rtcp(&mut self, encrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result<Bytes> {
if encrypted.len() < self.auth_tag_len() + SRTCP_INDEX_SIZE {
if encrypted.len() < self.aead_auth_tag_len() + SRTCP_INDEX_SIZE {
return Err(Error::ErrFailedToVerifyAuthTag);
}

Expand Down Expand Up @@ -131,10 +156,31 @@ impl Cipher for CipherAeadAesGcm {
}
}

impl CipherAeadAesGcm {
impl<AES, NS> CipherAeadAesGcm<AES, NS>
where
NS: Unsigned,
AES: BlockEncrypt + KeyInit + BlockSizeUser<BlockSize = U16> + 'static,
AesGcm<AES, NS>: Aead,
{
/// Create a new AEAD instance.
pub(crate) fn new(master_key: &[u8], master_salt: &[u8]) -> Result<CipherAeadAesGcm> {
let srtp_session_key = aes_cm_key_derivation(
pub(crate) fn new(
profile: ProtectionProfile,
master_key: &[u8],
master_salt: &[u8],
) -> Result<CipherAeadAesGcm<AES>> {
assert_eq!(profile.aead_auth_tag_len(), AES::block_size());
assert_eq!(profile.key_len(), AES::key_size());
assert_eq!(profile.salt_len(), master_salt.len());

type Kdf = fn(u8, &[u8], &[u8], usize, usize) -> Result<Vec<u8>>;
let kdf: Kdf = match profile {
ProtectionProfile::AeadAes128Gcm => aes_cm_key_derivation,
// AES_256_GCM must use AES_256_CM_PRF as per https://datatracker.ietf.org/doc/html/rfc7714#section-11
ProtectionProfile::AeadAes256Gcm => aes_256_cm_key_derivation,
_ => unreachable!(),
};

let srtp_session_key = kdf(
LABEL_SRTP_ENCRYPTION,
master_key,
master_salt,
Expand All @@ -144,9 +190,9 @@ impl CipherAeadAesGcm {

let srtp_block = GenericArray::from_slice(&srtp_session_key);

let srtp_cipher = Aes128Gcm::new(srtp_block);
let srtp_cipher = AesGcm::<AES, U12>::new(srtp_block);

let srtcp_session_key = aes_cm_key_derivation(
let srtcp_session_key = kdf(
LABEL_SRTCP_ENCRYPTION,
master_key,
master_salt,
Expand All @@ -156,29 +202,31 @@ impl CipherAeadAesGcm {

let srtcp_block = GenericArray::from_slice(&srtcp_session_key);

let srtcp_cipher = Aes128Gcm::new(srtcp_block);
let srtcp_cipher = AesGcm::<AES, U12>::new(srtcp_block);

let srtp_session_salt = aes_cm_key_derivation(
let srtp_session_salt = kdf(
LABEL_SRTP_SALT,
master_key,
master_salt,
0,
master_key.len(),
master_salt.len(),
)?;

let srtcp_session_salt = aes_cm_key_derivation(
let srtcp_session_salt = kdf(
LABEL_SRTCP_SALT,
master_key,
master_salt,
0,
master_key.len(),
master_salt.len(),
)?;

Ok(CipherAeadAesGcm {
profile,
srtp_cipher,
srtcp_cipher,
srtp_session_salt,
srtcp_session_salt,
_tag: PhantomData,
})
}

Expand Down Expand Up @@ -245,3 +293,52 @@ impl CipherAeadAesGcm {
aad
}
}

#[cfg(test)]
mod tests {
use aes::{Aes128, Aes256};

use super::*;

#[test]
fn test_aead_aes_gcm_128() {
let profile = ProtectionProfile::AeadAes128Gcm;
let master_key = vec![0u8; profile.key_len()];
let master_salt = vec![0u8; 12];

let mut cipher =
CipherAeadAesGcm::<Aes128>::new(profile, &master_key, &master_salt).unwrap();

let header = rtp::header::Header {
ssrc: 0x12345678,
..Default::default()
};

let payload = vec![0u8; 100];
let encrypted = cipher.encrypt_rtp(&payload, &header, 0).unwrap();

let decrypted = cipher.decrypt_rtp(&encrypted, &header, 0).unwrap();
assert_eq!(&decrypted[..], &payload[..]);
}

#[test]
fn test_aead_aes_gcm_256() {
let profile = ProtectionProfile::AeadAes256Gcm;
let master_key = vec![0u8; profile.key_len()];
let master_salt = vec![0u8; 12];

let mut cipher =
CipherAeadAesGcm::<Aes256>::new(profile, &master_key, &master_salt).unwrap();

let header = rtp::header::Header {
ssrc: 0x12345678,
..Default::default()
};

let payload = vec![0u8; 100];
let encrypted = cipher.encrypt_rtp(&payload, &header, 0).unwrap();

let decrypted = cipher.decrypt_rtp(&encrypted, &header, 0).unwrap();
assert_eq!(&decrypted[..], &payload[..]);
}
}
57 changes: 35 additions & 22 deletions srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use util::marshal::*;
use super::{Cipher, CipherInner};
use crate::error::{Error, Result};
use crate::key_derivation::*;
use crate::protection_profile::ProtectionProfile;

type Aes128Ctr = ctr::Ctr128BE<aes::Aes128>;

Expand All @@ -18,8 +19,8 @@ pub(crate) struct CipherAesCmHmacSha1 {
}

impl CipherAesCmHmacSha1 {
pub fn new(master_key: &[u8], master_salt: &[u8]) -> Result<Self> {
let inner = CipherInner::new(master_key, master_salt)?;
pub fn new(profile: ProtectionProfile, master_key: &[u8], master_salt: &[u8]) -> Result<Self> {
let inner = CipherInner::new(profile, master_key, master_salt)?;

let srtp_session_key = aes_cm_key_derivation(
LABEL_SRTP_ENCRYPTION,
Expand All @@ -45,8 +46,19 @@ impl CipherAesCmHmacSha1 {
}

impl Cipher for CipherAesCmHmacSha1 {
fn auth_tag_len(&self) -> usize {
self.inner.auth_tag_len()
/// Get RTP authenticated tag length.
fn rtp_auth_tag_len(&self) -> usize {
self.inner.profile.rtp_auth_tag_len()
}

/// Get RTCP authenticated tag length.
fn rtcp_auth_tag_len(&self) -> usize {
self.inner.profile.rtcp_auth_tag_len()
}

/// Get AEAD auth key length of the cipher.
fn aead_auth_tag_len(&self) -> usize {
self.inner.profile.aead_auth_tag_len()
}

fn get_rtcp_index(&self, input: &[u8]) -> usize {
Expand All @@ -59,7 +71,7 @@ impl Cipher for CipherAesCmHmacSha1 {
header: &rtp::header::Header,
roc: u32,
) -> Result<Bytes> {
let mut writer = Vec::with_capacity(plaintext.len() + self.auth_tag_len());
let mut writer = Vec::with_capacity(plaintext.len() + self.rtp_auth_tag_len());

// Write the plaintext to the destination buffer.
writer.extend_from_slice(plaintext);
Expand All @@ -77,7 +89,7 @@ impl Cipher for CipherAesCmHmacSha1 {
stream.apply_keystream(&mut writer[header.marshal_size()..]);

// Generate the auth tag.
let auth_tag = &self.inner.generate_srtp_auth_tag(&writer, roc)[..self.auth_tag_len()];
let auth_tag = &self.inner.generate_srtp_auth_tag(&writer, roc)[..self.rtp_auth_tag_len()];
writer.extend(auth_tag);

Ok(Bytes::from(writer))
Expand All @@ -90,19 +102,19 @@ impl Cipher for CipherAesCmHmacSha1 {
roc: u32,
) -> Result<Bytes> {
let encrypted_len = encrypted.len();
if encrypted_len < self.auth_tag_len() {
return Err(Error::SrtpTooSmall(encrypted_len, self.auth_tag_len()));
if encrypted_len < self.rtp_auth_tag_len() {
return Err(Error::SrtpTooSmall(encrypted_len, self.rtp_auth_tag_len()));
}

let mut writer = Vec::with_capacity(encrypted_len - self.auth_tag_len());
let mut writer = Vec::with_capacity(encrypted_len - self.rtp_auth_tag_len());

// Split the auth tag and the cipher text into two parts.
let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..];
let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()];
let actual_tag = &encrypted[encrypted_len - self.rtp_auth_tag_len()..];
let cipher_text = &encrypted[..encrypted_len - self.rtp_auth_tag_len()];

// Generate the auth tag we expect to see from the ciphertext.
let expected_tag =
&self.inner.generate_srtp_auth_tag(cipher_text, roc)[..self.auth_tag_len()];
&self.inner.generate_srtp_auth_tag(cipher_text, roc)[..self.rtp_auth_tag_len()];

// See if the auth tag actually matches.
// We use a constant time comparison to prevent timing attacks.
Expand Down Expand Up @@ -132,7 +144,7 @@ impl Cipher for CipherAesCmHmacSha1 {

fn encrypt_rtcp(&mut self, decrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result<Bytes> {
let mut writer =
Vec::with_capacity(decrypted.len() + SRTCP_INDEX_SIZE + self.auth_tag_len());
Vec::with_capacity(decrypted.len() + SRTCP_INDEX_SIZE + self.rtcp_auth_tag_len());

// Write the decrypted to the destination buffer.
writer.extend_from_slice(decrypted);
Expand All @@ -155,22 +167,22 @@ impl Cipher for CipherAesCmHmacSha1 {
writer.put_u32(srtcp_index as u32 | (1u32 << 31));

// Generate the auth tag.
let auth_tag = &self.inner.generate_srtcp_auth_tag(&writer)[..self.auth_tag_len()];
let auth_tag = &self.inner.generate_srtcp_auth_tag(&writer)[..self.rtcp_auth_tag_len()];
writer.extend(auth_tag);

Ok(Bytes::from(writer))
}

fn decrypt_rtcp(&mut self, encrypted: &[u8], srtcp_index: usize, ssrc: u32) -> Result<Bytes> {
let encrypted_len = encrypted.len();
if encrypted_len < self.auth_tag_len() + SRTCP_INDEX_SIZE {
if encrypted_len < self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE {
return Err(Error::SrtcpTooSmall(
encrypted_len,
self.auth_tag_len() + SRTCP_INDEX_SIZE,
self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE,
));
}

let tail_offset = encrypted_len - (self.auth_tag_len() + SRTCP_INDEX_SIZE);
let tail_offset = encrypted_len - (self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE);
if tail_offset < 8 {
return Err(Error::ErrTooShortRtcp);
}
Expand All @@ -185,18 +197,19 @@ impl Cipher for CipherAesCmHmacSha1 {
}

// Split the auth tag and the cipher text into two parts.
let actual_tag = &encrypted[encrypted_len - self.auth_tag_len()..];
if actual_tag.len() != self.auth_tag_len() {
let actual_tag = &encrypted[encrypted_len - self.rtcp_auth_tag_len()..];
if actual_tag.len() != self.rtcp_auth_tag_len() {
return Err(Error::RtcpInvalidLengthAuthTag(
actual_tag.len(),
self.auth_tag_len(),
self.rtcp_auth_tag_len(),
));
}

let cipher_text = &encrypted[..encrypted_len - self.auth_tag_len()];
let cipher_text = &encrypted[..encrypted_len - self.rtcp_auth_tag_len()];

// Generate the auth tag we expect to see from the ciphertext.
let expected_tag = &self.inner.generate_srtcp_auth_tag(cipher_text)[..self.auth_tag_len()];
let expected_tag =
&self.inner.generate_srtcp_auth_tag(cipher_text)[..self.rtcp_auth_tag_len()];

// See if the auth tag actually matches.
// We use a constant time comparison to prevent timing attacks.
Expand Down
Loading
Loading