--- a/rust/hedgewars-server/src/server/network.rs Sun Dec 16 00:09:20 2018 +0100
+++ b/rust/hedgewars-server/src/server/network.rs Sun Dec 16 00:12:29 2018 +0100
@@ -1,37 +1,33 @@
extern crate slab;
use std::{
- io, io::{Error, ErrorKind, Read, Write},
- net::{SocketAddr, IpAddr, Ipv4Addr},
collections::HashSet,
- mem::{swap, replace}
+ io,
+ io::{Error, ErrorKind, Read, Write},
+ mem::{replace, swap},
+ net::{IpAddr, Ipv4Addr, SocketAddr},
};
+use log::*;
use mio::{
- net::{TcpStream, TcpListener},
- Poll, PollOpt, Ready, Token
+ net::{TcpListener, TcpStream},
+ Poll, PollOpt, Ready, Token,
};
use netbuf;
use slab::Slab;
-use log::*;
+use super::{core::HWServer, coretypes::ClientId, io::FileServerIO};
use crate::{
+ protocol::{messages::*, ProtocolDecoder},
utils,
- protocol::{ProtocolDecoder, messages::*}
-};
-use super::{
- io::FileServerIO,
- core::{HWServer},
- coretypes::ClientId
};
#[cfg(feature = "tls-connections")]
use openssl::{
+ error::ErrorStack,
ssl::{
- SslMethod, SslContext, Ssl, SslContextBuilder,
- SslVerifyMode, SslFiletype, SslOptions,
- SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream
+ HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslContextBuilder, SslFiletype,
+ SslMethod, SslOptions, SslStream, SslStreamBuilder, SslVerifyMode,
},
- error::ErrorStack
};
const MAX_BYTES_PER_READ: usize = 2048;
@@ -48,13 +44,13 @@
#[cfg(not(feature = "tls-connections"))]
pub enum ClientSocket {
- Plain(TcpStream)
+ Plain(TcpStream),
}
#[cfg(feature = "tls-connections")]
pub enum ClientSocket {
SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
- SslStream(SslStream<TcpStream>)
+ SslStream(SslStream<TcpStream>),
}
impl ClientSocket {
@@ -68,7 +64,7 @@
match self {
ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
ClientSocket::SslHandshake(None) => unreachable!(),
- ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref()
+ ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
}
}
}
@@ -78,24 +74,32 @@
socket: ClientSocket,
peer_addr: SocketAddr,
decoder: ProtocolDecoder,
- buf_out: netbuf::Buf
+ buf_out: netbuf::Buf,
}
impl NetworkClient {
pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
NetworkClient {
- id, socket, peer_addr,
+ id,
+ socket,
+ peer_addr,
decoder: ProtocolDecoder::new(),
- buf_out: netbuf::Buf::new()
+ buf_out: netbuf::Buf::new(),
}
}
#[cfg(feature = "tls-connections")]
- fn handshake_impl(&mut self, handshake: MidHandshakeSslStream<TcpStream>) -> io::Result<NetworkClientState> {
+ fn handshake_impl(
+ &mut self,
+ handshake: MidHandshakeSslStream<TcpStream>,
+ ) -> io::Result<NetworkClientState> {
match handshake.handshake() {
Ok(stream) => {
self.socket = ClientSocket::SslStream(stream);
- debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr);
+ debug!(
+ "TLS handshake with {} ({}) completed",
+ self.id, self.peer_addr
+ );
Ok(NetworkClientState::Idle)
}
Err(HandshakeError::WouldBlock(new_handshake)) => {
@@ -107,12 +111,16 @@
debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
Err(Error::new(ErrorKind::Other, "Connection failure"))
}
- Err(HandshakeError::SetupFailure(_)) => unreachable!()
+ Err(HandshakeError::SetupFailure(_)) => unreachable!(),
}
}
- fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R,
- id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> {
+ fn read_impl<R: Read>(
+ decoder: &mut ProtocolDecoder,
+ source: &mut R,
+ id: ClientId,
+ addr: &SocketAddr,
+ ) -> NetworkResult<Vec<HWProtocolMessage>> {
let mut bytes_read = 0;
let result = loop {
match decoder.read_from(source) {
@@ -127,21 +135,19 @@
(decoder.extract_messages(), NetworkClientState::NeedsRead)
};
break Ok(result);
- }
- else if bytes_read >= MAX_BYTES_PER_READ {
- break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead))
+ } else if bytes_read >= MAX_BYTES_PER_READ {
+ break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead));
}
}
Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
- let messages = if bytes_read == 0 {
+ let messages = if bytes_read == 0 {
Vec::new()
} else {
decoder.extract_messages()
};
break Ok((messages, NetworkClientState::Idle));
}
- Err(error) =>
- break Err(error)
+ Err(error) => break Err(error),
}
};
decoder.sweep();
@@ -151,8 +157,9 @@
pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
#[cfg(not(feature = "tls-connections"))]
match self.socket {
- ClientSocket::Plain(ref mut stream) =>
- NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr),
+ ClientSocket::Plain(ref mut stream) => {
+ NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
+ }
}
#[cfg(feature = "tls-connections")]
@@ -160,24 +167,27 @@
ClientSocket::SslHandshake(ref mut handshake_opt) => {
let handshake = std::mem::replace(handshake_opt, None).unwrap();
Ok((Vec::new(), self.handshake_impl(handshake)?))
- },
- ClientSocket::SslStream(ref mut stream) =>
+ }
+ ClientSocket::SslStream(ref mut stream) => {
NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
+ }
}
}
fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
let result = loop {
match buf_out.write_to(destination) {
- Ok(bytes) if buf_out.is_empty() || bytes == 0 =>
- break Ok(((), NetworkClientState::Idle)),
+ Ok(bytes) if buf_out.is_empty() || bytes == 0 => {
+ break Ok(((), NetworkClientState::Idle))
+ }
Ok(_) => (),
- Err(ref error) if error.kind() == ErrorKind::Interrupted
- || error.kind() == ErrorKind::WouldBlock => {
+ Err(ref error)
+ if error.kind() == ErrorKind::Interrupted
+ || error.kind() == ErrorKind::WouldBlock =>
+ {
break Ok(((), NetworkClientState::NeedsWrite));
- },
- Err(error) =>
- break Err(error)
+ }
+ Err(error) => break Err(error),
}
};
result
@@ -187,18 +197,21 @@
let result = {
#[cfg(not(feature = "tls-connections"))]
match self.socket {
- ClientSocket::Plain(ref mut stream) =>
+ ClientSocket::Plain(ref mut stream) => {
NetworkClient::write_impl(&mut self.buf_out, stream)
+ }
}
- #[cfg(feature = "tls-connections")] {
+ #[cfg(feature = "tls-connections")]
+ {
match self.socket {
ClientSocket::SslHandshake(ref mut handshake_opt) => {
let handshake = std::mem::replace(handshake_opt, None).unwrap();
Ok(((), self.handshake_impl(handshake)?))
}
- ClientSocket::SslStream(ref mut stream) =>
+ ClientSocket::SslStream(ref mut stream) => {
NetworkClient::write_impl(&mut self.buf_out, stream)
+ }
}
}
};
@@ -222,7 +235,7 @@
#[cfg(feature = "tls-connections")]
struct ServerSsl {
- context: SslContext
+ context: SslContext,
}
pub struct NetworkLayer {
@@ -232,7 +245,7 @@
pending: HashSet<(ClientId, NetworkClientState)>,
pending_cache: Vec<(ClientId, NetworkClientState)>,
#[cfg(feature = "tls-connections")]
- ssl: ServerSsl
+ ssl: ServerSsl,
}
impl NetworkLayer {
@@ -243,9 +256,13 @@
let pending_cache = Vec::with_capacity(2 * clients_limit);
NetworkLayer {
- listener, server, clients, pending, pending_cache,
+ listener,
+ server,
+ clients,
+ pending,
+ pending_cache,
#[cfg(feature = "tls-connections")]
- ssl: NetworkLayer::create_ssl_context()
+ ssl: NetworkLayer::create_ssl_context(),
}
}
@@ -254,16 +271,26 @@
let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
builder.set_read_ahead(true);
- builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap();
- builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap();
+ builder
+ .set_certificate_file("ssl/cert.pem", SslFiletype::PEM)
+ .unwrap();
+ builder
+ .set_private_key_file("ssl/key.pem", SslFiletype::PEM)
+ .unwrap();
builder.set_options(SslOptions::NO_COMPRESSION);
builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
- ServerSsl { context: builder.build() }
+ ServerSsl {
+ context: builder.build(),
+ }
}
pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
- poll.register(&self.listener, utils::SERVER, Ready::readable(),
- PollOpt::edge())
+ poll.register(
+ &self.listener,
+ utils::SERVER,
+ Ready::readable(),
+ PollOpt::edge(),
+ )
}
fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
@@ -279,11 +306,20 @@
}
}
- fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) {
- poll.register(client_socket.inner(), Token(id),
- Ready::readable() | Ready::writable(),
- PollOpt::edge())
- .expect("could not register socket with event loop");
+ fn register_client(
+ &mut self,
+ poll: &Poll,
+ id: ClientId,
+ client_socket: ClientSocket,
+ addr: SocketAddr,
+ ) {
+ poll.register(
+ client_socket.inner(),
+ Token(id),
+ Ready::readable() | Ready::writable(),
+ PollOpt::edge(),
+ )
+ .expect("could not register socket with event loop");
let entry = self.clients.vacant_entry();
let client = NetworkClient::new(id, client_socket, addr);
@@ -299,26 +335,29 @@
for client_id in clients {
if let Some(client) = self.clients.get_mut(client_id) {
client.send_string(&msg_string);
- self.pending.insert((client_id, NetworkClientState::NeedsWrite));
+ self.pending
+ .insert((client_id, NetworkClientState::NeedsWrite));
}
}
}
}
fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
- #[cfg(not(feature = "tls-connections"))] {
+ #[cfg(not(feature = "tls-connections"))]
+ {
Ok(ClientSocket::Plain(socket))
}
- #[cfg(feature = "tls-connections")] {
+ #[cfg(feature = "tls-connections")]
+ {
let ssl = Ssl::new(&self.ssl.context).unwrap();
let mut builder = SslStreamBuilder::new(ssl, socket);
builder.set_accept_state();
match builder.handshake() {
- Ok(stream) =>
- Ok(ClientSocket::SslStream(stream)),
- Err(HandshakeError::WouldBlock(stream)) =>
- Ok(ClientSocket::SslHandshake(Some(stream))),
+ Ok(stream) => Ok(ClientSocket::SslStream(stream)),
+ Err(HandshakeError::WouldBlock(stream)) => {
+ Ok(ClientSocket::SslHandshake(Some(stream)))
+ }
Err(e) => {
debug!("OpenSSL handshake failed: {}", e);
Err(Error::new(ErrorKind::Other, "Connection failure"))
@@ -332,13 +371,24 @@
info!("Connected: {}", addr);
let client_id = self.server.add_client();
- self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr);
+ self.register_client(
+ poll,
+ client_id,
+ self.create_client_socket(client_socket)?,
+ addr,
+ );
self.flush_server_messages();
Ok(())
}
- fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> {
+ fn operation_failed(
+ &mut self,
+ poll: &Poll,
+ client_id: ClientId,
+ error: &Error,
+ msg: &str,
+ ) -> io::Result<()> {
let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
client.peer_addr
} else {
@@ -348,15 +398,13 @@
self.client_error(poll, client_id)
}
- pub fn client_readable(&mut self, poll: &Poll,
- client_id: ClientId) -> io::Result<()> {
- let messages =
- if let Some(ref mut client) = self.clients.get_mut(client_id) {
- client.read()
- } else {
- warn!("invalid readable client: {}", client_id);
- Ok((Vec::new(), NetworkClientState::Idle))
- };
+ pub fn client_readable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
+ let messages = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+ client.read()
+ } else {
+ warn!("invalid readable client: {}", client_id);
+ Ok((Vec::new(), NetworkClientState::Idle))
+ };
match messages {
Ok((messages, state)) => {
@@ -366,15 +414,17 @@
match state {
NetworkClientState::NeedsRead => {
self.pending.insert((client_id, state));
- },
- NetworkClientState::Closed =>
- self.client_error(&poll, client_id)?,
+ }
+ NetworkClientState::Closed => self.client_error(&poll, client_id)?,
_ => {}
};
}
Err(e) => self.operation_failed(
- poll, client_id, &e,
- "Error while reading from client socket")?
+ poll,
+ client_id,
+ &e,
+ "Error while reading from client socket",
+ )?,
}
self.flush_server_messages();
@@ -389,31 +439,28 @@
Ok(())
}
- pub fn client_writable(&mut self, poll: &Poll,
- client_id: ClientId) -> io::Result<()> {
- let result =
- if let Some(ref mut client) = self.clients.get_mut(client_id) {
- client.write()
- } else {
- warn!("invalid writable client: {}", client_id);
- Ok(((), NetworkClientState::Idle))
- };
+ pub fn client_writable(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
+ let result = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+ client.write()
+ } else {
+ warn!("invalid writable client: {}", client_id);
+ Ok(((), NetworkClientState::Idle))
+ };
match result {
Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
self.pending.insert((client_id, state));
- },
+ }
Ok(_) => {}
- Err(e) => self.operation_failed(
- poll, client_id, &e,
- "Error while writing to client socket")?
+ Err(e) => {
+ self.operation_failed(poll, client_id, &e, "Error while writing to client socket")?
+ }
}
Ok(())
}
- pub fn client_error(&mut self, poll: &Poll,
- client_id: ClientId) -> io::Result<()> {
+ pub fn client_error(&mut self, poll: &Poll, client_id: ClientId) -> io::Result<()> {
self.deregister_client(poll, client_id);
self.server.client_lost(client_id);
@@ -430,10 +477,8 @@
cache.extend(self.pending.drain());
for (id, state) in cache.drain(..) {
match state {
- NetworkClientState::NeedsRead =>
- self.client_readable(poll, id)?,
- NetworkClientState::NeedsWrite =>
- self.client_writable(poll, id)?,
+ NetworkClientState::NeedsRead => self.client_readable(poll, id)?,
+ NetworkClientState::NeedsWrite => self.client_writable(poll, id)?,
_ => {}
}
}