Add TLS support
authoralfadur
Thu, 06 Sep 2018 23:12:32 +0300
changeset 13804 c8fd12db6215
parent 13803 4664da990556
child 13805 0118b7412570
Add TLS support
gameServer2/Cargo.toml
gameServer2/src/main.rs
gameServer2/src/server/network.rs
--- a/gameServer2/Cargo.toml	Wed Sep 05 19:22:29 2018 +0300
+++ b/gameServer2/Cargo.toml	Thu Sep 06 23:12:32 2018 +0300
@@ -8,6 +8,8 @@
 
 [features]
 official-server = []
+tls-connections = ["openssl"]
+default = []
 
 [dependencies]
 rand = "0.5"
@@ -22,6 +24,7 @@
 serde = "1.0"
 serde_yaml = "0.8"
 serde_derive = "1.0"
+openssl = { version = "0.10", optional = true }
 
 [dev-dependencies]
 proptest = "0.8"
\ No newline at end of file
--- a/gameServer2/src/main.rs	Wed Sep 05 19:22:29 2018 +0300
+++ b/gameServer2/src/main.rs	Thu Sep 06 23:12:32 2018 +0300
@@ -17,6 +17,8 @@
 #[macro_use] extern crate bitflags;
 extern crate serde;
 extern crate serde_yaml;
+#[cfg(feature = "tls-connections")]
+extern crate openssl;
 #[macro_use] extern crate serde_derive;
 
 //use std::io::*;
--- a/gameServer2/src/server/network.rs	Wed Sep 05 19:22:29 2018 +0300
+++ b/gameServer2/src/server/network.rs	Thu Sep 06 23:12:32 2018 +0300
@@ -1,7 +1,7 @@
 extern crate slab;
 
 use std::{
-    io, io::{Error, ErrorKind, Write},
+    io, io::{Error, ErrorKind, Read, Write},
     net::{SocketAddr, IpAddr, Ipv4Addr},
     collections::HashSet,
     mem::{swap, replace}
@@ -22,6 +22,15 @@
     server::{HWServer},
     coretypes::ClientId
 };
+#[cfg(feature = "tls-connections")]
+use openssl::{
+    ssl::{
+        SslMethod, SslContext, Ssl, SslContextBuilder,
+        SslVerifyMode, SslFiletype, SslOptions,
+        SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream
+    },
+    error::ErrorStack
+};
 
 const MAX_BYTES_PER_READ: usize = 2048;
 
@@ -35,16 +44,43 @@
 
 type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
 
+#[cfg(not(feature = "tls-connections"))]
+pub enum ClientSocket {
+    Plain(TcpStream)
+}
+
+#[cfg(feature = "tls-connections")]
+pub enum ClientSocket {
+    SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
+    SslStream(SslStream<TcpStream>)
+}
+
+impl ClientSocket {
+    fn inner(&self) -> &TcpStream {
+        #[cfg(not(feature = "tls-connections"))]
+        match self {
+            ClientSocket::Plain(stream) => stream,
+        }
+
+        #[cfg(feature = "tls-connections")]
+        match self {
+            ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
+            ClientSocket::SslHandshake(None) => unreachable!(),
+            ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref()
+        }
+    }
+}
+
 pub struct NetworkClient {
     id: ClientId,
-    socket: TcpStream,
+    socket: ClientSocket,
     peer_addr: SocketAddr,
     decoder: ProtocolDecoder,
     buf_out: netbuf::Buf
 }
 
 impl NetworkClient {
-    pub fn new(id: ClientId, socket: TcpStream, peer_addr: SocketAddr) -> NetworkClient {
+    pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
         NetworkClient {
             id, socket, peer_addr,
             decoder: ProtocolDecoder::new(),
@@ -52,31 +88,32 @@
         }
     }
 
-    pub fn read_messages(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
+    fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R,
+                          id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> {
         let mut bytes_read = 0;
         let result = loop {
-            match self.decoder.read_from(&mut self.socket) {
+            match decoder.read_from(source) {
                 Ok(bytes) => {
-                    debug!("Client {}: read {} bytes", self.id, bytes);
+                    debug!("Client {}: read {} bytes", id, bytes);
                     bytes_read += bytes;
                     if bytes == 0 {
                         let result = if bytes_read == 0 {
-                            info!("EOF for client {} ({})", self.id, self.peer_addr);
+                            info!("EOF for client {} ({})", id, addr);
                             (Vec::new(), NetworkClientState::Closed)
                         } else {
-                            (self.decoder.extract_messages(), NetworkClientState::NeedsRead)
+                            (decoder.extract_messages(), NetworkClientState::NeedsRead)
                         };
                         break Ok(result);
                     }
                     else if bytes_read >= MAX_BYTES_PER_READ {
-                        break Ok((self.decoder.extract_messages(), NetworkClientState::NeedsRead))
+                        break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead))
                     }
                 }
                 Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
                     let messages =  if bytes_read == 0 {
                         Vec::new()
                     } else {
-                        self.decoder.extract_messages()
+                        decoder.extract_messages()
                     };
                     break Ok((messages, NetworkClientState::Idle));
                 }
@@ -84,14 +121,48 @@
                     break Err(error)
             }
         };
-        self.decoder.sweep();
+        decoder.sweep();
         result
     }
 
