diff -r a077aac9df01 -r 98ef2913ec73 rust/hedgewars-server/src/server/network.rs --- a/rust/hedgewars-server/src/server/network.rs Sun Dec 16 00:09:20 2018 +0100 +++ b/rust/hedgewars-server/src/server/network.rs Sun Dec 16 00:12:29 2018 +0100 @@ -1,37 +1,33 @@ extern crate slab; use std::{ - io, io::{Error, ErrorKind, Read, Write}, - net::{SocketAddr, IpAddr, Ipv4Addr}, collections::HashSet, - mem::{swap, replace} + io, + io::{Error, ErrorKind, Read, Write}, + mem::{replace, swap}, + net::{IpAddr, Ipv4Addr, SocketAddr}, }; +use log::*; use mio::{ - net::{TcpStream, TcpListener}, - Poll, PollOpt, Ready, Token + net::{TcpListener, TcpStream}, + Poll, PollOpt, Ready, Token, }; use netbuf; use slab::Slab; -use log::*; +use super::{core::HWServer, coretypes::ClientId, io::FileServerIO}; use crate::{ + protocol::{messages::*, ProtocolDecoder}, utils, - protocol::{ProtocolDecoder, messages::*} -}; -use super::{ - io::FileServerIO, - core::{HWServer}, - coretypes::ClientId }; #[cfg(feature = "tls-connections")] use openssl::{ + error::ErrorStack, ssl::{ - SslMethod, SslContext, Ssl, SslContextBuilder, - SslVerifyMode, SslFiletype, SslOptions, - SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream + HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype, + SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode, }, - error::ErrorStack }; const MAX_BYTES_PER_READ: usize = 2048; @@ -48,13 +44,13 @@ #[cfg(not(feature = "tls-connections"))] pub enum ClientSocket { - Plain(TcpStream) + Plain(TcpStream), } #[cfg(feature = "tls-connections")] pub enum ClientSocket { SslHandshake(Option>), - SslStream(SslStream) + SslStream(SslStream), } impl ClientSocket { @@ -68,7 +64,7 @@ match self { ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), ClientSocket::SslHandshake(None) => unreachable!(), - ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref() + ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(), } } } @@ -78,24 +74,32 @@ socket: ClientSocket, peer_addr: SocketAddr, decoder: ProtocolDecoder, - buf_out: netbuf::Buf + buf_out: netbuf::Buf, } impl NetworkClient { pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient { NetworkClient { - id, socket, peer_addr, + id, + socket, + peer_addr, decoder: ProtocolDecoder::new(), - buf_out: netbuf::Buf::new() + buf_out: netbuf::Buf::new(), } } #[cfg(feature = "tls-connections")] - fn handshake_impl(&mut self, handshake: MidHandshakeSslStream) -> io::Result { + fn handshake_impl( + &mut self, + handshake: MidHandshakeSslStream, + ) -> io::Result { match handshake.handshake() { Ok(stream) => { self.socket = ClientSocket::SslStream(stream); - debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr); + debug!( + "TLS handshake with {} ({}) completed", + self.id, self.peer_addr + ); Ok(NetworkClientState::Idle) } Err(HandshakeError::WouldBlock(new_handshake)) => { @@ -107,12 +111,16 @@ debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); Err(Error::new(ErrorKind::Other, "Connection failure")) } - Err(HandshakeError::SetupFailure(_)) => unreachable!() + Err(HandshakeError::SetupFailure(_)) => unreachable!(), } } - fn read_impl(decoder: &mut ProtocolDecoder, source: &mut R, - id: ClientId, addr: &SocketAddr) -> NetworkResult> { + fn read_impl( + decoder: &mut ProtocolDecoder, + source: &mut R, + id: ClientId, + addr: &SocketAddr, + ) -> NetworkResult> { let mut bytes_read = 0; let result = loop { match decoder.read_from(source) { @@ -127,21 +135,19 @@ (decoder.extract_messages(), NetworkClientState::NeedsRead) }; break Ok(result); - } - else if bytes_read >= MAX_BYTES_PER_READ { - break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)) + } else if bytes_read >= MAX_BYTES_PER_READ { + break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)); } } Err(ref error) if error.kind() == ErrorKind::WouldBlock => { - let messages = if bytes_read == 0 { + let messages = if bytes_read == 0 { Vec::new() } else { decoder.extract_messages() }; break Ok((messages, NetworkClientState::Idle)); } - Err(error) => - break Err(error) + Err(error) => break Err(error), } }; decoder.sweep(); @@ -151,8 +157,9 @@ pub fn read(&mut self) -> NetworkResult> { #[cfg(not(feature = "tls-connections"))] match self.socket { - ClientSocket::Plain(ref mut stream) => - NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr), + ClientSocket::Plain(ref mut stream) => { + NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) + } } #[cfg(feature = "tls-connections")] @@ -160,24 +167,27 @@ ClientSocket::SslHandshake(ref mut handshake_opt) => { let handshake = std::mem::replace(handshake_opt, None).unwrap(); Ok((Vec::new(), self.handshake_impl(handshake)?)) - }, - ClientSocket::SslStream(ref mut stream) => + } + ClientSocket::SslStream(ref mut stream) => { NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) + } } } fn write_impl(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> { let result = loop { match buf_out.write_to(destination) { - Ok(bytes) if buf_out.is_empty() || bytes == 0 => - break Ok(((), NetworkClientState::Idle)), + Ok(bytes) if buf_out.is_empty() || bytes == 0 => { + break Ok(((), NetworkClientState::Idle)) + } Ok(_) => (), - Err(ref error) if error.kind() == ErrorKind::Interrupted - || error.kind() == ErrorKind::WouldBlock => { + Err(ref error) + if error.kind() == ErrorKind::Interrupted + || error.kind() == ErrorKind::WouldBlock => + { break Ok(((), NetworkClientState::NeedsWrite)); - }, - Err(error) => - break Err(error) + } + Err(error) => break Err(error), } }; result @@ -187,18 +197,21 @@ let result = { #[cfg(not(feature = "tls-connections"))] match self.socket { - ClientSocket::Plain(ref mut stream) => + ClientSocket::Plain(ref mut stream) => { NetworkClient::write_impl(&mut self.buf_out, stream) + } } - #[cfg(feature = "tls-connections")] { + #[cfg(feature = "tls-connections")] + { match self.socket { ClientSocket::SslHandshake(ref mut handshake_opt) => { let handshake = std::mem::replace(handshake_opt, None).unwrap(); Ok(((), self.handshake_impl(handshake)?)) } - ClientSocket::SslStream(ref mut stream) => + ClientSocket::SslStream(ref mut stream) => { NetworkClient::write_impl(&mut self.buf_out, stream) + } } } }; @@ -222,7 +235,7 @@ #[cfg(feature = "tls-connections")] struct ServerSsl { - context: SslContext + context: SslContext, } pub struct NetworkLayer { @@ -232,7 +245,7 @@ pending: HashSet<(ClientId, NetworkClientState)>, pending_cache: Vec<(ClientId, NetworkClientState)>, #[cfg(feature = "tls-connections")] - ssl: ServerSsl + ssl: ServerSsl, } impl NetworkLayer { @@ -243,9 +256,13 @@ let pending_cache = Vec::with_capacity(2 * clients_limit); NetworkLayer { - listener, server, clients, pending, pending_cache, + listener, + server, + clients, + pending, + pending_cache, #[cfg(feature = "tls-connections")] - ssl: NetworkLayer::create_ssl_context() + ssl: NetworkLayer::create_ssl_context(), } } @@ -254,16 +271,26 @@ let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); builder.set_verify(SslVerifyMode::NONE); builder.set_read_ahead(true); - builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap(); - builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap(); + builder + .set_certificate_file("ssl/cert.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_private_key_file("ssl/key.pem", SslFiletype::PEM) + .unwrap(); builder.set_options(SslOptions::NO_COMPRESSION); builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); - ServerSsl { context: builder.build() } + ServerSsl { + context: builder.build(), + } } pub fn register_server(&self, poll: &Poll) -> io::Result<()> { - poll.register(&self.listener, utils::SERVER, Ready::readable(), - PollOpt::edge()) + poll.register( + &self.listener, + utils::SERVER, + Ready::readable(), + PollOpt::edge(), + ) } fn deregister_client(&mut self, poll: &Poll, id: ClientId) { @@ -279,11 +306,20 @@ } } - fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) { - poll.register(client_socket.inner(), Token(id), - Ready::readable() | Ready::writable(), - PollOpt::edge()) - .expect("could not register socket with event loop"); + fn register_client( + &mut self, + poll: &Poll, + id: ClientId, + client_socket: ClientSocket, + addr: SocketAddr, + ) { + poll.register( + client_socket.inner(), + Token(id), + Ready::readable() | Ready::writable(), + PollOpt::edge(), + ) + .expect("could not register socket with event loop"); let entry = self.clients.vacant_entry(); let client = NetworkClient::new(id, client_socket, addr); @@ -299,26 +335,29 @@ for client_id in clients { if let Some(client) = self.clients.get_mut(client_id) { client.send_string(&msg_string); - self.pending.insert((client_id, NetworkClientState::NeedsWrite)); + self.pending + .insert((client_id, NetworkClientState::NeedsWrite)); } } } } fn create_client_socket(&self, socket: TcpStream) -> io::Result { - #[cfg(not(feature = "tls-connections"))] { + #[cfg(not(feature = "tls-connections"))] + { Ok(ClientSocket::Plain(socket)) } - #[cfg(feature = "tls-connections")] { + #[cfg(feature = "tls-connections")] + { let ssl = Ssl::new(&self.ssl.context).unwrap(); let mut builder = SslStreamBuilder::new(ssl, socket); builder.set_accept_state(); match builder.handshake() { - Ok(stream) => - Ok(ClientSocket::SslStream(stream)), - Err(HandshakeError::WouldBlock(stream)) => - Ok(ClientSocket::SslHandshake(Some(stream))), + Ok(stream) => Ok(ClientSocket::SslStream(stream)), + Err(HandshakeError::WouldBlock(stream)) => { + Ok(ClientSocket::SslHandshake(Some(stream))) + } Err(e) => { debug!("OpenSSL handshake failed: {}", e); Err(Error::new(ErrorKind::Other, "Connection failure")) @@ -332,13 +371,24 @@ info!("Connected: {}", addr); let client_id = self.server.add_client(); - self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr); + self.register_client( + poll, + client_id, + self.create_client_socket(client_socket)?, + addr, + ); self.flush_server_messages(); Ok(()) } - fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> { + fn operation_failed( + &mut self, + poll: &Poll, + client_id: ClientId, + error: &Error, + msg: &str, + ) -> io::Result<()> { let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) { client.peer_addr } else { @@ -348,15 +398,13 @@ self.client_error(poll, client_id) } - pub fn client_readable(&mut self, poll: &Poll, - client_id: ClientId) -> io::Result<()> { - let messages = - if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.read() - } else { - warn!("invalid readable client: {}", client_id); - Ok((Vec::new(), NetworkClientState::Idle)) - }; + pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { + let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) { + client.read() + } else { + warn!("invalid readable client: {}", client_id); + Ok((Vec::new(), NetworkClientState::Idle)) + }; match messages { Ok((messages, state)) => { @@ -366,15 +414,17 @@ match state { NetworkClientState::NeedsRead => { self.pending.insert((client_id, state)); - }, - NetworkClientState::Closed => - self.client_error(&poll, client_id)?, + } + NetworkClientState::Closed => self.client_error(&poll, client_id)?, _ => {} }; } Err(e) => self.operation_failed( - poll, client_id, &e, - "Error while reading from client socket")? + poll, + client_id, + &e, + "Error while reading from client socket", + )?, } self.flush_server_messages(); @@ -389,31 +439,28 @@ Ok(()) } - pub fn client_writable(&mut self, poll: &Poll, - client_id: ClientId) -> io::Result<()> { - let result = - if let Some(ref mut client) = self.clients.get_mut(client_id) { - client.write() - } else { - warn!("invalid writable client: {}", client_id); - Ok(((), NetworkClientState::Idle)) - }; + pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { + let result = if let Some(ref mut client) = self.clients.get_mut(client_id) { + client.write() + } else { + warn!("invalid writable client: {}", client_id); + Ok(((), NetworkClientState::Idle)) + }; match result { Ok(((), state)) if state == NetworkClientState::NeedsWrite => { self.pending.insert((client_id, state)); - }, + } Ok(_) => {} - Err(e) => self.operation_failed( - poll, client_id, &e, - "Error while writing to client socket")? + Err(e) => { + self.operation_failed(poll, client_id, &e, "Error while writing to client socket")? + } } Ok(()) } - pub fn client_error(&mut self, poll: &Poll, - client_id: ClientId) -> io::Result<()> { + pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> { self.deregister_client(poll, client_id); self.server.client_lost(client_id); @@ -430,10 +477,8 @@ cache.extend(self.pending.drain()); for (id, state) in cache.drain(..) { match state { - NetworkClientState::NeedsRead => - self.client_readable(poll, id)?, - NetworkClientState::NeedsWrite => - self.client_writable(poll, id)?, + NetworkClientState::NeedsRead => self.client_readable(poll, id)?, + NetworkClientState::NeedsWrite => self.client_writable(poll, id)?, _ => {} } }