diff --git a/ownserver_server/src/client.rs b/ownserver_server/src/client.rs index 1c43305..5f03dae 100644 --- a/ownserver_server/src/client.rs +++ b/ownserver_server/src/client.rs @@ -17,7 +17,7 @@ pub struct Client { pub client_id: ClientId, pub endpoints: Endpoints, - ws_tx: SplitSink, + ws_tx: Option>, ct_self: CancellationToken, ct_child: CancellationToken, state: ClientState, @@ -54,10 +54,10 @@ where break } result = stream.next() => { - let message = match result { + let mut bytes = match result { // handle protocol message Some(Ok(msg)) if (msg.is_binary() || msg.is_text()) && !msg.as_bytes().is_empty() => { - msg.into_bytes() + BytesMut::from(msg.as_bytes()) } // handle close with reason Some(Ok(msg)) if msg.is_close() && !msg.as_bytes().is_empty() => { @@ -70,7 +70,6 @@ where } }; - let mut bytes = BytesMut::from(&message[..]); counter!("ownserver_server.client.control_packet.received_bytes", "client_id" => client_id.to_string()).increment(bytes.len() as u64); let packet = match ControlPacketV2Codec::new().decode(&mut bytes) { @@ -143,7 +142,7 @@ where Self { client_id, endpoints, - ws_tx: sink, + ws_tx: Some(sink), ct_self: token, ct_child: CancellationToken::new(), state: ClientState::Connected, @@ -152,29 +151,43 @@ where } pub async fn send_to_client(&mut self, packet: ControlPacketV2) -> Result<(), SendToClientError> { - if !matches!(self.state, ClientState::Connected) { - return Err(SendToClientError::ClientNotConnected(self.client_id)) - } - let mut codec = ControlPacketV2Codec::new(); - let mut bytes = BytesMut::new(); - if let Err(e) = codec.encode(packet, &mut bytes) { - tracing::warn!(cid = %self.client_id, error = ?e, "failed to encode message"); - return Err(SendToClientError::EncodeError(e)) - } + match self.ws_tx { + None => { + // ClientState should be Connected + Err(SendToClientError::ClientNotConnected(self.client_id)) + } + Some(ref mut ws_tx) => { + if !matches!(self.state, ClientState::Connected) { + return Err(SendToClientError::ClientNotConnected(self.client_id)) + } - if let Err(e) = self.ws_tx.send(Message::binary(bytes.to_vec())).await { - tracing::warn!(cid = %self.client_id, error = ?e, "client disconnected: aborting"); - self.set_wait_reconnect(); - return Err(SendToClientError::WriteError(e)) + let mut codec = ControlPacketV2Codec::new(); + let mut bytes = BytesMut::new(); + if let Err(e) = codec.encode(packet, &mut bytes) { + tracing::warn!(cid = %self.client_id, error = ?e, "failed to encode message"); + return Err(SendToClientError::EncodeError(e)) + } + + let bytes_len = bytes.len(); + if let Err(e) = ws_tx.send(Message::binary(bytes)).await { + tracing::warn!(cid = %self.client_id, error = ?e, "client disconnected: aborting"); + self.set_wait_reconnect(); + return Err(SendToClientError::WriteError(e)) + } + counter!("ownserver_server.client.control_packet.sent_bytes", "client_id" => self.client_id.to_string()).increment(bytes_len as u64); + Ok(()) + } } - counter!("ownserver_server.client.control_packet.sent_bytes", "client_id" => self.client_id.to_string()).increment(bytes.len() as u64); - Ok(()) } pub fn set_wait_reconnect(&mut self) { if let ClientState::Connected = self.state { self.state = ClientState::WaitReconnect { expires: Utc::now() + self.reconnect_window }; tracing::debug!(cid = %self.client_id, "set client state: {:?}", self.state); + // dropping read half of websocket stream + self.ct_self.cancel(); + // dropping write half of websocket stream + self.ws_tx = None; } }