-    pub fn flush(&mut self) -> NetworkResult<()> {
+    pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
+        #[cfg(not(feature = "tls-connections"))]
+        match self.socket {
+            ClientSocket::Plain(ref mut stream) =>
+                NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr),
+        }
+
+        #[cfg(feature = "tls-connections")]
+        match self.socket {
+            ClientSocket::SslHandshake(ref mut handshake_opt) => {
+                let mut handshake = std::mem::replace(handshake_opt, None).unwrap();
+
+                match handshake.handshake() {
+                    Ok(stream) => {
+                        debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr);
+                        self.socket = ClientSocket::SslStream(stream);
+
+                        Ok((Vec::new(), NetworkClientState::Idle))
+                    }
+                    Err(HandshakeError::WouldBlock(new_handshake)) => {
+                        *handshake_opt = Some(new_handshake);
+                        Ok((Vec::new(), NetworkClientState::Idle))
+                    }
+                    Err(e) => {
+                        debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
+                        Err(Error::new(ErrorKind::Other, "Connection failure"))
+                    }
+                }
+            },
+            ClientSocket::SslStream(ref mut stream) =>
+                NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
+        }
+    }
+
+    fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
         let result = loop {
-            match self.buf_out.write_to(&mut self.socket) {
-                Ok(bytes) if self.buf_out.is_empty() || bytes == 0 =>
+            match buf_out.write_to(destination) {
+                Ok(bytes) if buf_out.is_empty() || bytes == 0 =>
                     break Ok(((), NetworkClientState::Idle)),
                 Ok(_) => (),
                 Err(ref error) if error.kind() == ErrorKind::Interrupted
@@ -102,7 +173,28 @@
                     break Err(error)
             }
         };
-        self.socket.flush()?;
+        result
+    }
+
+    pub fn write(&mut self) -> NetworkResult<()> {
+        let result = {
+            #[cfg(not(feature = "tls-connections"))]
+            match self.socket {
+                ClientSocket::Plain(ref mut stream) =>
+                    NetworkClient::write_impl(&mut self.buf_out, stream)
+            }
+
+            #[cfg(feature = "tls-connections")] {
+                match self.socket {
+                    ClientSocket::SslHandshake(_) =>
+                        Ok(((), NetworkClientState::Idle)),
+                    ClientSocket::SslStream(ref mut stream) =>
+                        NetworkClient::write_impl(&mut self.buf_out, stream)
+                }
+            }
+        };
+
+        self.socket.inner().flush()?;
         result
     }
 
@@ -119,12 +211,19 @@
     }
 }
 
