--- a/rust/hedgewars-server/src/main.rs Sat Apr 13 00:37:35 2019 +0300
+++ b/rust/hedgewars-server/src/main.rs Mon Apr 15 21:22:51 2019 +0300
@@ -68,6 +68,7 @@
if event.readiness() & Ready::readable() == Ready::readable() {
match event.token() {
utils::SERVER_TOKEN => hw_network.accept_client(&poll).unwrap(),
+ utils::TIMER_TOKEN => hw_network.handle_timeout(&poll).unwrap(),
#[cfg(feature = "official-server")]
utils::IO_TOKEN => hw_network.handle_io_result(),
Token(tok) => hw_network.client_readable(&poll, tok).unwrap(),
@@ -75,8 +76,7 @@
}
if event.readiness() & Ready::writable() == Ready::writable() {
match event.token() {
- utils::SERVER_TOKEN => unreachable!(),
- utils::IO_TOKEN => unreachable!(),
+ utils::SERVER_TOKEN | utils::TIMER_TOKEN | utils::IO_TOKEN => unreachable!(),
Token(tok) => hw_network.client_writable(&poll, tok).unwrap(),
}
}
--- a/rust/hedgewars-server/src/server/network.rs Sat Apr 13 00:37:35 2019 +0300
+++ b/rust/hedgewars-server/src/server/network.rs Mon Apr 15 21:22:51 2019 +0300
@@ -13,6 +13,7 @@
net::{TcpListener, TcpStream},
Poll, PollOpt, Ready, Token,
};
+use mio_extras::timer;
use netbuf;
use slab::Slab;
@@ -34,8 +35,11 @@
SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
},
};
+use std::time::Duration;
const MAX_BYTES_PER_READ: usize = 2048;
+const SEND_PING_TIMEOUT: Duration = Duration::from_secs(30);
+const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Hash, Eq, PartialEq, Copy, Clone)]
pub enum NetworkClientState {
@@ -80,16 +84,23 @@
peer_addr: SocketAddr,
decoder: ProtocolDecoder,
buf_out: netbuf::Buf,
+ timeout: timer::Timeout,
}
impl NetworkClient {
- pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
+ pub fn new(
+ id: ClientId,
+ socket: ClientSocket,
+ peer_addr: SocketAddr,
+ timeout: timer::Timeout,
+ ) -> NetworkClient {
NetworkClient {
id,
socket,
peer_addr,
decoder: ProtocolDecoder::new(),
buf_out: netbuf::Buf::new(),
+ timeout,
}
}
@@ -231,6 +242,10 @@
pub fn send_string(&mut self, msg: &str) {
self.send_raw_msg(&msg.as_bytes());
}
+
+ pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout {
+ replace(&mut self.timeout, timeout)
+ }
}
#[cfg(feature = "tls-connections")]
@@ -288,6 +303,13 @@
}
}
+enum TimeoutEvent {
+ SendPing,
+ DropClient,
+}
+
+struct TimerData(TimeoutEvent, ClientId);
+
pub struct NetworkLayer {
listener: TcpListener,
server: HWServer,
@@ -298,6 +320,21 @@
ssl: ServerSsl,
#[cfg(feature = "official-server")]
io: IoLayer,
+ timer: timer::Timer<TimerData>,
+}
+
+fn create_ping_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
+ timer.set_timeout(
+ SEND_PING_TIMEOUT,
+ TimerData(TimeoutEvent::SendPing, client_id),
+ )
+}
+
+fn create_drop_timeout(timer: &mut timer::Timer<TimerData>, client_id: ClientId) -> timer::Timeout {
+ timer.set_timeout(
+ DROP_CLIENT_TIMEOUT,
+ TimerData(TimeoutEvent::DropClient, client_id),
+ )
}
impl NetworkLayer {
@@ -306,6 +343,7 @@
let clients = Slab::with_capacity(clients_limit);
let pending = HashSet::with_capacity(2 * clients_limit);
let pending_cache = Vec::with_capacity(2 * clients_limit);
+ let timer = timer::Builder::default().build();
NetworkLayer {
listener,
@@ -317,6 +355,7 @@
ssl: NetworkLayer::create_ssl_context(),
#[cfg(feature = "official-server")]
io: IoLayer::new(),
+ timer,
}
}
@@ -346,6 +385,13 @@
PollOpt::edge(),
)?;
+ poll.register(
+ &self.timer,
+ utils::TIMER_TOKEN,
+ Ready::readable(),
+ PollOpt::edge(),
+ )?;
+
#[cfg(feature = "official-server")]
self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
@@ -384,7 +430,12 @@
)
.expect("could not register socket with event loop");
- let client = NetworkClient::new(client_id, client_socket, addr);
+ let client = NetworkClient::new(
+ client_id,
+ client_socket,
+ addr,
+ create_ping_timeout(&mut self.timer, client_id),
+ );
info!("client {} ({}) added", client.id, client.peer_addr);
entry.insert(client);
@@ -419,6 +470,29 @@
}
}
+ pub fn handle_timeout(&mut self, poll: &Poll) -> io::Result<()> {
+ while let Some(TimerData(event, client_id)) = self.timer.poll() {
+ match event {
+ TimeoutEvent::SendPing => {
+ if let Some(ref mut client) = self.clients.get_mut(client_id) {
+ client.send_string(&HWServerMessage::Ping.to_raw_protocol());
+ client.write()?;
+ client.replace_timeout(create_drop_timeout(&mut self.timer, client_id));
+ }
+ }
+ TimeoutEvent::DropClient => {
+ self.operation_failed(
+ poll,
+ client_id,
+ &ErrorKind::TimedOut.into(),
+ "No ping response",
+ )?;
+ }
+ }
+ }
+ Ok(())
+ }
+
#[cfg(feature = "official-server")]
pub fn handle_io_result(&mut self) {
if let Some((client_id, result)) = self.io.try_recv() {
@@ -486,6 +560,8 @@
pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+ let timeout = client.replace_timeout(create_ping_timeout(&mut self.timer, client_id));
+ self.timer.cancel_timeout(&timeout);
client.read()
} else {
warn!("invalid readable client: {}", client_id);