# HG changeset patch # User alfadur # Date 1556112106 -10800 # Node ID 8ddb5842fe0b7d76caa1af1c52ab9b806643b76e # Parent f56936207a659a434ef780ec09dd884ae5b020c0 allow running plaintext and tls servers in parallel diff -r f56936207a65 -r 8ddb5842fe0b rust/hedgewars-server/src/main.rs --- a/rust/hedgewars-server/src/main.rs Tue Apr 23 15:54:06 2019 +0200 +++ b/rust/hedgewars-server/src/main.rs Wed Apr 24 16:21:46 2019 +0300 @@ -1,25 +1,24 @@ #![allow(unused_imports)] #![deny(bare_trait_objects)] -extern crate getopts; use getopts::Options; use log::*; -use mio::net::*; -use mio::*; -use std::env; +use mio::{net::*, *}; +use std::{env, str::FromStr as _, time::Duration}; mod protocol; mod server; mod utils; -use crate::server::network::NetworkLayer; -use std::time::Duration; +use crate::server::network::{NetworkLayer, NetworkLayerBuilder}; const PROGRAM_NAME: &'_ str = "Hedgewars Game Server"; fn main() { env_logger::init(); + info!("Hedgewars game server, protocol {}", utils::SERVER_VERSION); + let args: Vec = env::args().collect(); let mut opts = Options::new(); @@ -36,23 +35,26 @@ println!("{}", opts.usage(PROGRAM_NAME)); return; } - info!("Hedgewars game server, protocol {}", utils::SERVER_VERSION); - let address; - if matches.opt_present("p") { - match matches.opt_str("p") { - Some(x) => address = format!("0.0.0.0:{}", x).parse().unwrap(), - None => address = "0.0.0.0:46631".parse().unwrap(), - } - } else { - address = "0.0.0.0:46631".parse().unwrap(); - } + let port = matches + .opt_str("p") + .and_then(|s| u16::from_str(&s).ok()) + .unwrap_or(46631); + let address = format!("0.0.0.0:{}", port).parse().unwrap(); let listener = TcpListener::bind(&address).unwrap(); let poll = Poll::new().unwrap(); - let mut hw_network = NetworkLayer::new(listener, 1024, 512); - hw_network.register_server(&poll).unwrap(); + let mut hw_builder = NetworkLayerBuilder::default().with_listener(listener); + + #[cfg(feature = "tls-connections")] + { + let address = format!("0.0.0.0:{}", port + 1).parse().unwrap(); + hw_builder = hw_builder.with_secure_listener(TcpListener::bind(&address).unwrap()); + } + + let mut hw_network = hw_builder.build(); + hw_network.register(&poll).unwrap(); let mut events = Events::with_capacity(1024); @@ -67,7 +69,9 @@ for event in events.iter() { if event.readiness() & Ready::readable() == Ready::readable() { match event.token() { - utils::SERVER_TOKEN => hw_network.accept_client(&poll).unwrap(), + token @ utils::SERVER_TOKEN | token @ utils::SECURE_SERVER_TOKEN => { + hw_network.accept_client(&poll, token).unwrap() + } utils::TIMER_TOKEN => hw_network.handle_timeout(&poll).unwrap(), #[cfg(feature = "official-server")] utils::IO_TOKEN => hw_network.handle_io_result(), diff -r f56936207a65 -r 8ddb5842fe0b rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Tue Apr 23 15:54:06 2019 +0200 +++ b/rust/hedgewars-server/src/server/network.rs Wed Apr 24 16:21:46 2019 +0300 @@ -11,7 +11,7 @@ use log::*; use mio::{ net::{TcpListener, TcpStream}, - Poll, PollOpt, Ready, Token, + Evented, Poll, PollOpt, Ready, Token, }; use mio_extras::timer; use netbuf; @@ -48,32 +48,29 @@ NeedsWrite, NeedsRead, Closed, + #[cfg(feature = "tls-connections")] + Connected, } type NetworkResult = io::Result<(T, NetworkClientState)>; -#[cfg(not(feature = "tls-connections"))] pub enum ClientSocket { Plain(TcpStream), -} - -#[cfg(feature = "tls-connections")] -pub enum ClientSocket { + #[cfg(feature = "tls-connections")] SslHandshake(Option>), + #[cfg(feature = "tls-connections")] SslStream(SslStream), } impl ClientSocket { fn inner(&self) -> &TcpStream { - #[cfg(not(feature = "tls-connections"))] match self { ClientSocket::Plain(stream) => stream, - } - - #[cfg(feature = "tls-connections")] - match self { + #[cfg(feature = "tls-connections")] ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), + #[cfg(feature = "tls-connections")] ClientSocket::SslHandshake(None) => unreachable!(), + #[cfg(feature = "tls-connections")] ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(), } } @@ -117,7 +114,7 @@ "TLS handshake with {} ({}) completed", self.id, self.peer_addr ); - Ok(NetworkClientState::Idle) + Ok(NetworkClientState::Connected) } Err(HandshakeError::WouldBlock(new_handshake)) => { self.socket = ClientSocket::SslHandshake(Some(new_handshake)); @@ -171,19 +168,16 @@ } pub fn read(&mut self) -> NetworkResult> { - #[cfg(not(feature = "tls-connections"))] match self.socket { ClientSocket::Plain(ref mut stream) => { NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) } - } - - #[cfg(feature = "tls-connections")] - match self.socket { + #[cfg(feature = "tls-connections")] ClientSocket::SslHandshake(ref mut handshake_opt) => { let handshake = std::mem::replace(handshake_opt, None).unwrap(); Ok((Vec::new(), self.handshake_impl(handshake)?)) } + #[cfg(feature = "tls-connections")] ClientSocket::SslStream(ref mut stream) => { NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) } @@ -210,25 +204,18 @@ } pub fn write(&mut self) -> NetworkResult<()> { - let result = { - #[cfg(not(feature = "tls-connections"))] - match self.socket { - ClientSocket::Plain(ref mut stream) => { - NetworkClient::write_impl(&mut self.buf_out, stream) - } + let result = match self.socket { + ClientSocket::Plain(ref mut stream) => { + NetworkClient::write_impl(&mut self.buf_out, stream) } - #[cfg(feature = "tls-connections")] - { - match self.socket { - ClientSocket::SslHandshake(ref mut handshake_opt) => { - let handshake = std::mem::replace(handshake_opt, None).unwrap(); - Ok(((), self.handshake_impl(handshake)?)) - } - ClientSocket::SslStream(ref mut stream) => { - NetworkClient::write_impl(&mut self.buf_out, stream) - } - } + ClientSocket::SslHandshake(ref mut handshake_opt) => { + let handshake = std::mem::replace(handshake_opt, None).unwrap(); + Ok(((), self.handshake_impl(handshake)?)) + } + #[cfg(feature = "tls-connections")] + ClientSocket::SslStream(ref mut stream) => { + NetworkClient::write_impl(&mut self.buf_out, stream) } }; @@ -251,6 +238,7 @@ #[cfg(feature = "tls-connections")] struct ServerSsl { + listener: TcpListener, context: SslContext, } @@ -324,6 +312,10 @@ timer: timer::Timer, } +fn register_read(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> { + poll.register(evented, token, Ready::readable(), PollOpt::edge()) +} + fn create_ping_timeout( timer: &mut timer::Timer, probes_count: u8, @@ -343,29 +335,8 @@ } impl NetworkLayer { - pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { - let server = HWServer::new(clients_limit, rooms_limit); - let clients = Slab::with_capacity(clients_limit); - let pending = HashSet::with_capacity(2 * clients_limit); - let pending_cache = Vec::with_capacity(2 * clients_limit); - let timer = timer::Builder::default().build(); - - NetworkLayer { - listener, - server, - clients, - pending, - pending_cache, - #[cfg(feature = "tls-connections")] - ssl: NetworkLayer::create_ssl_context(), - #[cfg(feature = "official-server")] - io: IoLayer::new(), - timer, - } - } - #[cfg(feature = "tls-connections")] - fn create_ssl_context() -> ServerSsl { + fn create_ssl_context(listener: TcpListener) -> ServerSsl { let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); builder.set_read_ahead(true); @@ -378,24 +349,16 @@ builder.set_options(SslOptions::NO_COMPRESSION); builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); ServerSsl { + listener, context: builder.build(), } } - pub fn register_server(&self, poll: &Poll) -> io::Result<()> { - poll.register( - &self.listener, - utils::SERVER_TOKEN, - Ready::readable(), - PollOpt::edge(), - )?; - - poll.register( - &self.timer, - utils::TIMER_TOKEN, - Ready::readable(), - PollOpt::edge(), - )?; + pub fn register(&self, poll: &Poll) -> io::Result<()> { + register_read(poll, &self.listener, utils::SERVER_TOKEN)?; + #[cfg(feature = "tls-connections")] + register_read(poll, &self.listener, utils::SECURE_SERVER_TOKEN)?; + register_read(poll, &self.timer, utils::TIMER_TOKEN)?; #[cfg(feature = "official-server")] self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?; @@ -448,6 +411,10 @@ } fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) { + if response.is_empty() { + return; + } + debug!("{} pending server messages", response.len()); let output = response.extract_messages(&mut self.server); for (clients, message) in output { @@ -512,41 +479,41 @@ } fn create_client_socket(&self, socket: TcpStream) -> io::Result { - #[cfg(not(feature = "tls-connections"))] - { - Ok(ClientSocket::Plain(socket)) - } + Ok(ClientSocket::Plain(socket)) + } - #[cfg(feature = "tls-connections")] - { - let ssl = Ssl::new(&self.ssl.context).unwrap(); - let mut builder = SslStreamBuilder::new(ssl, socket); - builder.set_accept_state(); - match builder.handshake() { - Ok(stream) => Ok(ClientSocket::SslStream(stream)), - Err(HandshakeError::WouldBlock(stream)) => { - Ok(ClientSocket::SslHandshake(Some(stream))) - } - Err(e) => { - debug!("OpenSSL handshake failed: {}", e); - Err(Error::new(ErrorKind::Other, "Connection failure")) - } + #[cfg(feature = "tls-connections")] + fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result { + let ssl = Ssl::new(&self.ssl.context).unwrap(); + let mut builder = SslStreamBuilder::new(ssl, socket); + builder.set_accept_state(); + match builder.handshake() { + Ok(stream) => Ok(ClientSocket::SslStream(stream)), + Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))), + Err(e) => { + debug!("OpenSSL handshake failed: {}", e); + Err(Error::new(ErrorKind::Other, "Connection failure")) } } } - pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { + pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> { let (client_socket, addr) = self.listener.accept()?; info!("Connected: {}", addr); - let client_id = self.register_client(poll, self.create_client_socket(client_socket)?, addr); - - let mut response = handlers::Response::new(client_id); - - handlers::handle_client_accept(&mut self.server, client_id, &mut response); - - if !response.is_empty() { - self.handle_response(response, poll); + match server_token { + utils::SERVER_TOKEN => { + let client_id = + self.register_client(poll, self.create_client_socket(client_socket)?, addr); + let mut response = handlers::Response::new(client_id); + handlers::handle_client_accept(&mut self.server, client_id, &mut response); + self.handle_response(response, poll); + } + #[cfg(feature = "tls-connections")] + utils::SECURE_SERVER_TOKEN => { + self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr); + } + _ => unreachable!(), } Ok(()) @@ -595,6 +562,12 @@ self.pending.insert((client_id, state)); } NetworkClientState::Closed => self.client_error(&poll, client_id)?, + #[cfg(feature = "tls-connections")] + NetworkClientState::Connected => { + let mut response = handlers::Response::new(client_id); + handlers::handle_client_accept(&mut self.server, client_id, &mut response); + self.handle_response(response, poll); + } _ => {} }; } @@ -606,9 +579,7 @@ )?, } - if !response.is_empty() { - self.handle_response(response, poll); - } + self.handle_response(response, poll); Ok(()) } @@ -663,3 +634,60 @@ Ok(()) } } + +pub struct NetworkLayerBuilder { + listener: Option, + secure_listener: Option, + clients_capacity: usize, + rooms_capacity: usize, +} + +impl Default for NetworkLayerBuilder { + fn default() -> Self { + Self { + clients_capacity: 1024, + rooms_capacity: 512, + listener: None, + secure_listener: None, + } + } +} + +impl NetworkLayerBuilder { + pub fn with_listener(self, listener: TcpListener) -> Self { + Self { + listener: Some(listener), + ..self + } + } + + pub fn with_secure_listener(self, listener: TcpListener) -> Self { + Self { + secure_listener: Some(listener), + ..self + } + } + + pub fn build(self) -> NetworkLayer { + let server = HWServer::new(self.clients_capacity, self.rooms_capacity); + let clients = Slab::with_capacity(self.clients_capacity); + let pending = HashSet::with_capacity(2 * self.clients_capacity); + let pending_cache = Vec::with_capacity(2 * self.clients_capacity); + let timer = timer::Builder::default().build(); + + NetworkLayer { + listener: self.listener.expect("No listener provided"), + server, + clients, + pending, + pending_cache, + #[cfg(feature = "tls-connections")] + ssl: NetworkLayer::create_ssl_context( + self.secure_listener.expect("No secure listener provided"), + ), + #[cfg(feature = "official-server")] + io: IoLayer::new(), + timer, + } + } +} diff -r f56936207a65 -r 8ddb5842fe0b rust/hedgewars-server/src/utils.rs --- a/rust/hedgewars-server/src/utils.rs Tue Apr 23 15:54:06 2019 +0200 +++ b/rust/hedgewars-server/src/utils.rs Wed Apr 24 16:21:46 2019 +0300 @@ -4,8 +4,9 @@ pub const SERVER_VERSION: u32 = 3; pub const SERVER_TOKEN: mio::Token = mio::Token(1_000_000_000); -pub const TIMER_TOKEN: mio::Token = mio::Token(1_000_000_001); -pub const IO_TOKEN: mio::Token = mio::Token(1_000_000_002); +pub const SECURE_SERVER_TOKEN: mio::Token = mio::Token(1_000_000_001); +pub const TIMER_TOKEN: mio::Token = mio::Token(1_000_000_002); +pub const IO_TOKEN: mio::Token = mio::Token(1_000_000_003); pub fn is_name_illegal(name: &str) -> bool { name.len() > 40