add back server TLS support
authoralfadur
Fri, 24 Mar 2023 03:26:08 +0300
changeset 15966 c5c53ebb2d91
parent 15965 cd3d16905e0e
child 15967 e514ceb5e7d6
add back server TLS support
rust/hedgewars-server/Cargo.toml
rust/hedgewars-server/src/server/network.rs
--- a/rust/hedgewars-server/Cargo.toml	Thu Mar 23 23:41:26 2023 +0300
+++ b/rust/hedgewars-server/Cargo.toml	Fri Mar 24 03:26:08 2023 +0300
@@ -5,7 +5,8 @@
 authors = [ "Andrey Korotaev <a.korotaev@hedgewars.org>" ]
 
 [features]
-official-server = ["mysql_async", "sha1"]
+tls-connections = ["tokio-native-tls"]
+official-server = ["mysql_async", "sha1", "tls-connections"]
 default = []
 
 [dependencies]
@@ -25,6 +26,7 @@
 sha1 = { version = "0.10.0", optional = true }
 slab = "0.4"
 tokio = { version = "1.16", features = ["full"]}
+tokio-native-tls = { version = "0.3", optional = true }
 
 hedgewars-network-protocol = { path = "../hedgewars-network-protocol" }
 
--- a/rust/hedgewars-server/src/server/network.rs	Thu Mar 23 23:41:26 2023 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Fri Mar 24 03:26:08 2023 +0300
@@ -1,16 +1,21 @@
 use bytes::{Buf, Bytes};
 use log::*;
 use slab::Slab;
+use std::io::Error;
+use std::pin::Pin;
+use std::task::{Context, Poll};
 use std::{
     iter::Iterator,
     net::{IpAddr, SocketAddr},
     time::Duration,
 };
 use tokio::{
-    io::AsyncReadExt,
+    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
     net::{TcpListener, TcpStream},
     sync::mpsc::{channel, Receiver, Sender},
 };
+#[cfg(feature = "tls-connections")]
+use tokio_native_tls::{TlsAcceptor, TlsStream};
 
 use crate::{
     core::{
@@ -25,7 +30,6 @@
 use hedgewars_network_protocol::{
     messages::HwServerMessage::Redirect, messages::*, parser::server_message,
 };
-use tokio::io::AsyncWriteExt;
 
 const PING_TIMEOUT: Duration = Duration::from_secs(15);
 
@@ -56,9 +60,65 @@
     }
 }
 
