diff --git a/crates/hyperion/src/egress/channel.rs b/crates/hyperion/src/egress/channel.rs index 50aba3f1..db27d3d3 100644 --- a/crates/hyperion/src/egress/channel.rs +++ b/crates/hyperion/src/egress/channel.rs @@ -1,5 +1,5 @@ use bevy::{ecs::world::OnDespawn, prelude::*}; -use hyperion_proto::{ServerToProxyMessage, UpdateChannelPosition, UpdateChannelPositions}; +use hyperion_proto::UpdateChannelPosition; use hyperion_utils::EntityExt; use tracing::error; use valence_bytes::CowBytes; @@ -7,7 +7,10 @@ use valence_protocol::{ByteAngle, RawBytes, VarInt, packets::play}; use crate::{ egress::metadata::show_all, - net::{Channel, ChannelId, Compose, ConnectionId}, + net::{ + Channel, ChannelId, Compose, ConnectionId, + intermediate::{IntermediateServerToProxyMessage, UpdateChannelPositions}, + }, simulation::{ Pitch, Position, RequestSubscribeChannelPackets, Uuid, Velocity, Yaw, entity_kind::EntityKind, @@ -47,7 +50,7 @@ fn update_channel_positions( compose .io_buf() - .add_proxy_message(&ServerToProxyMessage::UpdateChannelPositions( + .add_proxy_message(&IntermediateServerToProxyMessage::UpdateChannelPositions( UpdateChannelPositions { updates: &updates }, )); } diff --git a/crates/hyperion/src/egress/mod.rs b/crates/hyperion/src/egress/mod.rs index 2ac48c15..898d556f 100644 --- a/crates/hyperion/src/egress/mod.rs +++ b/crates/hyperion/src/egress/mod.rs @@ -1,11 +1,13 @@ use bevy::prelude::*; -use hyperion_proto::{ServerToProxyMessage, UpdatePlayerPositions}; use tracing::error; use valence_protocol::{VarInt, packets::play::PlayerActionResponseS2c}; use crate::{ Blocks, - net::{Compose, ConnectionId}, + net::{ + Compose, ConnectionId, + intermediate::{IntermediateServerToProxyMessage, UpdatePlayerPositions}, + }, simulation::Position, }; mod channel; @@ -29,14 +31,14 @@ fn send_chunk_positions( let mut stream = Vec::with_capacity(count); let mut positions = Vec::with_capacity(count); - for (io, pos) in query.iter() { - stream.push(io.inner()); + for (&io, pos) in query.iter() { + stream.push(io); positions.push(hyperion_proto::ChunkPosition::from(pos.to_chunk())); } let packet = UpdatePlayerPositions { stream, positions }; - let chunk_positions = ServerToProxyMessage::UpdatePlayerPositions(packet); + let chunk_positions = IntermediateServerToProxyMessage::UpdatePlayerPositions(packet); compose.io_buf().add_proxy_message(&chunk_positions); } diff --git a/crates/hyperion/src/lib.rs b/crates/hyperion/src/lib.rs index 6a3cd3c6..a3ed2a21 100644 --- a/crates/hyperion/src/lib.rs +++ b/crates/hyperion/src/lib.rs @@ -247,21 +247,21 @@ impl Plugin for HyperionCore { let global = Global::new(shared.clone()); - let mut compose = Compose::new(shared.compression_level, global, IoBuf::default()); - app.add_plugins(CommandChannelPlugin); if let Some(address) = app.world().get_resource::() { let crypto = app.world().resource::(); let command_channel = app.world().resource::(); - let egress_comm = - init_proxy_comms(&runtime, command_channel.clone(), address.0, crypto.clone()); - compose.io_buf_mut().add_egress_comm(egress_comm); + init_proxy_comms(&runtime, command_channel.clone(), address.0, crypto.clone()); } else { warn!("Endpoint was not set while loading HyperionCore"); } - app.insert_resource(compose); + app.insert_resource(Compose::new( + shared.compression_level, + global, + IoBuf::default(), + )); app.insert_resource(runtime); app.insert_resource(CraftingRegistry::default()); app.insert_resource(StreamLookup::default()); diff --git a/crates/hyperion/src/net/intermediate.rs b/crates/hyperion/src/net/intermediate.rs new file mode 100644 index 00000000..d1a21047 --- /dev/null +++ b/crates/hyperion/src/net/intermediate.rs @@ -0,0 +1,206 @@ +use hyperion_proto::{ChunkPosition, ServerToProxyMessage, UpdateChannelPosition}; + +use crate::net::{ConnectionId, ProxyId}; + +#[derive(Clone, PartialEq)] +pub struct UpdatePlayerPositions { + pub stream: Vec, + pub positions: Vec, +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct AddChannel<'a> { + pub channel_id: u32, + + pub unsubscribe_packets: &'a [u8], +} + +#[derive(Clone, PartialEq)] +pub struct UpdateChannelPositions<'a> { + pub updates: &'a [UpdateChannelPosition], +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct RemoveChannel { + pub channel_id: u32, +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct SubscribeChannelPackets<'a> { + pub channel_id: u32, + pub exclude: Option, + + pub data: &'a [u8], +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct SetReceiveBroadcasts { + pub stream: ConnectionId, +} + +#[derive(Clone, PartialEq, Eq)] +pub struct BroadcastGlobal<'a> { + pub exclude: Option, + + pub data: &'a [u8], +} + +#[derive(Clone, PartialEq)] +pub struct BroadcastLocal<'a> { + pub center: ChunkPosition, + pub exclude: Option, + + pub data: &'a [u8], +} + +#[derive(Clone, PartialEq, Eq)] +pub struct BroadcastChannel<'a> { + pub channel_id: u32, + pub exclude: Option, + + pub data: &'a [u8], +} + +#[derive(Clone, PartialEq, Eq)] +pub struct Unicast<'a> { + pub stream: ConnectionId, + + pub data: &'a [u8], +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub struct Shutdown { + pub stream: ConnectionId, +} + +#[derive(Clone, PartialEq)] +pub enum IntermediateServerToProxyMessage<'a> { + UpdatePlayerPositions(UpdatePlayerPositions), + AddChannel(AddChannel<'a>), + UpdateChannelPositions(UpdateChannelPositions<'a>), + RemoveChannel(RemoveChannel), + SubscribeChannelPackets(SubscribeChannelPackets<'a>), + BroadcastGlobal(BroadcastGlobal<'a>), + BroadcastLocal(BroadcastLocal<'a>), + BroadcastChannel(BroadcastChannel<'a>), + Unicast(Unicast<'a>), + SetReceiveBroadcasts(SetReceiveBroadcasts), + Shutdown(Shutdown), +} + +impl IntermediateServerToProxyMessage<'_> { + /// Whether the result of [`IntermediateServerToProxyMessage::transform_for_proxy`] will be + /// affected by the proxy id provided + #[must_use] + pub const fn affected_by_proxy(&self) -> bool { + match self { + Self::UpdatePlayerPositions(_) + | Self::SubscribeChannelPackets(_) + | Self::BroadcastGlobal(_) + | Self::BroadcastLocal(_) + | Self::BroadcastChannel(_) + | Self::Unicast(_) + | Self::SetReceiveBroadcasts(_) + | Self::Shutdown(_) => true, + Self::AddChannel(_) | Self::UpdateChannelPositions(_) | Self::RemoveChannel(_) => false, + } + } + + /// Transforms an intermediate message to a message suitable for sending to a particular proxy. + /// Returns `None` if this message should not be sent to the proxy. + #[must_use] + pub fn transform_for_proxy(&self, proxy_id: ProxyId) -> Option> { + let filter_map_connection_id = + |id: ConnectionId| (id.proxy_id() == proxy_id).then(|| id.inner()); + match self { + Self::UpdatePlayerPositions(message) => { + Some(ServerToProxyMessage::UpdatePlayerPositions( + hyperion_proto::UpdatePlayerPositions { + stream: message + .stream + .iter() + .copied() + .filter_map(filter_map_connection_id) + .collect::>(), + positions: message.positions.clone(), + }, + )) + } + Self::AddChannel(message) => Some(ServerToProxyMessage::AddChannel( + hyperion_proto::AddChannel { + channel_id: message.channel_id, + unsubscribe_packets: message.unsubscribe_packets, + }, + )), + Self::UpdateChannelPositions(message) => { + Some(ServerToProxyMessage::UpdateChannelPositions( + hyperion_proto::UpdateChannelPositions { + updates: message.updates, + }, + )) + } + Self::RemoveChannel(message) => Some(ServerToProxyMessage::RemoveChannel( + hyperion_proto::RemoveChannel { + channel_id: message.channel_id, + }, + )), + Self::SubscribeChannelPackets(message) => { + Some(ServerToProxyMessage::SubscribeChannelPackets( + hyperion_proto::SubscribeChannelPackets { + channel_id: message.channel_id, + exclude: message + .exclude + .and_then(filter_map_connection_id) + .unwrap_or_default(), + data: message.data, + }, + )) + } + Self::BroadcastGlobal(message) => Some(ServerToProxyMessage::BroadcastGlobal( + hyperion_proto::BroadcastGlobal { + exclude: message + .exclude + .and_then(filter_map_connection_id) + .unwrap_or_default(), + data: message.data, + }, + )), + Self::BroadcastLocal(message) => Some(ServerToProxyMessage::BroadcastLocal( + hyperion_proto::BroadcastLocal { + center: message.center, + exclude: message + .exclude + .and_then(filter_map_connection_id) + .unwrap_or_default(), + data: message.data, + }, + )), + Self::BroadcastChannel(message) => Some(ServerToProxyMessage::BroadcastChannel( + hyperion_proto::BroadcastChannel { + channel_id: message.channel_id, + exclude: message + .exclude + .and_then(filter_map_connection_id) + .unwrap_or_default(), + data: message.data, + }, + )), + Self::Unicast(message) => { + Some(ServerToProxyMessage::Unicast(hyperion_proto::Unicast { + stream: filter_map_connection_id(message.stream)?, + data: message.data, + })) + } + Self::SetReceiveBroadcasts(message) => Some( + ServerToProxyMessage::SetReceiveBroadcasts(hyperion_proto::SetReceiveBroadcasts { + stream: filter_map_connection_id(message.stream)?, + }), + ), + Self::Shutdown(message) => Some(ServerToProxyMessage::SetReceiveBroadcasts( + hyperion_proto::SetReceiveBroadcasts { + stream: filter_map_connection_id(message.stream)?, + }, + )), + } + } +} diff --git a/crates/hyperion/src/net/mod.rs b/crates/hyperion/src/net/mod.rs index 2943b78e..7742c503 100644 --- a/crates/hyperion/src/net/mod.rs +++ b/crates/hyperion/src/net/mod.rs @@ -13,17 +13,23 @@ use glam::I16Vec2; use hyperion_proto::{ChunkPosition, ServerToProxyMessage}; use hyperion_utils::EntityExt; use libdeflater::CompressionLvl; +use rustc_hash::FxHashMap; use thread_local::ThreadLocal; +use tracing::error; use crate::{ Global, PacketBundle, Scratch, - net::encoder::{PacketEncoder, append_packet_without_compression}, + net::{ + encoder::{PacketEncoder, append_packet_without_compression}, + intermediate::IntermediateServerToProxyMessage, + }, simulation::EgressComm, }; pub mod agnostic; pub mod decoder; pub mod encoder; +pub mod intermediate; pub mod packets; pub mod proxy; @@ -37,6 +43,33 @@ pub const MAX_PACKET_SIZE: usize = valence_protocol::MAX_PACKET_SIZE as usize; /// targets. pub const MINECRAFT_VERSION: &str = "1.20.1"; +/// A unique identifier for a proxy to game server connection +#[derive(Component, Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct ProxyId { + /// The underlying unique identifier for the proxy connection. + /// This value is guaranteed to be unique among all active connections. + proxy_id: u64, +} + +impl ProxyId { + /// Creates a new proxy ID with the specified proxy identifier. + /// + /// This is an internal API used by the proxy management system. + #[must_use] + pub const fn new(proxy_id: u64) -> Self { + Self { proxy_id } + } + + /// Returns the underlying proxy identifier. + /// + /// This method is primarily used by internal networking code to interact + /// with the proxy layer. Most application code should not need this. + #[must_use] + pub const fn inner(self) -> u64 { + self.proxy_id + } +} + /// A unique identifier for a client connection /// /// Each `ConnectionId` represents an active network connection between the server and a client, @@ -49,11 +82,14 @@ pub const MINECRAFT_VERSION: &str = "1.20.1"; /// /// Note: Connection IDs are managed internally by the networking system and should be obtained /// through the appropriate connection establishment handlers rather than created directly. -#[derive(Component, Copy, Clone, Debug)] +#[derive(Component, Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct ConnectionId { /// The underlying unique identifier for this connection. /// This value is guaranteed to be unique among all active connections. stream_id: u64, + + /// The proxy which this player connection is connected to + proxy_id: ProxyId, } impl ConnectionId { @@ -63,8 +99,20 @@ impl ConnectionId { /// External code should obtain connection IDs through the appropriate /// connection handlers. #[must_use] - pub const fn new(stream_id: u64) -> Self { - Self { stream_id } + pub const fn new(stream_id: u64, proxy_id: ProxyId) -> Self { + Self { + stream_id, + proxy_id, + } + } + + /// Returns the proxy which this player connection is connected to. + /// + /// This method is primarily used by internal networking code. + /// Most application code should not need this. + #[must_use] + pub const fn proxy_id(self) -> ProxyId { + self.proxy_id } /// Returns the underlying stream identifier. @@ -165,7 +213,7 @@ impl<'a> DataBundle<'a> { self.compose .io_buf - .broadcast_local_raw(&self.data, center, 0); + .broadcast_local_raw(&self.data, center, None); Ok(()) } @@ -177,7 +225,7 @@ impl<'a> DataBundle<'a> { self.compose .io_buf - .broadcast_channel_raw(&self.data, channel, 0); + .broadcast_channel_raw(&self.data, channel, None); Ok(()) } @@ -216,7 +264,7 @@ impl Compose { Broadcast { packet, compose: self, - exclude: 0, + exclude: None, } } @@ -241,7 +289,7 @@ impl Compose { BroadcastLocal { packet, compose: self, - exclude: 0, + exclude: None, center: ChunkPosition { x: center.x, z: center.y, @@ -261,7 +309,7 @@ impl Compose { BroadcastChannel { packet, compose: self, - exclude: 0, + exclude: None, channel, } } @@ -328,7 +376,7 @@ pub struct IoBuf { // broadcast_buffer: ThreadLocal>, temp_buffer: ThreadLocal>, idx: ThreadLocal>, - egress_comms: Vec, + egress_comms: FxHashMap, } impl IoBuf { @@ -339,8 +387,16 @@ impl IoBuf { result } - pub(crate) fn add_egress_comm(&mut self, egress_comm: EgressComm) { - self.egress_comms.push(egress_comm); + pub(crate) fn add_proxy(&mut self, proxy_id: ProxyId, egress_comm: EgressComm) { + let already_exists = self.egress_comms.insert(proxy_id, egress_comm).is_some(); + + if already_exists { + error!("added multiple proxies with the same proxy id {proxy_id:?}"); + } + } + + pub(crate) fn remove_proxy(&mut self, proxy_id: ProxyId) -> Option { + self.egress_comms.remove(&proxy_id) } } @@ -349,7 +405,7 @@ impl IoBuf { pub struct Broadcast<'a, P> { packet: P, compose: &'a Compose, - exclude: u64, + exclude: Option, } /// A unicast builder @@ -394,7 +450,6 @@ impl

Broadcast<'_, P> { /// Exclude a certain player from the broadcast. This can only be called once. pub fn exclude(self, exclude: impl Into>) -> Self { let exclude = exclude.into(); - let exclude = exclude.map(|id| id.stream_id).unwrap_or_default(); Broadcast { packet: self.packet, compose: self.compose, @@ -409,7 +464,7 @@ pub struct BroadcastLocal<'a, P> { packet: P, compose: &'a Compose, center: ChunkPosition, - exclude: u64, + exclude: Option, } impl

BroadcastLocal<'_, P> { @@ -433,7 +488,6 @@ impl

BroadcastLocal<'_, P> { /// Exclude a certain player from the broadcast. This can only be called once. pub fn exclude(self, exclude: impl Into>) -> Self { let exclude = exclude.into(); - let exclude = exclude.map(|id| id.stream_id).unwrap_or_default(); BroadcastLocal { packet: self.packet, compose: self.compose, @@ -448,7 +502,7 @@ impl

BroadcastLocal<'_, P> { pub struct BroadcastChannel<'a, P> { packet: P, compose: &'a Compose, - exclude: u64, + exclude: Option, channel: ChannelId, } @@ -473,7 +527,6 @@ impl

BroadcastChannel<'_, P> { /// Exclude a certain player from the broadcast. This can only be called once. pub fn exclude(self, exclude: impl Into>) -> Self { let exclude = exclude.into(); - let exclude = exclude.map(|id| id.stream_id).unwrap_or_default(); Self { exclude, ..self } } } @@ -532,7 +585,7 @@ impl IoBuf { Ok(()) } - pub(crate) fn add_proxy_message(&self, message: &ServerToProxyMessage<'_>) { + pub(crate) fn encode_proxy_message(message: &ServerToProxyMessage<'_>) -> Bytes { let mut buffer = Vec::::new(); buffer.write_u64::(0x00).unwrap(); @@ -542,28 +595,61 @@ impl IoBuf { let packet_len = u64::try_from(buffer.len() - size_of::()).unwrap(); buffer[0..8].copy_from_slice(&packet_len.to_be_bytes()); - let buffer = Bytes::from_owner(buffer); + Bytes::from_owner(buffer) + } + + pub(crate) fn add_proxy_message(&self, message: &IntermediateServerToProxyMessage<'_>) { + if message.affected_by_proxy() { + // Encode the message for each proxy before sending it + for (&proxy_id, egress_comm) in &self.egress_comms { + let Some(message) = message.transform_for_proxy(proxy_id) else { + continue; + }; - for egress_comm in &self.egress_comms { - egress_comm.tx.send(buffer.clone()).unwrap(); + egress_comm + .tx + .send(Self::encode_proxy_message(&message)) + .unwrap(); + } + } else { + // Encode the message once and then send it to each proxy. This uses a placeholder + // proxy id. + let Some(message) = message.transform_for_proxy(ProxyId::new(0)) else { + return; + }; + + let buffer = Self::encode_proxy_message(&message); + for egress_comm in self.egress_comms.values() { + egress_comm.tx.send(buffer.clone()).unwrap(); + } } } - fn broadcast_local_raw(&self, data: &[u8], center: impl Into, exclude: u64) { + fn broadcast_local_raw( + &self, + data: &[u8], + center: impl Into, + exclude: Option, + ) { let center = center.into(); - self.add_proxy_message(&ServerToProxyMessage::BroadcastLocal( - hyperion_proto::BroadcastLocal { - data, + self.add_proxy_message(&IntermediateServerToProxyMessage::BroadcastLocal( + intermediate::BroadcastLocal { center, exclude, + data, }, )); } - fn broadcast_channel_raw(&self, data: &[u8], channel: ChannelId, exclude: u64) { - self.add_proxy_message(&ServerToProxyMessage::BroadcastChannel( - hyperion_proto::BroadcastChannel { + fn broadcast_channel_raw( + &self, + data: &[u8], + channel: ChannelId, + exclude: Option, + ) { + self.add_proxy_message(&IntermediateServerToProxyMessage::BroadcastChannel( + intermediate::BroadcastChannel { channel_id: channel.inner(), data, exclude, @@ -571,36 +657,27 @@ impl IoBuf { )); } - pub(crate) fn broadcast_raw(&self, data: &[u8], exclude: u64) { - self.add_proxy_message(&ServerToProxyMessage::BroadcastGlobal( - hyperion_proto::BroadcastGlobal { - data, - // todo: Right now, we are using `to_vec`. - // We want to probably allow encoding without allocation in the future. - // Fortunately, `to_vec` will not require any allocation if the buffer is empty. - exclude, - }, + pub(crate) fn broadcast_raw(&self, data: &[u8], exclude: Option) { + self.add_proxy_message(&IntermediateServerToProxyMessage::BroadcastGlobal( + intermediate::BroadcastGlobal { exclude, data }, )); } pub(crate) fn unicast_raw(&self, data: &[u8], stream: ConnectionId) { - self.add_proxy_message(&ServerToProxyMessage::Unicast(hyperion_proto::Unicast { - data, - stream: stream.stream_id, - })); + self.add_proxy_message(&IntermediateServerToProxyMessage::Unicast( + intermediate::Unicast { stream, data }, + )); } pub(crate) fn set_receive_broadcasts(&self, stream: ConnectionId) { - self.add_proxy_message(&ServerToProxyMessage::SetReceiveBroadcasts( - hyperion_proto::SetReceiveBroadcasts { - stream: stream.stream_id, - }, + self.add_proxy_message(&IntermediateServerToProxyMessage::SetReceiveBroadcasts( + intermediate::SetReceiveBroadcasts { stream }, )); } pub(crate) fn add_channel(&self, channel: ChannelId, unsubscribe_packets: &[u8]) { - self.add_proxy_message(&ServerToProxyMessage::AddChannel( - hyperion_proto::AddChannel { + self.add_proxy_message(&IntermediateServerToProxyMessage::AddChannel( + intermediate::AddChannel { channel_id: channel.inner(), unsubscribe_packets, }, @@ -613,26 +690,26 @@ impl IoBuf { packets: &[u8], exclude: Option, ) { - self.add_proxy_message(&ServerToProxyMessage::SubscribeChannelPackets( - hyperion_proto::SubscribeChannelPackets { + self.add_proxy_message(&IntermediateServerToProxyMessage::SubscribeChannelPackets( + intermediate::SubscribeChannelPackets { channel_id: channel.inner(), - exclude: exclude.map_or(0, |connection_id| connection_id.stream_id), + exclude, data: packets, }, )); } pub(crate) fn remove_channel(&self, channel: ChannelId) { - self.add_proxy_message(&ServerToProxyMessage::RemoveChannel( - hyperion_proto::RemoveChannel { + self.add_proxy_message(&IntermediateServerToProxyMessage::RemoveChannel( + intermediate::RemoveChannel { channel_id: channel.inner(), }, )); } pub fn shutdown(&self, stream: ConnectionId) { - self.add_proxy_message(&ServerToProxyMessage::Shutdown(hyperion_proto::Shutdown { - stream: stream.stream_id, - })); + self.add_proxy_message(&IntermediateServerToProxyMessage::Shutdown( + intermediate::Shutdown { stream }, + )); } } diff --git a/crates/hyperion/src/net/proxy.rs b/crates/hyperion/src/net/proxy.rs index a2cdfe1c..18570532 100644 --- a/crates/hyperion/src/net/proxy.rs +++ b/crates/hyperion/src/net/proxy.rs @@ -1,6 +1,13 @@ //! Communication to a proxy which forwards packets to the players. -use std::{net::SocketAddr, process::Command, sync::Arc}; +use std::{ + net::SocketAddr, + process::Command, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, +}; use bevy::prelude::*; use hyperion_proto::ArchivedProxyToServerMessage; @@ -13,11 +20,12 @@ use rustls::{ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; use tokio_rustls::TlsAcceptor; use tracing::{error, info, warn}; +use valence_protocol::{VarInt, packets::play}; use crate::{ ConnectionId, Crypto, PacketDecoder, command_channel::CommandChannel, - net::Compose, + net::{Channel, ChannelId, Compose, IoBuf, ProxyId}, runtime::AsyncRuntime, simulation::{EgressComm, RequestSubscribeChannelPackets, StreamLookup, packet_state}, }; @@ -44,10 +52,15 @@ fn get_pid_from_port(port: u16) -> Result, std::io::Error> { Ok(pid) } -async fn handle_proxy_messages(read: impl AsyncRead + Unpin, command_channel: CommandChannel) { +async fn handle_proxy_messages( + read: impl AsyncRead + Unpin, + command_channel: CommandChannel, + proxy_id: ProxyId, +) { let mut reader = ProxyReader::new(read); let mut player_packet_sender: FxHashMap = FxHashMap::default(); + // Process packets loop { let buffer = match reader.next_server_packet_buffer().await { Ok(message) => message, @@ -65,7 +78,7 @@ async fn handle_proxy_messages(read: impl AsyncRead + Unpin, command_channel: Co error!("closing proxy connection due to an unexpected error: {err:?}"); } } - return; + break; } }; @@ -86,7 +99,7 @@ async fn handle_proxy_messages(read: impl AsyncRead + Unpin, command_channel: Co command_channel.push(move |world: &mut World| { let player = world .spawn(( - ConnectionId::new(stream), + ConnectionId::new(stream, proxy_id), packet_state::Handshake(()), PacketDecoder::default(), receiver, @@ -124,7 +137,7 @@ async fn handle_proxy_messages(read: impl AsyncRead + Unpin, command_channel: Co error!( "PlayerPackets: no player with stream id exists in player_packet_sender" ); - return; + continue; }; if let Err(e) = sender.send(&message.data) { @@ -145,7 +158,9 @@ async fn handle_proxy_messages(read: impl AsyncRead + Unpin, command_channel: Co let compose = world .get_resource::() .expect("Compose resource should exist"); - compose.io_buf().shutdown(ConnectionId::new(stream)); + compose + .io_buf() + .shutdown(ConnectionId::new(stream, proxy_id)); }); } } @@ -184,14 +199,22 @@ async fn handle_proxy_messages(read: impl AsyncRead + Unpin, command_channel: Co } } } + + // Disconnect all players that were connected through this proxy + command_channel.push(move |world: &mut World| { + let mut query = world.query::<(Entity, &ConnectionId)>(); + let players_to_remove = query + .iter(world) + .filter(|(_, connection_id)| connection_id.proxy_id() == proxy_id) + .map(|(entity, _)| entity) + .collect::>(); + for player in players_to_remove { + world.despawn(player); + } + }); } -async fn inner( - socket: SocketAddr, - crypto: Crypto, - mut server_to_proxy: tokio::sync::mpsc::UnboundedReceiver, - command_channel: CommandChannel, -) { +async fn inner(socket: SocketAddr, crypto: Crypto, command_channel: CommandChannel) { let listener = match tokio::net::TcpListener::bind(socket).await { Ok(listener) => listener, Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => { @@ -240,6 +263,8 @@ async fn inner( tokio::spawn( async move { + let next_proxy_id = Arc::new(AtomicU64::new(0)); + loop { let (socket, _) = listener.accept().await.unwrap(); @@ -254,58 +279,105 @@ async fn inner( }; let command_channel = command_channel.clone(); + let next_proxy_id = next_proxy_id.clone(); + let stream = acceptor.accept(socket); - let stream = match acceptor.accept(socket).await { - Ok(stream) => stream, - Err(e) => { - error!( - "failed to accept proxy connection from {addr}: tls accept failed: {e}" - ); - continue; - } - }; + tokio::spawn(async move { + let stream = match stream.await { + Ok(stream) => stream, + Err(e) => { + error!( + "failed to accept proxy connection from {addr}: tls accept \ + failed: {e}" + ); + return; + } + }; - info!("Proxy connection established on {addr}"); + info!("Proxy connection established on {addr}"); - let (read, mut write) = tokio::io::split(stream); + let (read, mut write) = tokio::io::split(stream); - let proxy_writer_task = tokio::spawn(async move { - while let Some(bytes) = server_to_proxy.recv().await { - if write.write_all(&bytes).await.is_err() { - error!("error writing to proxy"); - return server_to_proxy; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let egress_comm = EgressComm::from(tx.clone()); + let proxy_id = ProxyId::new(next_proxy_id.fetch_add(1, Ordering::Relaxed)); + + command_channel.push(move |world: &mut World| { + let mut compose = world.resource_mut::(); + compose.io_buf_mut().add_proxy(proxy_id, egress_comm); + }); + + let command_channel_clone = command_channel.clone(); + tokio::spawn(async move { + // Send the bytes from the channel to the proxy + while let Some(bytes) = rx.recv().await { + if write.write_all(&bytes).await.is_err() { + error!("error writing to proxy"); + break; + } } - } - warn!("proxy shut down"); + warn!("proxy shut down"); - server_to_proxy - }); + command_channel_clone.push(move |world: &mut World| { + // Remove this channel from the compose egress comms list + let mut compose = world.resource_mut::(); + let removed = compose.io_buf_mut().remove_proxy(proxy_id).is_some(); + if !removed { + error!("failed to remove proxy from compose egress comms"); + } - let command_channel = command_channel.clone(); - tokio::spawn(handle_proxy_messages(read, command_channel)); + // Explicitly close this receiver. This ensures that the channel isn't + // closed before this, which would lead to an error on the sender side + // of Compose. + rx.close(); + }); + }); + + command_channel.push(move |world: &mut World| { + // Let the proxy know about all packet channels that exist at the moment + + let mut query = world.query_filtered::>(); + let compose = world.resource::(); + for channel in query.iter(world) { + let packet = play::EntitiesDestroyS2c { + entity_ids: vec![VarInt(channel.minecraft_id())].into(), + }; + + let packet_buf = + compose.io_buf().encode_packet(&packet, compose).unwrap(); + + tx.send(IoBuf::encode_proxy_message( + &hyperion_proto::ServerToProxyMessage::AddChannel( + hyperion_proto::AddChannel { + channel_id: ChannelId::from(channel).inner(), + unsubscribe_packets: &packet_buf, + }, + ), + )) + .unwrap(); + } + }); - // todo: handle player disconnects on proxy shut down - // Ideally, we should design for there being multiple proxies, - // and all proxies should store all the players on them. - // Then we can disconnect all those players related to that proxy. - server_to_proxy = proxy_writer_task.await.unwrap(); + tokio::spawn(handle_proxy_messages( + read, + command_channel.clone(), + proxy_id, + )); + }); } }, // .instrument(info_span!("proxy reader")), ); } /// Initializes proxy communications. -#[must_use] pub fn init_proxy_comms( runtime: &AsyncRuntime, command_channel: CommandChannel, socket: SocketAddr, crypto: Crypto, -) -> EgressComm { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - runtime.spawn(inner(socket, crypto, rx, command_channel)); - EgressComm::from(tx) +) { + runtime.spawn(inner(socket, crypto, command_channel)); } #[derive(Debug)] diff --git a/crates/hyperion/src/simulation/mod.rs b/crates/hyperion/src/simulation/mod.rs index 30f3b8d5..387ef825 100644 --- a/crates/hyperion/src/simulation/mod.rs +++ b/crates/hyperion/src/simulation/mod.rs @@ -60,7 +60,7 @@ pub struct PlayerUuidLookup { } /// Communicates with the proxy server. -#[derive(Resource, Deref, DerefMut, From)] +#[derive(Clone, Deref, DerefMut, From)] pub struct EgressComm { pub(crate) tx: tokio::sync::mpsc::UnboundedSender, }