diff -r 7d0f747afcb8 -r a4d505a32879 rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Tue Feb 01 02:23:35 2022 +0300 +++ b/rust/hedgewars-server/src/server/network.rs Tue Feb 01 20:58:35 2022 +0300 @@ -2,15 +2,9 @@ use log::*; use slab::Slab; use std::{ - collections::HashSet, - io, - io::{Error, ErrorKind, Read, Write}, iter::Iterator, - mem::{replace, swap}, - net::{IpAddr, Ipv4Addr, SocketAddr}, - num::NonZeroU32, + net::{IpAddr, SocketAddr}, time::Duration, - time::Instant, }; use tokio::{ io::AsyncReadExt, @@ -25,7 +19,7 @@ }, handlers, handlers::{IoResult, IoTask, ServerState}, - protocol::ProtocolDecoder, + protocol::{self, ProtocolDecoder, ProtocolError}, utils, }; use hedgewars_network_protocol::{ @@ -33,6 +27,8 @@ }; use tokio::io::AsyncWriteExt; +const PING_TIMEOUT: Duration = Duration::from_secs(15); + enum ClientUpdateData { Message(HwProtocolMessage), Error(String), @@ -80,16 +76,28 @@ socket, peer_addr, receiver, - decoder: ProtocolDecoder::new(), + decoder: ProtocolDecoder::new(PING_TIMEOUT), } } - async fn read(&mut self) -> Option { - self.decoder.read_from(&mut self.socket).await + async fn read( + socket: &mut TcpStream, + decoder: &mut ProtocolDecoder, + ) -> protocol::Result { + let result = decoder.read_from(socket).await; + if matches!(result, Err(ProtocolError::Timeout)) { + if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await { + decoder.read_from(socket).await + } else { + Err(ProtocolError::Eof) + } + } else { + result + } } - async fn write(&mut self, mut data: Bytes) -> bool { - !data.has_remaining() || matches!(self.socket.write_buf(&mut data).await, Ok(n) if n > 0) + async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool { + !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0) } async fn run(mut self, sender: Sender) { @@ -103,7 +111,7 @@ tokio::select! { server_message = self.receiver.recv() => { match server_message { - Some(message) => if !self.write(message).await { + Some(message) => if !Self::write(&mut self.socket, message).await { sender.send(Error("Connection reset by peer".to_string())).await; break; } @@ -112,15 +120,18 @@ } } } - client_message = self.decoder.read_from(&mut self.socket) => { + client_message = Self::read(&mut self.socket, &mut self.decoder) => { match client_message { - Some(message) => { + Ok(message) => { if !sender.send(Message(message)).await { break; } } - None => { - sender.send(Error("Connection reset by peer".to_string())).await; + Err(e) => { + sender.send(Error(format!("{}", e))).await; + if matches!(e, ProtocolError::Timeout) { + Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await; + } break; } } @@ -153,8 +164,10 @@ Some(ClientUpdate{ client_id, data: Message(message) } ) => { self.handle_message(client_id, message).await; } - Some(ClientUpdate{ client_id, .. } ) => { + Some(ClientUpdate{ client_id, data: Error(e) } ) => { let mut response = handlers::Response::new(client_id); + info!("Client {} error: {:?}", client_id, e); + response.remove_client(client_id); handlers::handle_client_loss(&mut self.server_state, client_id, &mut response); self.handle_response(response).await; }