+#[cfg(feature = "tls-connections")]
+struct ServerSsl {
+    context: SslContext
+}
+
 pub struct NetworkLayer {
     listener: TcpListener,
     server: HWServer,
     clients: Slab<NetworkClient>,
     pending: HashSet<(ClientId, NetworkClientState)>,
-    pending_cache: Vec<(ClientId, NetworkClientState)>
+    pending_cache: Vec<(ClientId, NetworkClientState)>,
+    #[cfg(feature = "tls-connections")]
+    ssl: ServerSsl
 }
 
 impl NetworkLayer {
@@ -133,7 +232,24 @@
         let clients = Slab::with_capacity(clients_limit);
         let pending = HashSet::with_capacity(2 * clients_limit);
         let pending_cache = Vec::with_capacity(2 * clients_limit);
-        NetworkLayer {listener, server, clients, pending, pending_cache}
+
+        NetworkLayer {
+            listener, server, clients, pending, pending_cache,
+            #[cfg(feature = "tls-connections")]
+            ssl: NetworkLayer::create_ssl_context()
+        }
+    }
+
+    #[cfg(feature = "tls-connections")]
+    fn create_ssl_context() -> ServerSsl {
+        let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
+        builder.set_verify(SslVerifyMode::NONE);
+        builder.set_read_ahead(true);
+        builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap();
+        builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap();
+        builder.set_options(SslOptions::NO_COMPRESSION);
+        builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
+        ServerSsl { context: builder.build() }
     }
 
     pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
@@ -144,7 +260,7 @@
     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
         let mut client_exists = false;
         if let Some(ref client) = self.clients.get(id) {
-            poll.deregister(&client.socket)
+            poll.deregister(client.socket.inner())
                 .expect("could not deregister socket");
             info!("client {} ({}) removed", client.id, client.peer_addr);
             client_exists = true;
@@ -154,8 +270,8 @@
         }
     }
 
-    fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: TcpStream, addr: SocketAddr) {
-        poll.register(&client_socket, Token(id),
+    fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) {
+        poll.register(client_socket.inner(), Token(id),
                       Ready::readable() | Ready::writable(),
                       PollOpt::edge())
             .expect("could not register socket with event loop");
@@ -180,12 +296,34 @@
         }
     }
 
+    fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
+        #[cfg(not(feature = "tls-connections"))] {
+            Ok(ClientSocket::Plain(socket))
+        }
+
+        #[cfg(feature = "tls-connections")] {
+            let ssl = Ssl::new(&self.ssl.context).unwrap();
+            let mut builder = SslStreamBuilder::new(ssl, socket);
+            builder.set_accept_state();
+            match builder.handshake() {
+                Ok(stream) =>
+                    Ok(ClientSocket::SslStream(stream)),
+                Err(HandshakeError::WouldBlock(stream)) =>
+                    Ok(ClientSocket::SslHandshake(Some(stream))),
+                Err(e) => {
+                    debug!("OpenSSL handshake failed: {}", e);
+                    Err(Error::new(ErrorKind::Other, "Connection failure"))
+                }
+            }
+        }
+    }
+
     pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
         let (client_socket, addr) = self.listener.accept()?;
         info!("Connected: {}", addr);
 
         let client_id = self.server.add_client();
-        self.register_client(poll, client_id, client_socket, addr);
+        self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr);
         self.flush_server_messages();
 
         Ok(())
@@ -205,7 +343,7 @@
                            client_id: ClientId) -> io::Result<()> {
         let messages =
             if let Some(ref mut client) = self.clients.get_mut(client_id) {
-                client.read_messages()
+                client.read()
             } else {
                 warn!("invalid readable client: {}", client_id);
                 Ok((Vec::new(), NetworkClientState::Idle))
@@ -246,7 +384,7 @@
                            client_id: ClientId) -> io::Result<()> {
         let result =
             if let Some(ref mut client) = self.clients.get_mut(client_id) {
-                client.flush()
+                client.write()
             } else {
                 warn!("invalid writable client: {}", client_id);
                 Ok(((), NetworkClientState::Idle))