--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/rust/hedgewars-server/src/server/network.rs Mon Dec 10 22:44:46 2018 +0100
@@ -0,0 +1,444 @@
+extern crate slab;
+
+use std::{
+ io, io::{Error, ErrorKind, Read, Write},
+ net::{SocketAddr, IpAddr, Ipv4Addr},
+ collections::HashSet,
+ mem::{swap, replace}
+};
+
+use mio::{
+ net::{TcpStream, TcpListener},
+ Poll, PollOpt, Ready, Token
+};
+use netbuf;
+use slab::Slab;
+use log::*;
+
+use crate::{
+ utils,
+ protocol::{ProtocolDecoder, messages::*}
+};
+use super::{
+ io::FileServerIO,
+ core::{HWServer},
+ coretypes::ClientId
+};
+#[cfg(feature = "tls-connections")]
+use openssl::{
+ ssl::{
+ SslMethod, SslContext, Ssl, SslContextBuilder,
+ SslVerifyMode, SslFiletype, SslOptions,
+ SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream
+ },
+ error::ErrorStack
+};
+
+const MAX_BYTES_PER_READ: usize = 2048;
+
+#[derive(Hash, Eq, PartialEq, Copy, Clone)]
+pub enum NetworkClientState {
+ Idle,
+ NeedsWrite,
+ NeedsRead,
+ Closed,
+}
+
+type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
+
+#[cfg(not(feature = "tls-connections"))]
+pub enum ClientSocket {
+ Plain(TcpStream)
+}
+
+#[cfg(feature = "tls-connections")]
+pub enum ClientSocket {
+ SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
+ SslStream(SslStream<TcpStream>)
+}
+
+impl ClientSocket {
+ fn inner(&self) -> &TcpStream {
+ #[cfg(not(feature = "tls-connections"))]
+ match self {
+ ClientSocket::Plain(stream) => stream,
+ }
+
+ #[cfg(feature = "tls-connections")]
+ match self {
+ ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
+ ClientSocket::SslHandshake(None) => unreachable!(),
+ ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref()
+ }
+ }
+}
+
+pub struct NetworkClient {
+ id: ClientId,
+ socket: ClientSocket,
+ peer_addr: SocketAddr,
+ decoder: ProtocolDecoder,
+ buf_out: netbuf::Buf
+}
+
+impl NetworkClient {
+ pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
+ NetworkClient {
+ id, socket, peer_addr,
+ decoder: ProtocolDecoder::new(),
+ buf_out: netbuf::Buf::new()
+ }
+ }
+
+ #[cfg(feature = "tls-connections")]
+ fn handshake_impl(&mut self, handshake: MidHandshakeSslStream<TcpStream>) -> io::Result<NetworkClientState> {
+ match handshake.handshake() {
+ Ok(stream) => {
+ self.socket = ClientSocket::SslStream(stream);
+ debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr);
+ Ok(NetworkClientState::Idle)
+ }
+ Err(HandshakeError::WouldBlock(new_handshake)) => {
+ self.socket = ClientSocket::SslHandshake(Some(new_handshake));
+ Ok(NetworkClientState::Idle)
+ }
+ Err(HandshakeError::Failure(new_handshake)) => {
+ self.socket = ClientSocket::SslHandshake(Some(new_handshake));
+ debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
+ Err(Error::new(ErrorKind::Other, "Connection failure"))
+ }
+ Err(HandshakeError::SetupFailure(_)) => unreachable!()
+ }
+ }
+
+ fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R,
+ id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> {
+ let mut bytes_read = 0;
+ let result = loop {
+ match decoder.read_from(source) {
+ Ok(bytes) => {
+ debug!("Client {}: read {} bytes", id, bytes);
+ bytes_read += bytes;
+ if bytes == 0 {
+ let result = if bytes_read == 0 {
+ info!("EOF for client {} ({})", id, addr);
+ (Vec::new(), NetworkClientState::Closed)
+ } else {
+ (decoder.extract_messages(), NetworkClientState::NeedsRead)
+ };
+ break Ok(result);
+ }
+ else if bytes_read >= MAX_BYTES_PER_READ {
+ break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead))
+ }
+ }
+ Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
+ let messages = if bytes_read == 0 {
+ Vec::new()
+ } else {
+ decoder.extract_messages()
+ };
+ break Ok((messages, NetworkClientState::Idle));
+ }
+ Err(error) =>
+ break Err(error)
+ }
+ };
+ decoder.sweep();
+ result
+ }
+
+ pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
+ #[cfg(not(feature = "tls-connections"))]
+ match self.socket {
+ ClientSocket::Plain(ref mut stream) =>
+ NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr),
+ }
+
+ #[cfg(feature = "tls-connections")]
+ match self.socket {
+ ClientSocket::SslHandshake(ref mut handshake_opt) => {
+ let handshake = std::mem::replace(handshake_opt, None).unwrap();
+ Ok((Vec::new(), self.handshake_impl(handshake)?))
+ },
+ ClientSocket::SslStream(ref mut stream) =>
+ NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
+ }
+ }
+
+ fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
+ let result = loop {
+ match buf_out.write_to(destination) {
+ Ok(bytes) if buf_out.is_empty() || bytes == 0 =>
+ break Ok(((), NetworkClientState::Idle)),
+ Ok(_) => (),
+ Err(ref error) if error.kind() == ErrorKind::Interrupted
+ || error.kind() == ErrorKind::WouldBlock => {
+ break Ok(((), NetworkClientState::NeedsWrite));
+ },
+ Err(error) =>
+ break Err(error)
+ }
+ };
+ result
+ }
+
+ pub fn write(&mut self) -> NetworkResult<()> {
+ let result = {
+ #[cfg(not(feature = "tls-connections"))]
+ match self.socket {
+ ClientSocket::Plain(ref mut stream) =>
+ NetworkClient::write_impl(&mut self.buf_out, stream)
+ }
+
+ #[cfg(feature = "tls-connections")] {
+ match self.socket {
+ ClientSocket::SslHandshake(ref mut handshake_opt) => {
+ let handshake = std::mem::replace(handshake_opt, None).unwrap();
+ Ok(((), self.handshake_impl(handshake)?))
+ }
+ ClientSocket::SslStream(ref mut stream) =>
+ NetworkClient::write_impl(&mut self.buf_out, stream)
+ }
+ }
+ };
+
+ self.socket.inner().flush()?;
+ result
+ }
+
+ pub fn send_raw_msg(&mut self, msg: &[u8]) {
+ self.buf_out.write_all(msg).unwrap();
+ }
+
+ pub fn send_string(&mut self, msg: &str) {
+ self.send_raw_msg(&msg.as_bytes());
+ }
+
+ pub fn send_msg(&mut self, msg: &HWServerMessage) {
+ self.send_string(&msg.to_raw_protocol());
+ }
+}
+
+#[cfg(feature = "tls-connections")]
+struct ServerSsl {
+ context: SslContext
+}
+
+pub struct NetworkLayer {
+ listener: TcpListener,
+ server: HWServer,
+ clients: Slab<NetworkClient>,
+ pending: HashSet<(ClientId, NetworkClientState)>,
+ pending_cache: Vec<(ClientId, NetworkClientState)>,
+ #[cfg(feature = "tls-connections")]
+ ssl: ServerSsl
+}
+
+impl NetworkLayer {
+ pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
+ let server = HWServer::new(clients_limit, rooms_limit, Box::new(FileServerIO::new()));
+ let clients = Slab::with_capacity(clients_limit);
+ let pending = HashSet::with_capacity(2 * clients_limit);
+ let pending_cache = Vec::with_capacity(2 * clients_limit);
+
+ NetworkLayer {
+ listener, server, clients, pending, pending_cache,
+ #[cfg(feature = "tls-connections")]
+ ssl: NetworkLayer::create_ssl_context()
+ }
+ }
+
+ #[cfg(feature = "tls-connections")]
+ fn create_ssl_context() -> ServerSsl {
+ let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
+ builder.set_verify(SslVerifyMode::NONE);
+ builder.set_read_ahead(true);
+ builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap();
+ builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap();
+ builder.set_options(SslOptions::NO_COMPRESSION);
+ builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
+ ServerSsl { context: builder.build() }
+ }
+
+ pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
+ poll.register(&self.listener, utils::SERVER, Ready::readable(),
+ PollOpt::edge())
+ }
+
+ fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
+ let mut client_exists = false;
+ if let Some(ref client) = self.clients.get(id) {
+ poll.deregister(client.socket.inner())
+ .expect("could not deregister socket");
+ info!("client {} ({}) removed", client.id, client.peer_addr);
+ client_exists = true;
+ }
+ if client_exists {
+ self.clients.remove(id);
+ }
+ }
+
+ fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) {
+ poll.register(client_socket.inner(), Token(id),
+ Ready::readable() | Ready::writable(),
+ PollOpt::edge())
+ .expect("could not register socket with event loop");
+
+ let entry = self.clients.vacant_entry();
+ let client = NetworkClient::new(id, client_socket, addr);
+ info!("client {} ({}) added", client.id, client.peer_addr);
+ entry.insert(client);
+ }
+
+ fn flush_server_messages(&mut self) {
+ debug!("{} pending server messages", self.server.output.len());
+ for (clients, message) in self.server.output.drain(..) {
+ debug!("Message {:?} to {:?}", message, clients);
+ let msg_string = message.to_raw_protocol();
+ for client_id in clients {
+ if let Some(client) = self.clients.get_mut(client_id) {
+ client.send_string(&msg_string);
+ self.pending.insert((client_id, NetworkClientState::NeedsWrite));
+ }
+ }
+ }
+ }
+
+ fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
+ #[cfg(not(feature = "tls-connections"))] {
+ Ok(ClientSocket::Plain(socket))
+ }
+
+ #[cfg(feature = "tls-connections")] {
+ let ssl = Ssl::new(&self.ssl.context).unwrap();
+ let mut builder = SslStreamBuilder::new(ssl, socket);
+ builder.set_accept_state();
+ match builder.handshake() {
+ Ok(stream) =>
+ Ok(ClientSocket::SslStream(stream)),
+ Err(HandshakeError::WouldBlock(stream)) =>
+ Ok(ClientSocket::SslHandshake(Some(stream))),
+ Err(e) => {
+ debug!("OpenSSL handshake failed: {}", e);
+ Err(Error::new(ErrorKind::Other, "Connection failure"))
+ }
+ }
+ }
+ }
+
+ pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
+ let (client_socket, addr) = self.listener.accept()?;
+ info!("Connected: {}", addr);
+
+ let client_id = self.server.add_client();
+ self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr);
+ self.flush_server_messages();
+
+ Ok(())
+ }
+
+ fn operation_failed(&mut self, poll: &Poll, client_id: ClientId, error: &Error, msg: &str) -> io::Result<()> {
+ let addr = if let Some(ref mut client) = self.clients.get_mut(client_id) {
+ client.peer_addr
+ } else {
+ SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
+ };
+ debug!("{}({}): {}", msg, addr, error);
+ self.client_error(poll, client_id)
+ }
+
+ pub fn client_readable(&mut self, poll: &Poll,
+ client_id: ClientId) -> io::Result<()> {
+ let messages =
+ if let Some(ref mut client) = self.clients.get_mut(client_id) {
+ client.read()
+ } else {
+ warn!("invalid readable client: {}", client_id);
+ Ok((Vec::new(), NetworkClientState::Idle))
+ };
+
+ match messages {
+ Ok((messages, state)) => {
+ for message in messages {
+ self.server.handle_msg(client_id, message);
+ }
+ match state {
+ NetworkClientState::NeedsRead => {
+ self.pending.insert((client_id, state));
+ },
+ NetworkClientState::Closed =>
+ self.client_error(&poll, client_id)?,
+ _ => {}
+ };
+ }
+ Err(e) => self.operation_failed(
+ poll, client_id, &e,
+ "Error while reading from client socket")?
+ }
+
+ self.flush_server_messages();
+
+ if !self.server.removed_clients.is_empty() {
+ let ids: Vec<_> = self.server.removed_clients.drain(..).collect();
+ for client_id in ids {
+ self.deregister_client(poll, client_id);
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn client_writable(&mut self, poll: &Poll,
+ client_id: ClientId) -> io::Result<()> {
+ let result =
+ if let Some(ref mut client) = self.clients.get_mut(client_id) {
+ client.write()
+ } else {
+ warn!("invalid writable client: {}", client_id);
+ Ok(((), NetworkClientState::Idle))
+ };
+
+ match result {
+ Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
+ self.pending.insert((client_id, state));
+ },
+ Ok(_) => {}
+ Err(e) => self.operation_failed(
+ poll, client_id, &e,
+ "Error while writing to client socket")?
+ }
+
+ Ok(())
+ }
+
+ pub fn client_error(&mut self, poll: &Poll,
+ client_id: ClientId) -> io::Result<()> {
+ self.deregister_client(poll, client_id);
+ self.server.client_lost(client_id);
+
+ Ok(())
+ }
+
+ pub fn has_pending_operations(&self) -> bool {
+ !self.pending.is_empty()
+ }
+
+ pub fn on_idle(&mut self, poll: &Poll) -> io::Result<()> {
+ if self.has_pending_operations() {
+ let mut cache = replace(&mut self.pending_cache, Vec::new());
+ cache.extend(self.pending.drain());
+ for (id, state) in cache.drain(..) {
+ match state {
+ NetworkClientState::NeedsRead =>
+ self.client_readable(poll, id)?,
+ NetworkClientState::NeedsWrite =>
+ self.client_writable(poll, id)?,
+ _ => {}
+ }
+ }
+ swap(&mut cache, &mut self.pending_cache);
+ }
+ Ok(())
+ }
+}