Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl<'de> Deserialize<'de> for NtsPoolKeConfig {
pub struct KeyExchangeServer {
pub domain: String,
pub server_name: ServerName<'static>,
pub port: u16,
pub connection_address: (String, u16),
}

impl<'de> Deserialize<'de> for KeyExchangeServer {
Expand All @@ -236,7 +236,7 @@ impl<'de> Deserialize<'de> for KeyExchangeServer {
Ok(KeyExchangeServer {
domain: bare.domain.to_string(),
server_name,
port: bare.port,
connection_address: (bare.domain.to_string(), bare.port),
})
}
}
Expand Down Expand Up @@ -281,12 +281,12 @@ mod tests {
KeyExchangeServer {
domain: String::from("foo.bar"),
server_name: ServerName::try_from("foo.bar").unwrap(),
port: 1234
connection_address: (String::from("foo.bar"), 1234),
},
KeyExchangeServer {
domain: String::from("bar.foo"),
server_name: ServerName::try_from("bar.foo").unwrap(),
port: 4321
connection_address: (String::from("bar.foo"), 4321),
},
]
.as_slice()
Expand Down Expand Up @@ -320,12 +320,12 @@ mod tests {
KeyExchangeServer {
domain: String::from("foo.bar"),
server_name: ServerName::try_from("foo.bar").unwrap(),
port: 1234
connection_address: (String::from("foo.bar"), 1234),
},
KeyExchangeServer {
domain: String::from("bar.foo"),
server_name: ServerName::try_from("bar.foo").unwrap(),
port: 4321
connection_address: (String::from("bar.foo"), 4321),
},
]
.as_slice()
Expand Down
74 changes: 74 additions & 0 deletions src/nts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,80 @@ impl FixedKeyRequest {

Ok(())
}

#[cfg(test)]
pub async fn parse(mut reader: impl AsyncRead + Unpin) -> Result<Self, NtsError> {
let mut c2s = None;
let mut s2c = None;
let mut algorithm = None;
let mut protocol = None;

loop {
let record = NtsRecord::parse(&mut reader).await?;

match record {
NtsRecord::EndOfMessage => break,
NtsRecord::FixedKeyRequest {
c2s: c2s_rem,
s2c: s2c_rem,
} => {
if c2s.is_some() || s2c.is_some() {
return Err(NtsError::Invalid);
}

c2s = Some(c2s_rem);
s2c = Some(s2c_rem);
}
NtsRecord::AeadAlgorithm { algorithm_ids } => {
if algorithm.is_some() || algorithm_ids.len() != 1 {
return Err(NtsError::Invalid);
}

algorithm = Some(algorithm_ids[0]);
}
NtsRecord::NextProtocol { protocol_ids } => {
if protocol.is_some() || protocol_ids.len() != 1 {
return Err(NtsError::Invalid);
}

protocol = Some(protocol_ids[0]);
}
// Error
NtsRecord::Error { errorcode } => return Err(NtsError::Error(errorcode)),
// Warning
NtsRecord::Warning { warningcode } => match warningcode {
WarningCode::Unknown(code) => return Err(NtsError::UnknownWarning(code)),
},
// Unknown critical
NtsRecord::Unknown { critical: true, .. } => {
return Err(NtsError::UnrecognizedCriticalRecord)
}
// Ignored
NtsRecord::KeepAlive
| NtsRecord::Unknown { .. }
| NtsRecord::Server { .. }
| NtsRecord::Port { .. } => {}
// Not allowed
NtsRecord::NewCookie { .. }
| NtsRecord::SupportedNextProtocolList { .. }
| NtsRecord::SupportedAlgorithmList { .. }
| NtsRecord::NtpServerDeny { .. } => return Err(NtsError::Invalid),
}
}

if let (Some(algorithm), Some(protocol), Some(c2s), Some(s2c)) =
(algorithm, protocol, c2s, s2c)
{
Ok(Self {
c2s,
s2c,
protocol,
algorithm,
})
} else {
Err(NtsError::Invalid)
}
}
}

pub struct KeyExchangeResponse {
Expand Down
Loading