+enum ClientStream {
+    Tcp(TcpStream),
+    #[cfg(feature = "tls-connections")]
+    Tls(TlsStream<TcpStream>),
+}
+
+impl Unpin for ClientStream {}
+
+impl AsyncRead for ClientStream {
+    fn poll_read(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &mut ReadBuf<'_>,
+    ) -> Poll<std::io::Result<()>> {
+        use ClientStream::*;
+        match Pin::into_inner(self) {
+            Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
+            #[cfg(feature = "tls-connections")]
+            Tls(stream) => Pin::new(stream).poll_read(cx, buf),
+        }
+    }
+}
+
+impl AsyncWrite for ClientStream {
+    fn poll_write(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &[u8],
+    ) -> Poll<Result<usize, Error>> {
+        use ClientStream::*;
+        match Pin::into_inner(self) {
+            Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
+            #[cfg(feature = "tls-connections")]
+            Tls(stream) => Pin::new(stream).poll_write(cx, buf),
+        }
+    }
+
+    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
+        use ClientStream::*;
+        match Pin::into_inner(self) {
+            Tcp(stream) => Pin::new(stream).poll_flush(cx),
+            #[cfg(feature = "tls-connections")]
+            Tls(stream) => Pin::new(stream).poll_flush(cx),
+        }
+    }
+
+    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
+        use ClientStream::*;
+        match Pin::into_inner(self) {
+            Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
+            #[cfg(feature = "tls-connections")]
+            Tls(stream) => Pin::new(stream).poll_shutdown(cx),
+        }
+    }
+}
+
 struct NetworkClient {
     id: ClientId,
-    socket: TcpStream,
+    stream: ClientStream,
     receiver: Receiver<Bytes>,
     peer_addr: SocketAddr,
     decoder: ProtocolDecoder,
@@ -67,27 +127,27 @@
 impl NetworkClient {
     fn new(
         id: ClientId,
-        socket: TcpStream,
+        stream: ClientStream,
         peer_addr: SocketAddr,
         receiver: Receiver<Bytes>,
     ) -> Self {
         Self {
             id,
-            socket,
+            stream,
             peer_addr,
             receiver,
             decoder: ProtocolDecoder::new(PING_TIMEOUT),
         }
     }
 
-    async fn read(
-        socket: &mut TcpStream,
+    async fn read<T: AsyncRead + AsyncWrite + Unpin>(
+        stream: &mut T,
         decoder: &mut ProtocolDecoder,
     ) -> protocol::Result<HwProtocolMessage> {
-        let result = decoder.read_from(socket).await;
+        let result = decoder.read_from(stream).await;
         if matches!(result, Err(ProtocolError::Timeout)) {
-            if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await {
-                decoder.read_from(socket).await
+            if Self::write(stream, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await {
+                decoder.read_from(stream).await
             } else {
                 Err(ProtocolError::Eof)
             }
@@ -96,8 +156,8 @@
         }
     }
 
-    async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool {
-        !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0)
+    async fn write<T: AsyncWrite + Unpin>(stream: &mut T, mut data: Bytes) -> bool {
+        !data.has_remaining() || matches!(stream.write_buf(&mut data).await, Ok(n) if n > 0)
     }
 
     async fn run(mut self, sender: Sender<ClientUpdate>) {
@@ -111,7 +171,7 @@
             tokio::select! {
                 server_message = self.receiver.recv() => {
                     match server_message {
-                        Some(message) => if !Self::write(&mut self.socket, message).await {
+                        Some(message) => if !Self::write(&mut self.stream, message).await {
                             sender.send(Error("Connection reset by peer".to_string())).await;
                             break;
                         }
@@ -120,7 +180,7 @@
                         }
                     }
                 }
-                client_message = Self::read(&mut self.socket, &mut self.decoder) => {
+                client_message = Self::read(&mut self.stream, &mut self.decoder) => {
                      match client_message {
                         Ok(message) => {
                             if !sender.send(Message(message)).await {
@@ -130,7 +190,7 @@
                         Err(e) => {
                             sender.send(Error(format!("{}", e))).await;
                             if matches!(e, ProtocolError::Timeout) {
-                                Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
+                                Self::write(&mut self.stream, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
                             }
                             break;
                         }
@@ -141,8 +201,16 @@
     }
 }
 
+#[cfg(feature = "tls-connections")]
+struct TlsListener {
+    listener: TcpListener,
+    acceptor: TlsAcceptor,
+}
+
 pub struct NetworkLayer {
     listener: TcpListener,
+    #[cfg(feature = "tls-connections")]
+    tls: TlsListener,
     server_state: ServerState,
     clients: Slab<Sender<Bytes>>,
 }
@@ -151,36 +219,82 @@
     pub async fn run(&mut self) {
         let (update_tx, mut update_rx) = channel(128);
 
-        loop {
-            tokio::select! {
-                Ok((stream, addr)) = self.listener.accept() => {
-                    if let Some(client) = self.create_client(stream, addr).await {
+        async fn accept_plain_branch(
+            layer: &mut NetworkLayer,
+            value: (TcpStream, SocketAddr),
+            update_tx: Sender<ClientUpdate>,
+        ) {
+            let (stream, addr) = value;
+            if let Some(client) = layer.create_client(ClientStream::Tcp(stream), addr).await {
+                tokio::spawn(client.run(update_tx));
+            }
+        }
+
+        #[cfg(feature = "tls-connections")]
+        async fn accept_tls_branch(
+            layer: &mut NetworkLayer,
+            value: (TcpStream, SocketAddr),
+            update_tx: Sender<ClientUpdate>,
+        ) {
+            let (stream, addr) = value;
+            match layer.tls.acceptor.accept(stream).await {
+                Ok(stream) => {
+                    if let Some(client) = layer.create_client(ClientStream::Tls(stream), addr).await
+                    {
                         tokio::spawn(client.run(update_tx.clone()));
                     }
                 }
-                client_message = update_rx.recv(), if !self.clients.is_empty() => {
-                    use ClientUpdateData::*;
-                    match client_message {
-                        Some(ClientUpdate{ client_id, data: Message(message) } ) => {
-                            self.handle_message(client_id, message).await;
-                        }
-                        Some(ClientUpdate{ client_id, data: Error(e) } ) => {
-                            let mut response = handlers::Response::new(client_id);
-                            info!("Client {} error: {:?}", client_id, e);
-                            response.remove_client(client_id);
-                            handlers::handle_client_loss(&mut self.server_state, client_id, &mut response);
-                            self.handle_response(response).await;
-                        }
-                        None => unreachable!()
-                    }
+                Err(e) => {
+                    warn!("Unable to establish TLS connection: {}", e);
+                }
+            }
+        }
+
+        async fn client_message_branch(
+            layer: &mut NetworkLayer,
+            client_message: Option<ClientUpdate>,
+        ) {
+            use ClientUpdateData::*;
+            match client_message {
+                Some(ClientUpdate {
+                    client_id,
+                    data: Message(message),
+                }) => {
+                    layer.handle_message(client_id, message).await;
                 }
+                Some(ClientUpdate {
+                    client_id,
+                    data: Error(e),
+                }) => {
+                    let mut response = handlers::Response::new(client_id);
+                    info!("Client {} error: {:?}", client_id, e);
+                    response.remove_client(client_id);
+                    handlers::handle_client_loss(&mut layer.server_state, client_id, &mut response);
+                    layer.handle_response(response).await;
+                }
+                None => unreachable!(),
+            }
+        }
+
+        loop {
+            #[cfg(not(feature = "tls-connections"))]
+            tokio::select! {
+                Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await,
+                client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
+            }
+
+            #[cfg(feature = "tls-connections")]
+            tokio::select! {
+                Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await,
+                Ok(value) = self.tls.listener.accept() => accept_tls_branch(self, value, update_tx.clone()).await,
+                client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
             }
         }
     }
 
     async fn create_client(
         &mut self,
-        stream: TcpStream,
+        stream: ClientStream,
         addr: SocketAddr,
     ) -> Option<NetworkClient> {
         let entry = self.clients.vacant_entry();
@@ -263,6 +377,10 @@
 
 pub struct NetworkLayerBuilder {
     listener: Option<TcpListener>,
+    #[cfg(feature = "tls-connections")]
+    tls_listener: Option<TcpListener>,
+    #[cfg(feature = "tls-connections")]
+    tls_acceptor: Option<TlsAcceptor>,
     clients_capacity: usize,
     rooms_capacity: usize,
 }
@@ -273,6 +391,10 @@
             clients_capacity: 1024,
             rooms_capacity: 512,
             listener: None,
+            #[cfg(feature = "tls-connections")]
+            tls_listener: None,
+            #[cfg(feature = "tls-connections")]
+            tls_acceptor: None,
         }
     }
 }
@@ -285,6 +407,22 @@
         }
     }
 
+    #[cfg(feature = "tls-connections")]
+    pub fn with_tls_acceptor(self, listener: TlsAcceptor) -> Self {
+        Self {
+            tls_acceptor: Option::from(listener),
+            ..self
+        }
+    }
+
+    #[cfg(feature = "tls-connections")]
+    pub fn with_tls_listener(self, listener: TlsAcceptor) -> Self {
+        Self {
+            tls_acceptor: Option::from(listener),
+            ..self
+        }
+    }
+
     pub fn build(self) -> NetworkLayer {
         let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
 
@@ -292,6 +430,11 @@
 
         NetworkLayer {
             listener: self.listener.expect("No listener provided"),
+            #[cfg(feature = "tls-connections")]
+            tls: TlsListener {
+                listener: self.tls_listener.expect("No TLS listener provided"),
+                acceptor: self.tls_acceptor.expect("No TLS acceptor provided"),
+            },
             server_state,
             clients,
         }