diff -r 6843c4551cde -r 06672690d71b rust/hedgewars-server/src/server/network.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/rust/hedgewars-server/src/server/network.rs Mon Dec 10 22:44:46 2018 +0100 @@ -0,0 +1,444 @@ +extern crate slab; + +use std::{ + io, io::{Error, ErrorKind, Read, Write}, + net::{SocketAddr, IpAddr, Ipv4Addr}, + collections::HashSet, + mem::{swap, replace} +}; + +use mio::{ + net::{TcpStream, TcpListener}, + Poll, PollOpt, Ready, Token +}; +use netbuf; +use slab::Slab; +use log::*; + +use crate::{ + utils, + protocol::{ProtocolDecoder, messages::*} +}; +use super::{ + io::FileServerIO, + core::{HWServer}, + coretypes::ClientId +}; +#[cfg(feature = "tls-connections")] +use openssl::{ + ssl::{ + SslMethod, SslContext, Ssl, SslContextBuilder, + SslVerifyMode, SslFiletype, SslOptions, + SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream + }, + error::ErrorStack +}; + +const MAX_BYTES_PER_READ: usize = 2048; + +#[derive(Hash, Eq, PartialEq, Copy, Clone)] +pub enum NetworkClientState { + Idle, + NeedsWrite, + NeedsRead, + Closed, +} + +type NetworkResult = io::Result<(T, NetworkClientState)>; + +#[cfg(not(feature = "tls-connections"))] +pub enum ClientSocket { + Plain(TcpStream) +} + +#[cfg(feature = "tls-connections")] +pub enum ClientSocket { + SslHandshake(Option>), + SslStream(SslStream) +} + +impl ClientSocket { + fn inner(&self) -> &TcpStream { + #[cfg(not(feature = "tls-connections"))] + match self { + ClientSocket::Plain(stream) => stream, + } + + #[cfg(feature = "tls-connections")] + match self { + ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), + ClientSocket::SslHandshake(None) => unreachable!(), + ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref() + } + } +} + +pub struct NetworkClient { + id: ClientId, + socket: ClientSocket, + peer_addr: SocketAddr, + decoder: ProtocolDecoder, + buf_out: netbuf::Buf +} + +impl NetworkClient { + pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient { + NetworkClient { + id, socket, peer_addr, + decoder: ProtocolDecoder::new(), + buf_out: netbuf::Buf::new() + } + } + + #[cfg(feature = "tls-connections")] + fn handshake_impl(&mut self, handshake: MidHandshakeSslStream) -> io::Result { + match handshake.handshake() { + Ok(stream) => { + self.socket = ClientSocket::SslStream(stream); + debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr); + Ok(NetworkClientState::Idle) + } + Err(HandshakeError::WouldBlock(new_handshake)) => { + self.socket = ClientSocket::SslHandshake(Some(new_handshake)); + Ok(NetworkClientState::Idle) + } + Err(HandshakeError::Failure(new_handshake)) => { + self.socket = ClientSocket::SslHandshake(Some(new_handshake)); + debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); + Err(Error::new(ErrorKind::Other, "Connection failure")) + } + Err(HandshakeError::SetupFailure(_)) => unreachable!() + } + } + + fn read_impl(decoder: &mut ProtocolDecoder, source: &mut R, + id: ClientId, addr: &SocketAddr) -> NetworkResult> { + let mut bytes_read = 0; + let result = loop { + match decoder.read_from(source) { + Ok(bytes) => { + debug!("Client {}: read {} bytes", id, bytes); + bytes_read += bytes; + if bytes == 0 { + let result = if bytes_read == 0 { + info!("EOF for client {} ({})", id, addr); + (Vec::new(), NetworkClientState::Closed) + } else { + (decoder.extract_messages(), NetworkClientState::NeedsRead) + }; + break Ok(result); + } + else if bytes_read >= MAX_BYTES_PER_READ { + break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)) + } + } + Err(ref error) if error.kind() == ErrorKind::WouldBlock => { + let messages = if bytes_read == 0 { + Vec::new() + } else { + decoder.extract_messages() + }; + break Ok((messages, NetworkClientState::Idle)); + } + Err(error) => + break Err(error) + } + }; + decoder.sweep(); + result + } + + pub fn read(&mut self) -> NetworkResult> { + #[cfg(not(feature = "tls-connections"))] + match self.socket { + ClientSocket::Plain(ref mut stream) => + NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr), + } + + #[cfg(feature = "tls-connections")] + match self.socket { + ClientSocket::SslHandshake(ref mut handshake_opt) => { + let handshake = std::mem::replace(handshake_opt, None).unwrap(); + Ok((Vec::new(), self.handshake_impl(handshake)?)) + }, + ClientSocket::SslStream(ref mut stream) => + NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) + } + } + + fn write_impl(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> { + let result = loop { + match buf_out.write_to(destination) { + Ok(bytes) if buf_out.is_empty() || bytes == 0 => + break Ok(((), NetworkClientState::Idle)), + Ok(_) => (), + Err(ref error) if error.kind() == ErrorKind::Interrupted + || error.kind() == ErrorKind::WouldBlock => { + break Ok(((), NetworkClientState::NeedsWrite)); + }, + Err(error) => + break Err(error) + } + }; + result + } + + pub fn write(&mut self) -> NetworkResult<()> { + let result = { + #[cfg(not(feature = "tls-connections"))] + match self.socket { + ClientSocket::Plain(ref mut stream) => + NetworkClient::write_impl(&mut self.buf_out, stream) + } + + #[cfg(feature = "tls-connections")] { + match self.socket { + ClientSocket::SslHandshake(ref mut handshake_opt) => { + let handshake = std::mem::replace(handshake_opt, None).unwrap(); + Ok(((), self.handshake_impl(handshake)?)) + } + ClientSocket::SslStream(ref mut stream) => + NetworkClient::write_impl(&mut self.buf_out, stream) + } + } + }; + + self.socket.inner().flush()?; + result + } + + pub fn send_raw_msg(&mut self, msg: &[u8]) { + self.buf_out.write_all(msg).unwrap(); + } + + pub fn send_string(&mut self, msg: &str) { + self.send_raw_msg(&msg.as_bytes()); + } + + pub fn send_msg(&mut self, msg: &HWServerMessage) { + self.send_string(&msg.to_raw_protocol()); + } +} + +#[cfg(feature = "tls-connections")] +struct ServerSsl { + context: SslContext +} + +pub struct NetworkLayer { + listener: TcpListener, + server: HWServer, + clients: Slab, + pending: HashSet<(ClientId, NetworkClientState)>, + pending_cache: Vec<(ClientId, NetworkClientState)>, + #[cfg(feature = "tls-connections")] + ssl: ServerSsl +} + +impl NetworkLayer { + pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { + let server = HWServer::new(clients_limit, rooms_limit, Box::new(FileServerIO::new())); + let clients = Slab::with_capacity(clients_limit); + let pending = HashSet::with_capacity(2 * clients_limit); + let pending_cache = Vec::with_capacity(2 * clients_limit); + + NetworkLayer { + listener, server, clients, pending, pending_cache, + #[cfg(feature = "tls-connections")] + ssl: NetworkLayer::create_ssl_context() + } + } + + #[cfg(feature = "tls-connections")] + fn create_ssl_context() -> ServerSsl { + let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + builder.set_read_ahead(true); + builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap(); + builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap(); + builder.set_options(SslOptions::NO_COMPRESSION); + builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); + ServerSsl { context: builder.build() } + } + + pub fn register_server(&self, poll: &Poll) -> io::Result<()> { + poll.register(&self.listener, utils::SERVER, Ready::readable(), + PollOpt::edge()) + } + + fn deregister_client(&mut self, poll: &Poll, id: ClientId) { + let mut client_exists = false; + if let Some(ref client) = self.clients.get(id) { + poll.deregister(client.socket.inner()) + .expect("could not deregister socket"); + info!("client {} ({}) removed", client.id, client.peer_addr); + client_exists = true; + } + if client_exists { + self.clients.remove(id); + } + } + + fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) { + poll.register(client_socket.inner(), Token(id), + Ready::readable() | Ready::writable(), + PollOpt::edge()) + .expect("could not register socket with event loop"); + + let entry = self.clients.vacant_entry(); + let client = NetworkClient::new(id, client_socket, addr); + info!("client {} ({}) added", client.id, client.peer_addr); + entry.insert(client); + } + + fn flush_server_messages(&mut self) { + debug!("{} pending server messages", self.server.output.len()); + for (clients, message) in self.server.output.drain(..) { + debug!("Message {:?} to {:?}", message, clients); + let msg_string = message.to_raw_protocol(); + for client_id in clients { + if let Some(client) = self.clients.get_mut(client_id) { + client.send_string(&msg_string); + self.pending.insert((client_id, NetworkClientState::NeedsWrite)); + } + } + } + } + + fn create_client_socket(&self, socket: TcpStream) -> io::Result { + #[cfg(not(feature = "tls-connections"))] { + Ok(ClientSocket::Plain(socket)) + } + + #[cfg(feature = "tls-connections")] { + let ssl = Ssl::new(&self.ssl.context).unwrap(); + let mut builder = SslStreamBuilder::new(ssl, socket); + builder.set_accept_state(); + match builder.handshake() { + Ok(stream) => + Ok(ClientSocket::SslStream(stream)), + Err(HandshakeError::WouldBlock(stream)) => + Ok(ClientSocket::SslHandshake(Some(stream))), + Err(e) => { + debug!("OpenSSL handshake failed: {}", e); + Err(Error::new(ErrorKind::Other, "Connection failure")) + } + } + } + } + + pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { + let (client_socket, addr) = self.listener.accept()?; + info!("Connected: {}", addr); + + let client_id = self.server.add_client(); + self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr); + self.flush_server_messages(); + + Ok(()) + } + + fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> { + let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) { + client.peer_addr + } else { + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) + }; + debug!("{}({}): {}", msg, addr, error); + self.client_error(poll, client_id) + } + + pub fn client_readable(&mut self, poll: &Poll, + client_id: ClientId) -> io::Result<()> { + let messages = + if let Some(ref mut client) = self.clients.get_mut(client_id) { + client.read() + } else { + warn!("invalid readable client: {}", client_id); + Ok((Vec::new(), NetworkClientState::Idle)) + }; + + match messages { + Ok((messages, state)) => { + for message in messages { + self.server.handle_msg(client_id, message); + } + match state { + NetworkClientState::NeedsRead => { + self.pending.insert((client_id, state)); + }, + NetworkClientState::Closed => + self.client_error(&poll, client_id)?, + _ => {} + }; + } + Err(e) => self.operation_failed( + poll, client_id, &e, + "Error while reading from client socket")? + } + + self.flush_server_messages(); + + if !self.server.removed_clients.is_empty() { + let ids: Vec<_> = self.server.removed_clients.drain(..).collect(); + for client_id in ids { + self.deregister_client(poll, client_id); + } + } + + Ok(()) + } + + pub fn client_writable(&mut self, poll: &Poll, + client_id: ClientId) -> io::Result<()> { + let result = + if let Some(ref mut client) = self.clients.get_mut(client_id) { + client.write() + } else { + warn!("invalid writable client: {}", client_id); + Ok(((), NetworkClientState::Idle)) + }; + + match result { + Ok(((), state)) if state == NetworkClientState::NeedsWrite => { + self.pending.insert((client_id, state)); + }, + Ok(_) => {} + Err(e) => self.operation_failed( + poll, client_id, &e, + "Error while writing to client socket")? + } + + Ok(()) + } + + pub fn client_error(&mut self, poll: &Poll, + client_id: ClientId) -> io::Result<()> { + self.deregister_client(poll, client_id); + self.server.client_lost(client_id); + + Ok(()) + } + + pub fn has_pending_operations(&self) -> bool { + !self.pending.is_empty() + } + + pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> { + if self.has_pending_operations() { + let mut cache = replace(&mut self.pending_cache, Vec::new()); + cache.extend(self.pending.drain()); + for (id, state) in cache.drain(..) { + match state { + NetworkClientState::NeedsRead => + self.client_readable(poll, id)?, + NetworkClientState::NeedsWrite => + self.client_writable(poll, id)?, + _ => {} + } + } + swap(&mut cache, &mut self.pending_cache); + } + Ok(()) + } +}