diff --git a/Cargo.toml b/Cargo.toml index b6c54d2c..5b80550c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ rust-version = "1.87" # MSRV normal = [ "ntp-proto", "rustls-platform-verifier", "rustls-pemfile2", "rustls", "serde", "tokio-rustls", "toml", "tracing", "tracing-subscriber" ] [dependencies] -tokio = { version = "1.32", features = ["rt-multi-thread", "io-util", "fs", "net", "macros", "time" ] } +tokio = { version = "1.32", features = ["rt-multi-thread", "io-util", "fs", "net", "macros", "time", "sync" ] } toml = { version = ">=0.6.0,<0.9.0", default-features = false, features = ["parse"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.0", default-features = false, features = ["std", "fmt", "ansi"] } diff --git a/src/config.rs b/src/config.rs index db44b7d8..65d605a3 100644 --- a/src/config.rs +++ b/src/config.rs @@ -105,12 +105,19 @@ struct BareNtsPoolKeConfig { listen: SocketAddr, /// Which upstream servers to use. key_exchange_servers: Box<[KeyExchangeServer]>, + /// Maximum amount of parallel connections (incoming) + #[serde(default = "default_max_connections")] + max_connections: usize, } fn default_nts_ke_timeout() -> u64 { 1000 } +fn default_max_connections() -> usize { + 100 +} + #[derive(Clone)] pub struct NtsPoolKeConfig { pub server_tls: TlsAcceptor, @@ -118,6 +125,7 @@ pub struct NtsPoolKeConfig { pub listen: SocketAddr, pub key_exchange_servers: Box<[KeyExchangeServer]>, pub key_exchange_timeout: Duration, + pub max_connections: usize, } fn load_certificates( @@ -192,6 +200,7 @@ impl<'de> Deserialize<'de> for NtsPoolKeConfig { listen: bare.listen, key_exchange_servers: bare.key_exchange_servers, key_exchange_timeout: std::time::Duration::from_millis(bare.key_exchange_timeout), + max_connections: bare.max_connections, }) } } diff --git a/src/pool_ke.rs b/src/pool_ke.rs index 6cd20e3e..1e88ee29 100644 --- a/src/pool_ke.rs +++ b/src/pool_ke.rs @@ -70,10 +70,16 @@ impl NtsPoolKe { async fn serve(self: Arc) -> std::io::Result<()> { let listener = TcpListener::bind(self.config.listen).await?; + let connectionpermits = Arc::new(tokio::sync::Semaphore::new(self.config.max_connections)); info!("listening on '{:?}'", listener.local_addr()); loop { + let permit = connectionpermits + .clone() + .acquire_owned() + .await + .expect("Semaphore shouldn't be closed"); let (client_stream, source_address) = listener.accept().await?; let self_clone = self.clone(); @@ -88,6 +94,7 @@ impl NtsPoolKe { Ok(Err(err)) => ::tracing::debug!(?err, ?source_address, "NTS Pool KE failed"), Ok(Ok(())) => ::tracing::debug!(?source_address, "NTS Pool KE completed"), } + drop(permit); }); } }