# HG changeset patch # User alfadur # Date 1643671415 -10800 # Node ID 7d0f747afcb81744f67a79f868bc8937aa54ad38 # Parent ea459da15b30f89839d31559eb5c473f5b683e3a move server network to tokio diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/Cargo.toml --- a/rust/hedgewars-server/Cargo.toml Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/Cargo.toml Tue Feb 01 02:23:35 2022 +0300 @@ -1,31 +1,29 @@ [package] -edition = "2018" +edition = "2021" name = "hedgewars-server" version = "0.9.0" authors = [ "Andrey Korotaev " ] [features] -official-server = ["openssl", "mysql"] -tls-connections = ["openssl"] +official-server = [] default = [] [dependencies] -getopts = "0.2" -rand = "0.8" +base64 = "0.13" +bitflags = "1.3" +bytes = "1.1" chrono = "0.4" -mio = { version = "0.7", features = ["os-poll", "net"] } -slab = "0.4" -netbuf = "0.4" -nom = "6.2" env_logger = "0.8" +getopts = "0.2" log = "0.4" -base64 = "0.13" -bitflags = "1.2" +nom = "7.1" +rand = "0.8" serde = "1.0" serde_yaml = "0.8" serde_derive = "1.0" -openssl = { version = "0.10", optional = true } -mysql = { version = "15.0", optional = true } +slab = "0.4" +tokio = { version = "1.16", features = ["full"]} + hedgewars-network-protocol = { path = "../hedgewars-network-protocol" } [dev-dependencies] diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/handlers.rs --- a/rust/hedgewars-server/src/handlers.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/handlers.rs Tue Feb 01 02:23:35 2022 +0300 @@ -359,6 +359,7 @@ response.warn(ACCESS_DENIED); } } + #[allow(unused_variables)] HwProtocolMessage::Watch(id) => { #[cfg(feature = "official-server")] { @@ -386,13 +387,14 @@ response: &mut Response, addr: [u8; 4], is_local: bool, -) { +) -> bool { let ban_reason = Some(addr) .filter(|_| !is_local) .and_then(|a| state.anteroom.find_ip_ban(a)); if let Some(reason) = ban_reason { response.add(HwServerMessage::Bye(reason).send_self()); response.remove_client(client_id); + false } else { let mut salt = [0u8; 18]; thread_rng().fill_bytes(&mut salt); @@ -401,7 +403,11 @@ .anteroom .add_client(client_id, encode(&salt), is_local); - response.add(HwServerMessage::Connected(utils::SERVER_MESSAGE.to_owned(), utils::SERVER_VERSION).send_self()); + response.add( + HwServerMessage::Connected(utils::SERVER_MESSAGE.to_owned(), utils::SERVER_VERSION) + .send_self(), + ); + true } } diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/handlers/checker.rs --- a/rust/hedgewars-server/src/handlers/checker.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/handlers/checker.rs Tue Feb 01 02:23:35 2022 +0300 @@ -1,5 +1,4 @@ use log::*; -use mio; use crate::core::{server::HwServer, types::ClientId}; use hedgewars_network_protocol::messages::HwProtocolMessage; diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/handlers/inanteroom.rs --- a/rust/hedgewars-server/src/handlers/inanteroom.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/handlers/inanteroom.rs Tue Feb 01 02:23:35 2022 +0300 @@ -1,5 +1,3 @@ -use mio; - use super::strings::*; use crate::handlers::actions::ToPendingMessage; use crate::{ diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/main.rs --- a/rust/hedgewars-server/src/main.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/main.rs Tue Feb 01 02:23:35 2022 +0300 @@ -1,14 +1,12 @@ +#![forbid(unsafe_code)] #![allow(unused_imports)] +#![allow(dead_code)] +#![allow(unused_variables)] #![deny(bare_trait_objects)] use getopts::Options; use log::*; -use mio::{net::*, *}; -use std::{ - env, - str::FromStr as _, - time::{Duration, Instant}, -}; +use std::{env, net::SocketAddr, str::FromStr as _}; mod core; mod handlers; @@ -20,7 +18,8 @@ const PROGRAM_NAME: &'_ str = "Hedgewars Game Server"; -fn main() { +#[tokio::main] +async fn main() -> tokio::io::Result<()> { env_logger::init(); info!("Hedgewars game server, protocol {}", utils::SERVER_VERSION); @@ -34,91 +33,24 @@ Ok(m) => m, Err(e) => { println!("{}\n{}", e, opts.short_usage("")); - return; + return Ok(()); } }; if matches.opt_present("h") { println!("{}", opts.usage(PROGRAM_NAME)); - return; + return Ok(()); } let port = matches .opt_str("p") .and_then(|s| u16::from_str(&s).ok()) .unwrap_or(46631); - let address = format!("0.0.0.0:{}", port).parse().unwrap(); - - let listener = TcpListener::bind(address).unwrap(); - - let mut poll = Poll::new().unwrap(); - let mut hw_builder = NetworkLayerBuilder::default().with_listener(listener); - - #[cfg(feature = "tls-connections")] - { - let address = format!("0.0.0.0:{}", port + 1).parse().unwrap(); - hw_builder = hw_builder.with_secure_listener(TcpListener::bind(address).unwrap()); - } + let address: SocketAddr = format!("0.0.0.0:{}", port).parse().unwrap(); - let mut hw_network = hw_builder.build(&poll); - hw_network.register(&poll).unwrap(); - - let mut events = Events::with_capacity(1024); - - let mut time = Instant::now(); - - loop { - let timeout = if hw_network.has_pending_operations() { - Some(Duration::from_millis(1)) - } else { - None - }; - - poll.poll(&mut events, timeout).unwrap(); + let server = tokio::net::TcpListener::bind(address).await.unwrap(); - for event in events.iter() { - if event.is_readable() { - match event.token() { - token @ (utils::SERVER_TOKEN | utils::SECURE_SERVER_TOKEN) => { - match hw_network.accept_client(&poll, token) { - Ok(()) => (), - Err(e) => debug!("Error accepting client: {}", e), - } - } - #[cfg(feature = "official-server")] - utils::IO_TOKEN => match hw_network.handle_io_result(&poll) { - Ok(()) => (), - Err(e) => debug!("Error in IO task: {}", e), - }, - Token(token) => match hw_network.client_readable(&poll, token) { - Ok(()) => (), - Err(e) => debug!("Error reading from client socket {}: {}", token, e), - }, - } - } - if event.is_writable() { - match event.token() { - utils::SERVER_TOKEN | utils::SECURE_SERVER_TOKEN | utils::IO_TOKEN => { - unreachable!() - } - Token(token) => match hw_network.client_writable(&poll, token) { - Ok(()) => (), - Err(e) => debug!("Error writing to client socket {}: {}", token, e), - }, - } - } - } + let mut hw_network = NetworkLayerBuilder::default().with_listener(server).build(); - match hw_network.on_idle(&poll) { - Ok(()) => (), - Err(e) => debug!("Error in idle handler: {}", e), - }; - - if time.elapsed() > Duration::from_secs(1) { - time = Instant::now(); - match hw_network.handle_timeout(&mut poll) { - Ok(()) => (), - Err(e) => debug!("Error in timer event: {}", e), - } - } - } + hw_network.run().await; + Ok(()) } diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/protocol.rs --- a/rust/hedgewars-server/src/protocol.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/protocol.rs Tue Feb 01 02:23:35 2022 +0300 @@ -1,67 +1,75 @@ +use bytes::{Buf, BufMut, BytesMut}; +use log::*; +use std::{io, io::ErrorKind, marker::Unpin}; +use tokio::io::AsyncReadExt; + use hedgewars_network_protocol::{ messages::HwProtocolMessage, parser::{malformed_message, message}, }; -use log::*; -use netbuf; -use std::io::{Read, Result}; pub struct ProtocolDecoder { - buf: netbuf::Buf, + buffer: BytesMut, is_recovering: bool, } impl ProtocolDecoder { pub fn new() -> ProtocolDecoder { ProtocolDecoder { - buf: netbuf::Buf::new(), + buffer: BytesMut::with_capacity(1024), is_recovering: false, } } fn recover(&mut self) -> bool { - self.is_recovering = match malformed_message(&self.buf[..]) { + self.is_recovering = match malformed_message(&self.buffer[..]) { Ok((tail, ())) => { - let length = tail.len(); - self.buf.consume(self.buf.len() - length); + let remaining = tail.len(); + self.buffer.advance(self.buffer.len() - remaining); false } _ => { - self.buf.consume(self.buf.len()); + self.buffer.clear(); true } }; !self.is_recovering } - pub fn read_from(&mut self, stream: &mut R) -> Result { - let count = self.buf.read_from(stream)?; - if count > 0 && self.is_recovering { - self.recover(); - } - Ok(count) - } - - pub fn extract_messages(&mut self) -> Vec { - let mut messages = vec![]; - if !self.is_recovering { - while !self.buf.is_empty() { - match message(&self.buf[..]) { - Ok((tail, message)) => { - messages.push(message); - let length = tail.len(); - self.buf.consume(self.buf.len() - length); - } - Err(nom::Err::Incomplete(_)) => break, - Err(nom::Err::Failure(e) | nom::Err::Error(e)) => { - debug!("Invalid message: {:?}", e); - if !self.recover() || self.buf.is_empty() { - break; - } - } + fn extract_message(&mut self) -> Option { + if !self.is_recovering || self.recover() { + match message(&self.buffer[..]) { + Ok((tail, message)) => { + let remaining = tail.len(); + self.buffer.advance(self.buffer.len() - remaining); + return Some(message); + } + Err(nom::Err::Incomplete(_)) => {} + Err(nom::Err::Failure(e) | nom::Err::Error(e)) => { + debug!("Invalid message: {:?}", e); + self.recover(); } } } - messages + None + } + + pub async fn read_from( + &mut self, + stream: &mut R, + ) -> Option { + loop { + if !self.buffer.has_remaining() { + let count = stream.read_buf(&mut self.buffer).await.ok()?; + if count == 0 { + return None; + } + } + while !self.buffer.is_empty() { + if let Some(result) = self.extract_message() { + return Some(result); + } + } + } } } diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/server/demo.rs --- a/rust/hedgewars-server/src/server/demo.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/server/demo.rs Tue Feb 01 02:23:35 2022 +0300 @@ -170,18 +170,16 @@ u32::from_str(arg).unwrap_or_default(), )), "addteam" => { - if let parts = arg.splitn(3, ' ').collect::>() { - let color = parts.get(1).unwrap_or(&"1"); - let name = parts.get(2).unwrap_or(&"Unnamed"); - teams.push(TeamInfo { - color: (u32::from_str(color).unwrap_or(2113696) - / 2113696 - - 1) - as u8, - name: name.to_string(), - ..TeamInfo::default() - }) - }; + let parts = arg.splitn(3, ' ').collect::>(); + let color = parts.get(1).unwrap_or(&"1"); + let name = parts.get(2).unwrap_or(&"Unnamed"); + teams.push(TeamInfo { + color: (u32::from_str(color).unwrap_or(2113696) / 2113696 + - 1) + as u8, + name: name.to_string(), + ..TeamInfo::default() + }); } "fort" => teams .last_mut() @@ -193,21 +191,19 @@ .for_each(|t| t.grave = arg.to_string()), "addhh" => { hog_index = (hog_index + 1) % 8; - if let parts = arg.splitn(3, ' ').collect::>() { - let health = parts.get(1).unwrap_or(&"100"); - teams.last_mut().iter_mut().for_each(|t| { - if let Some(difficulty) = parts.get(0) { - t.difficulty = - u8::from_str(difficulty).unwrap_or(0); - } - if let Some(init_health) = parts.get(1) { - scheme_properties[2] = init_health.to_string(); - } - t.hedgehogs_number = (hog_index + 1) as u8; - t.hedgehogs[hog_index].name = - parts.get(2).unwrap_or(&"Unnamed").to_string() - }); - } + let parts = arg.splitn(3, ' ').collect::>(); + let health = parts.get(1).unwrap_or(&"100"); + teams.last_mut().iter_mut().for_each(|t| { + if let Some(difficulty) = parts.get(0) { + t.difficulty = u8::from_str(difficulty).unwrap_or(0); + } + if let Some(init_health) = parts.get(1) { + scheme_properties[2] = init_health.to_string(); + } + t.hedgehogs_number = (hog_index + 1) as u8; + t.hedgehogs[hog_index].name = + parts.get(2).unwrap_or(&"Unnamed").to_string(); + }); } "hat" => { teams diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/server/io.rs --- a/rust/hedgewars-server/src/server/io.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/server/io.rs Tue Feb 01 02:23:35 2022 +0300 @@ -10,7 +10,6 @@ server::database::Database, }; use log::*; -use mio::{Poll, Waker}; pub type RequestId = u32; diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/server/network.rs Tue Feb 01 02:23:35 2022 +0300 @@ -1,24 +1,22 @@ -extern crate slab; - +use bytes::{Buf, Bytes}; +use log::*; +use slab::Slab; use std::{ collections::HashSet, io, io::{Error, ErrorKind, Read, Write}, + iter::Iterator, mem::{replace, swap}, net::{IpAddr, Ipv4Addr, SocketAddr}, num::NonZeroU32, time::Duration, time::Instant, }; - -use log::*; -use mio::{ - event::Source, +use tokio::{ + io::AsyncReadExt, net::{TcpListener, TcpStream}, - Interest, Poll, Token, Waker, + sync::mpsc::{channel, Receiver, Sender}, }; -use netbuf; -use slab::Slab; use crate::{ core::{ @@ -30,404 +28,188 @@ protocol::ProtocolDecoder, utils, }; -use hedgewars_network_protocol::{messages::HwServerMessage::Redirect, messages::*}; - -#[cfg(feature = "official-server")] -use super::io::{IoThread, RequestId}; +use hedgewars_network_protocol::{ + messages::HwServerMessage::Redirect, messages::*, parser::server_message, +}; +use tokio::io::AsyncWriteExt; -#[cfg(feature = "tls-connections")] -use openssl::{ - error::ErrorStack, - ssl::{ - HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype, - SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode, - }, -}; +enum ClientUpdateData { + Message(HwProtocolMessage), + Error(String), +} + +struct ClientUpdate { + client_id: ClientId, + data: ClientUpdateData, +} -const MAX_BYTES_PER_READ: usize = 2048; -const SEND_PING_TIMEOUT: Duration = Duration::from_secs(5); -const DROP_CLIENT_TIMEOUT: Duration = Duration::from_secs(5); -const MAX_TIMEOUT: usize = DROP_CLIENT_TIMEOUT.as_secs() as usize; -const PING_PROBES_COUNT: u8 = 2; +struct ClientUpdateSender { + client_id: ClientId, + sender: Sender, +} -#[derive(Hash, Eq, PartialEq, Copy, Clone)] -pub enum NetworkClientState { - Idle, - NeedsWrite, - NeedsRead, - Closed, - #[cfg(feature = "tls-connections")] - Connected, +impl ClientUpdateSender { + async fn send(&mut self, data: ClientUpdateData) -> bool { + self.sender + .send(ClientUpdate { + client_id: self.client_id, + data, + }) + .await + .is_ok() + } +} + +struct NetworkClient { + id: ClientId, + socket: TcpStream, + receiver: Receiver, + peer_addr: SocketAddr, + decoder: ProtocolDecoder, } -type NetworkResult = io::Result<(T, NetworkClientState)>; +impl NetworkClient { + fn new( + id: ClientId, + socket: TcpStream, + peer_addr: SocketAddr, + receiver: Receiver, + ) -> Self { + Self { + id, + socket, + peer_addr, + receiver, + decoder: ProtocolDecoder::new(), + } + } -pub enum ClientSocket { - Plain(TcpStream), - #[cfg(feature = "tls-connections")] - SslHandshake(Option>), - #[cfg(feature = "tls-connections")] - SslStream(SslStream), -} + async fn read(&mut self) -> Option { + self.decoder.read_from(&mut self.socket).await + } + + async fn write(&mut self, mut data: Bytes) -> bool { + !data.has_remaining() || matches!(self.socket.write_buf(&mut data).await, Ok(n) if n > 0) + } -impl ClientSocket { - fn inner_mut(&mut self) -> &mut TcpStream { - match self { - ClientSocket::Plain(stream) => stream, - #[cfg(feature = "tls-connections")] - ClientSocket::SslHandshake(Some(builder)) => builder.get_mut(), - #[cfg(feature = "tls-connections")] - ClientSocket::SslHandshake(None) => unreachable!(), - #[cfg(feature = "tls-connections")] - ClientSocket::SslStream(ssl_stream) => ssl_stream.get_mut(), + async fn run(mut self, sender: Sender) { + use ClientUpdateData::*; + let mut sender = ClientUpdateSender { + client_id: self.id, + sender, + }; + + loop { + tokio::select! { + server_message = self.receiver.recv() => { + match server_message { + Some(message) => if !self.write(message).await { + sender.send(Error("Connection reset by peer".to_string())).await; + break; + } + None => { + break; + } + } + } + client_message = self.decoder.read_from(&mut self.socket) => { + match client_message { + Some(message) => { + if !sender.send(Message(message)).await { + break; + } + } + None => { + sender.send(Error("Connection reset by peer".to_string())).await; + break; + } + } + } + } } } } -pub struct NetworkClient { - id: ClientId, - socket: ClientSocket, - peer_addr: SocketAddr, - decoder: ProtocolDecoder, - buf_out: netbuf::Buf, - pending_close: bool, - timeout: Timeout, - last_rx_time: Instant, +pub struct NetworkLayer { + listener: TcpListener, + server_state: ServerState, + clients: Slab>, } -impl NetworkClient { - pub fn new( - id: ClientId, - socket: ClientSocket, - peer_addr: SocketAddr, - timeout: Timeout, - ) -> NetworkClient { - NetworkClient { - id, - socket, - peer_addr, - decoder: ProtocolDecoder::new(), - buf_out: netbuf::Buf::new(), - pending_close: false, - timeout, - last_rx_time: Instant::now(), - } - } +impl NetworkLayer { + pub async fn run(&mut self) { + let (update_tx, mut update_rx) = channel(128); - #[cfg(feature = "tls-connections")] - fn handshake_impl( - &mut self, - handshake: MidHandshakeSslStream, - ) -> io::Result { - match handshake.handshake() { - Ok(stream) => { - self.socket = ClientSocket::SslStream(stream); - debug!( - "TLS handshake with {} ({}) completed", - self.id, self.peer_addr - ); - Ok(NetworkClientState::Connected) + loop { + tokio::select! { + Ok((stream, addr)) = self.listener.accept() => { + if let Some(client) = self.create_client(stream, addr).await { + tokio::spawn(client.run(update_tx.clone())); + } + } + client_message = update_rx.recv(), if !self.clients.is_empty() => { + use ClientUpdateData::*; + match client_message { + Some(ClientUpdate{ client_id, data: Message(message) } ) => { + self.handle_message(client_id, message).await; + } + Some(ClientUpdate{ client_id, .. } ) => { + let mut response = handlers::Response::new(client_id); + handlers::handle_client_loss(&mut self.server_state, client_id, &mut response); + self.handle_response(response).await; + } + None => unreachable!() + } + } } - Err(HandshakeError::WouldBlock(new_handshake)) => { - self.socket = ClientSocket::SslHandshake(Some(new_handshake)); - Ok(NetworkClientState::Idle) - } - Err(HandshakeError::Failure(new_handshake)) => { - self.socket = ClientSocket::SslHandshake(Some(new_handshake)); - debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); - Err(Error::new(ErrorKind::Other, "Connection failure")) - } - Err(HandshakeError::SetupFailure(_)) => unreachable!(), } } - fn read_impl( - decoder: &mut ProtocolDecoder, - source: &mut R, - id: ClientId, - addr: &SocketAddr, - ) -> NetworkResult> { - let mut bytes_read = 0; - let result = loop { - match decoder.read_from(source) { - Ok(bytes) => { - debug!("Client {}: read {} bytes", id, bytes); - bytes_read += bytes; - if bytes == 0 { - let result = if bytes_read == 0 { - info!("EOF for client {} ({})", id, addr); - (Vec::new(), NetworkClientState::Closed) - } else { - (decoder.extract_messages(), NetworkClientState::NeedsRead) - }; - break Ok(result); - } else if bytes_read >= MAX_BYTES_PER_READ { - break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)); - } - } - Err(ref error) if error.kind() == ErrorKind::WouldBlock => { - let messages = if bytes_read == 0 { - Vec::new() - } else { - decoder.extract_messages() - }; - break Ok((messages, NetworkClientState::Idle)); - } - Err(error) => break Err(error), - } - }; - result - } + async fn create_client( + &mut self, + stream: TcpStream, + addr: SocketAddr, + ) -> Option { + let entry = self.clients.vacant_entry(); + let client_id = entry.key(); + let (tx, rx) = channel(16); + entry.insert(tx); + + let client = NetworkClient::new(client_id, stream, addr, rx); - pub fn read(&mut self) -> NetworkResult> { - let result = match self.socket { - ClientSocket::Plain(ref mut stream) => { - NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) - } - #[cfg(feature = "tls-connections")] - ClientSocket::SslHandshake(ref mut handshake_opt) => { - let handshake = std::mem::replace(handshake_opt, None).unwrap(); - Ok((Vec::new(), self.handshake_impl(handshake)?)) - } - #[cfg(feature = "tls-connections")] - ClientSocket::SslStream(ref mut stream) => { - NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) - } + info!("client {} ({}) added", client.id, client.peer_addr); + + let mut response = handlers::Response::new(client_id); + + let added = if let IpAddr::V4(addr) = client.peer_addr.ip() { + handlers::handle_client_accept( + &mut self.server_state, + client_id, + &mut response, + addr.octets(), + addr.is_loopback(), + ) + } else { + todo!("implement something") }; - if let Ok(_) = result { - self.last_rx_time = Instant::now(); - } - - result - } - - fn write_impl( - buf_out: &mut netbuf::Buf, - destination: &mut W, - close_on_empty: bool, - ) -> NetworkResult<()> { - let result = loop { - match buf_out.write_to(destination) { - Ok(bytes) if buf_out.is_empty() || bytes == 0 => { - let status = if buf_out.is_empty() && close_on_empty { - NetworkClientState::Closed - } else { - NetworkClientState::Idle - }; - break Ok(((), status)); - } - Ok(_) => (), - Err(ref error) - if error.kind() == ErrorKind::Interrupted - || error.kind() == ErrorKind::WouldBlock => - { - break Ok(((), NetworkClientState::NeedsWrite)); - } - Err(error) => break Err(error), - } - }; - result - } - - pub fn write(&mut self) -> NetworkResult<()> { - let result = match self.socket { - ClientSocket::Plain(ref mut stream) => { - NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close) - } - #[cfg(feature = "tls-connections")] - ClientSocket::SslHandshake(ref mut handshake_opt) => { - let handshake = std::mem::replace(handshake_opt, None).unwrap(); - Ok(((), self.handshake_impl(handshake)?)) - } - #[cfg(feature = "tls-connections")] - ClientSocket::SslStream(ref mut stream) => { - NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close) - } - }; - - self.socket.inner_mut().flush()?; - result - } + self.handle_response(response).await; - pub fn send_raw_msg(&mut self, msg: &[u8]) { - self.buf_out.write_all(msg).unwrap(); - } - - pub fn send_string(&mut self, msg: &str) { - self.send_raw_msg(&msg.as_bytes()); - } - - pub fn replace_timeout(&mut self, timeout: Timeout) -> Timeout { - replace(&mut self.timeout, timeout) - } - - pub fn has_pending_sends(&self) -> bool { - !self.buf_out.is_empty() - } -} - -#[cfg(feature = "tls-connections")] -struct ServerSsl { - listener: TcpListener, - context: SslContext, -} - -#[cfg(feature = "official-server")] -pub struct IoLayer { - next_request_id: RequestId, - request_queue: Vec<(RequestId, ClientId)>, - io_thread: IoThread, -} - -#[cfg(feature = "official-server")] -impl IoLayer { - fn new(waker: Waker) -> Self { - Self { - next_request_id: 0, - request_queue: vec![], - io_thread: IoThread::new(waker), - } - } - - fn send(&mut self, client_id: ClientId, task: IoTask) { - let request_id = self.next_request_id; - self.next_request_id += 1; - self.request_queue.push((request_id, client_id)); - self.io_thread.send(request_id, task); - } - - fn try_recv(&mut self) -> Option<(ClientId, IoResult)> { - let (request_id, result) = self.io_thread.try_recv()?; - if let Some(index) = self - .request_queue - .iter() - .position(|(id, _)| *id == request_id) - { - let (_, client_id) = self.request_queue.swap_remove(index); - Some((client_id, result)) + if added { + Some(client) } else { None } } - fn cancel(&mut self, client_id: ClientId) { - let mut index = 0; - while index < self.request_queue.len() { - if self.request_queue[index].1 == client_id { - self.request_queue.swap_remove(index); - } else { - index += 1; - } - } - } -} - -enum TimeoutEvent { - SendPing { probes_count: u8 }, - DropClient, -} - -struct TimerData(TimeoutEvent, ClientId); -type NetworkTimeoutEvents = TimedEvents; - -pub struct NetworkLayer { - listener: TcpListener, - server_state: ServerState, - clients: Slab, - pending: HashSet<(ClientId, NetworkClientState)>, - pending_cache: Vec<(ClientId, NetworkClientState)>, - #[cfg(feature = "tls-connections")] - ssl: ServerSsl, - #[cfg(feature = "official-server")] - io: IoLayer, - timeout_events: NetworkTimeoutEvents, -} - -fn register_read(poll: &Poll, source: &mut S, token: mio::Token) -> io::Result<()> { - poll.registry().register(source, token, Interest::READABLE) -} - -fn create_ping_timeout( - timeout_events: &mut NetworkTimeoutEvents, - probes_count: u8, - client_id: ClientId, -) -> Timeout { - timeout_events.set_timeout( - NonZeroU32::new(SEND_PING_TIMEOUT.as_secs() as u32).unwrap(), - TimerData(TimeoutEvent::SendPing { probes_count }, client_id), - ) -} - -fn create_drop_timeout(timeout_events: &mut NetworkTimeoutEvents, client_id: ClientId) -> Timeout { - timeout_events.set_timeout( - NonZeroU32::new(DROP_CLIENT_TIMEOUT.as_secs() as u32).unwrap(), - TimerData(TimeoutEvent::DropClient, client_id), - ) -} - -impl NetworkLayer { - pub fn register(&mut self, poll: &Poll) -> io::Result<()> { - register_read(poll, &mut self.listener, utils::SERVER_TOKEN)?; - #[cfg(feature = "tls-connections")] - register_read(poll, &mut self.ssl.listener, utils::SECURE_SERVER_TOKEN)?; - - Ok(()) + async fn handle_message(&mut self, client_id: ClientId, message: HwProtocolMessage) { + debug!("Handling message {:?} for client {}", message, client_id); + let mut response = handlers::Response::new(client_id); + handlers::handle(&mut self.server_state, client_id, &mut response, message); + self.handle_response(response).await; } - fn deregister_client(&mut self, poll: &Poll, id: ClientId, is_error: bool) { - if let Some(ref mut client) = self.clients.get_mut(id) { - poll.registry() - .deregister(client.socket.inner_mut()) - .expect("could not deregister socket"); - if client.has_pending_sends() && !is_error { - info!( - "client {} ({}) pending removal", - client.id, client.peer_addr - ); - client.pending_close = true; - poll.registry() - .register(client.socket.inner_mut(), Token(id), Interest::WRITABLE) - .unwrap_or_else(|_| { - self.clients.remove(id); - }); - } else { - info!("client {} ({}) removed", client.id, client.peer_addr); - self.clients.remove(id); - } - #[cfg(feature = "official-server")] - self.io.cancel(id); - } - } - - fn register_client( - &mut self, - poll: &Poll, - mut client_socket: ClientSocket, - addr: SocketAddr, - ) -> io::Result { - let entry = self.clients.vacant_entry(); - let client_id = entry.key(); - - poll.registry().register( - client_socket.inner_mut(), - Token(client_id), - Interest::READABLE | Interest::WRITABLE, - )?; - - let client = NetworkClient::new( - client_id, - client_socket, - addr, - create_ping_timeout(&mut self.timeout_events, PING_PROBES_COUNT - 1, client_id), - ); - info!("client {} ({}) added", client.id, client.peer_addr); - entry.insert(client); - - Ok(client_id) - } - - fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) { + async fn handle_response(&mut self, mut response: handlers::Response) { if response.is_empty() { return; } @@ -436,263 +218,38 @@ let output = response.extract_messages(&mut self.server_state.server); for (clients, message) in output { debug!("Message {:?} to {:?}", message, clients); - let msg_string = message.to_raw_protocol(); - for client_id in clients { - if let Some(client) = self.clients.get_mut(client_id) { - client.send_string(&msg_string); - self.pending - .insert((client_id, NetworkClientState::NeedsWrite)); - } - } + Self::send_message(&mut self.clients, message, clients.iter().cloned()).await; } for client_id in response.extract_removed_clients() { - self.deregister_client(poll, client_id, false); - } - - #[cfg(feature = "official-server")] - { - let client_id = response.client_id(); - for task in response.extract_io_tasks() { - self.io.send(client_id, task); - } - } - } - - pub fn handle_timeout(&mut self, poll: &mut Poll) -> io::Result<()> { - for TimerData(event, client_id) in self.timeout_events.poll(Instant::now()) { - if let Some(client) = self.clients.get_mut(client_id) { - if client.last_rx_time.elapsed() > SEND_PING_TIMEOUT { - match event { - TimeoutEvent::SendPing { probes_count } => { - client.send_string(&HwServerMessage::Ping.to_raw_protocol()); - client.write()?; - let timeout = if probes_count != 0 { - create_ping_timeout( - &mut self.timeout_events, - probes_count - 1, - client_id, - ) - } else { - create_drop_timeout(&mut self.timeout_events, client_id) - }; - client.replace_timeout(timeout); - } - TimeoutEvent::DropClient => { - client.send_string( - &HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol(), - ); - let _res = client.write(); - - self.operation_failed( - poll, - client_id, - &ErrorKind::TimedOut.into(), - "No ping response", - )?; - } - } - } else { - client.replace_timeout(create_ping_timeout( - &mut self.timeout_events, - PING_PROBES_COUNT - 1, - client_id, - )); - } + if self.clients.contains(client_id) { + self.clients.remove(client_id); } - } - Ok(()) - } - - #[cfg(feature = "official-server")] - pub fn handle_io_result(&mut self, poll: &Poll) -> io::Result<()> { - while let Some((client_id, result)) = self.io.try_recv() { - debug!("Handling io result {:?} for client {}", result, client_id); - let mut response = handlers::Response::new(client_id); - handlers::handle_io_result(&mut self.server_state, client_id, &mut response, result); - self.handle_response(response, poll); - } - Ok(()) - } - - fn create_client_socket(&self, socket: TcpStream) -> io::Result { - Ok(ClientSocket::Plain(socket)) - } - - #[cfg(feature = "tls-connections")] - fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result { - let ssl = Ssl::new(&self.ssl.context).unwrap(); - let mut builder = SslStreamBuilder::new(ssl, socket); - builder.set_accept_state(); - match builder.handshake() { - Ok(stream) => Ok(ClientSocket::SslStream(stream)), - Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))), - Err(e) => { - debug!("OpenSSL handshake failed: {}", e); - Err(Error::new(ErrorKind::Other, "Connection failure")) - } - } - } - - fn init_client(&mut self, poll: &Poll, client_id: ClientId) { - let mut response = handlers::Response::new(client_id); - - if let ClientSocket::Plain(_) = self.clients[client_id].socket { - #[cfg(feature = "tls-connections")] - response.add(Redirect(self.ssl.listener.local_addr().unwrap().port()).send_self()) - } - - if let IpAddr::V4(addr) = self.clients[client_id].peer_addr.ip() { - handlers::handle_client_accept( - &mut self.server_state, - client_id, - &mut response, - addr.octets(), - addr.is_loopback(), - ); - self.handle_response(response, poll); - } else { - todo!("implement something") + info!("Client {} removed", client_id); } } - pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> { - match server_token { - utils::SERVER_TOKEN => { - let (client_socket, addr) = self.listener.accept()?; - info!("Connected(plaintext): {}", addr); - let client_id = - self.register_client(poll, self.create_client_socket(client_socket)?, addr)?; - self.init_client(poll, client_id); - } - #[cfg(feature = "tls-connections")] - utils::SECURE_SERVER_TOKEN => { - let (client_socket, addr) = self.ssl.listener.accept()?; - info!("Connected(TLS): {}", addr); - self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr)?; - } - _ => unreachable!(), - } - - Ok(()) - } - - fn operation_failed( - &mut self, - poll: &Poll, - client_id: ClientId, - error: &Error, - msg: &str, - ) -> io::Result<()> { - let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.peer_addr - } else { - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) - }; - debug!("{}({}): {}", msg, addr, error); - self.client_error(poll, client_id) - } - - pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { - let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.read() - } else { - warn!("invalid readable client: {}", client_id); - Ok((Vec::new(), NetworkClientState::Idle)) - }; - - let mut response = handlers::Response::new(client_id); - - match messages { - Ok((messages, state)) => { - for message in messages { - debug!("Handling message {:?} for client {}", message, client_id); - handlers::handle(&mut self.server_state, client_id, &mut response, message); + async fn send_message( + clients: &mut Slab>, + message: HwServerMessage, + to_clients: I, + ) where + I: Iterator, + { + let msg_string = message.to_raw_protocol(); + let bytes = Bytes::copy_from_slice(msg_string.as_bytes()); + for client_id in to_clients { + if let Some(client) = clients.get_mut(client_id) { + if !client.send(bytes.clone()).await.is_ok() { + clients.remove(client_id); } - match state { - NetworkClientState::NeedsRead => { - self.pending.insert((client_id, state)); - } - NetworkClientState::Closed => self.client_error(&poll, client_id)?, - #[cfg(feature = "tls-connections")] - NetworkClientState::Connected => self.init_client(poll, client_id), - _ => {} - }; - } - Err(e) => self.operation_failed( - poll, - client_id, - &e, - "Error while reading from client socket", - )?, - } - - self.handle_response(response, poll); - - Ok(()) - } - - pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { - let result = if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.write() - } else { - warn!("invalid writable client: {}", client_id); - Ok(((), NetworkClientState::Idle)) - }; - - match result { - Ok(((), state)) if state == NetworkClientState::NeedsWrite => { - self.pending.insert((client_id, state)); - } - Ok(((), state)) if state == NetworkClientState::Closed => { - self.deregister_client(poll, client_id, false); - } - Ok(_) => (), - Err(e) => { - self.operation_failed(poll, client_id, &e, "Error while writing to client socket")? } } - - Ok(()) - } - - pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { - let pending_close = self.clients[client_id].pending_close; - self.deregister_client(poll, client_id, true); - - if !pending_close { - let mut response = handlers::Response::new(client_id); - handlers::handle_client_loss(&mut self.server_state, client_id, &mut response); - self.handle_response(response, poll); - } - - Ok(()) - } - - pub fn has_pending_operations(&self) -> bool { - !self.pending.is_empty() || !self.timeout_events.is_empty() - } - - pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> { - if self.has_pending_operations() { - let mut cache = replace(&mut self.pending_cache, Vec::new()); - cache.extend(self.pending.drain()); - for (id, state) in cache.drain(..) { - match state { - NetworkClientState::NeedsRead => self.client_readable(poll, id)?, - NetworkClientState::NeedsWrite => self.client_writable(poll, id)?, - _ => {} - } - } - swap(&mut cache, &mut self.pending_cache); - } - Ok(()) } } pub struct NetworkLayerBuilder { listener: Option, - secure_listener: Option, clients_capacity: usize, rooms_capacity: usize, } @@ -703,7 +260,6 @@ clients_capacity: 1024, rooms_capacity: 512, listener: None, - secure_listener: None, } } } @@ -716,59 +272,15 @@ } } - pub fn with_secure_listener(self, listener: TcpListener) -> Self { - Self { - secure_listener: Some(listener), - ..self - } - } - - #[cfg(feature = "tls-connections")] - fn create_ssl_context(listener: TcpListener) -> ServerSsl { - let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); - builder.set_verify(SslVerifyMode::NONE); - builder.set_read_ahead(true); - builder - .set_certificate_file("ssl/cert.pem", SslFiletype::PEM) - .expect("Cannot find certificate file"); - builder - .set_private_key_file("ssl/key.pem", SslFiletype::PEM) - .expect("Cannot find private key file"); - builder.set_options(SslOptions::NO_COMPRESSION); - builder.set_options(SslOptions::NO_TLSV1); - builder.set_options(SslOptions::NO_TLSV1_1); - builder.set_cipher_list("ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384").unwrap(); - ServerSsl { - listener, - context: builder.build(), - } - } - - pub fn build(self, poll: &Poll) -> NetworkLayer { + pub fn build(self) -> NetworkLayer { let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity); let clients = Slab::with_capacity(self.clients_capacity); - let pending = HashSet::with_capacity(2 * self.clients_capacity); - let pending_cache = Vec::with_capacity(2 * self.clients_capacity); - let timeout_events = NetworkTimeoutEvents::new(); - - #[cfg(feature = "official-server")] - let waker = Waker::new(poll.registry(), utils::IO_TOKEN) - .expect("Unable to create a waker for the IO thread"); NetworkLayer { listener: self.listener.expect("No listener provided"), server_state, clients, - pending, - pending_cache, - #[cfg(feature = "tls-connections")] - ssl: Self::create_ssl_context( - self.secure_listener.expect("No secure listener provided"), - ), - #[cfg(feature = "official-server")] - io: IoLayer::new(waker), - timeout_events, } } } diff -r ea459da15b30 -r 7d0f747afcb8 rust/hedgewars-server/src/utils.rs --- a/rust/hedgewars-server/src/utils.rs Mon Jan 31 18:24:49 2022 +0300 +++ b/rust/hedgewars-server/src/utils.rs Tue Feb 01 02:23:35 2022 +0300 @@ -1,12 +1,8 @@ use base64::encode; -use mio; use std::iter::Iterator; pub const SERVER_VERSION: u32 = 3; pub const SERVER_MESSAGE: &str = &"Hedgewars server https://www.hedgewars.org/"; -pub const SERVER_TOKEN: mio::Token = mio::Token(1_000_000_000); -pub const SECURE_SERVER_TOKEN: mio::Token = mio::Token(1_000_000_001); -pub const IO_TOKEN: mio::Token = mio::Token(1_000_000_003); pub fn is_name_illegal(name: &str) -> bool { name.len() > 40