--- a/gameServer2/src/server/network.rs Wed Sep 05 19:22:29 2018 +0300
+++ b/gameServer2/src/server/network.rs Thu Sep 06 23:12:32 2018 +0300
@@ -1,7 +1,7 @@
extern crate slab;
use std::{
- io, io::{Error, ErrorKind, Write},
+ io, io::{Error, ErrorKind, Read, Write},
net::{SocketAddr, IpAddr, Ipv4Addr},
collections::HashSet,
mem::{swap, replace}
@@ -22,6 +22,15 @@
server::{HWServer},
coretypes::ClientId
};
+#[cfg(feature = "tls-connections")]
+use openssl::{
+ ssl::{
+ SslMethod, SslContext, Ssl, SslContextBuilder,
+ SslVerifyMode, SslFiletype, SslOptions,
+ SslStreamBuilder, HandshakeError, MidHandshakeSslStream, SslStream
+ },
+ error::ErrorStack
+};
const MAX_BYTES_PER_READ: usize = 2048;
@@ -35,16 +44,43 @@
type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
+#[cfg(not(feature = "tls-connections"))]
+pub enum ClientSocket {
+ Plain(TcpStream)
+}
+
+#[cfg(feature = "tls-connections")]
+pub enum ClientSocket {
+ SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
+ SslStream(SslStream<TcpStream>)
+}
+
+impl ClientSocket {
+ fn inner(&self) -> &TcpStream {
+ #[cfg(not(feature = "tls-connections"))]
+ match self {
+ ClientSocket::Plain(stream) => stream,
+ }
+
+ #[cfg(feature = "tls-connections")]
+ match self {
+ ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
+ ClientSocket::SslHandshake(None) => unreachable!(),
+ ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref()
+ }
+ }
+}
+
pub struct NetworkClient {
id: ClientId,
- socket: TcpStream,
+ socket: ClientSocket,
peer_addr: SocketAddr,
decoder: ProtocolDecoder,
buf_out: netbuf::Buf
}
impl NetworkClient {
- pub fn new(id: ClientId, socket: TcpStream, peer_addr: SocketAddr) -> NetworkClient {
+ pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient {
NetworkClient {
id, socket, peer_addr,
decoder: ProtocolDecoder::new(),
@@ -52,31 +88,32 @@
}
}
- pub fn read_messages(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
+ fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R,
+ id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> {
let mut bytes_read = 0;
let result = loop {
- match self.decoder.read_from(&mut self.socket) {
+ match decoder.read_from(source) {
Ok(bytes) => {
- debug!("Client {}: read {} bytes", self.id, bytes);
+ debug!("Client {}: read {} bytes", id, bytes);
bytes_read += bytes;
if bytes == 0 {
let result = if bytes_read == 0 {
- info!("EOF for client {} ({})", self.id, self.peer_addr);
+ info!("EOF for client {} ({})", id, addr);
(Vec::new(), NetworkClientState::Closed)
} else {
- (self.decoder.extract_messages(), NetworkClientState::NeedsRead)
+ (decoder.extract_messages(), NetworkClientState::NeedsRead)
};
break Ok(result);
}
else if bytes_read >= MAX_BYTES_PER_READ {
- break Ok((self.decoder.extract_messages(), NetworkClientState::NeedsRead))
+ break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead))
}
}
Err(ref error) if error.kind() == ErrorKind::WouldBlock => {
let messages = if bytes_read == 0 {
Vec::new()
} else {
- self.decoder.extract_messages()
+ decoder.extract_messages()
};
break Ok((messages, NetworkClientState::Idle));
}
@@ -84,14 +121,48 @@
break Err(error)
}
};
- self.decoder.sweep();
+ decoder.sweep();
result
}
- pub fn flush(&mut self) -> NetworkResult<()> {
+ pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
+ #[cfg(not(feature = "tls-connections"))]
+ match self.socket {
+ ClientSocket::Plain(ref mut stream) =>
+ NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr),
+ }
+
+ #[cfg(feature = "tls-connections")]
+ match self.socket {
+ ClientSocket::SslHandshake(ref mut handshake_opt) => {
+ let mut handshake = std::mem::replace(handshake_opt, None).unwrap();
+
+ match handshake.handshake() {
+ Ok(stream) => {
+ debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr);
+ self.socket = ClientSocket::SslStream(stream);
+
+ Ok((Vec::new(), NetworkClientState::Idle))
+ }
+ Err(HandshakeError::WouldBlock(new_handshake)) => {
+ *handshake_opt = Some(new_handshake);
+ Ok((Vec::new(), NetworkClientState::Idle))
+ }
+ Err(e) => {
+ debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr);
+ Err(Error::new(ErrorKind::Other, "Connection failure"))
+ }
+ }
+ },
+ ClientSocket::SslStream(ref mut stream) =>
+ NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
+ }
+ }
+
+ fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
let result = loop {
- match self.buf_out.write_to(&mut self.socket) {
- Ok(bytes) if self.buf_out.is_empty() || bytes == 0 =>
+ match buf_out.write_to(destination) {
+ Ok(bytes) if buf_out.is_empty() || bytes == 0 =>
break Ok(((), NetworkClientState::Idle)),
Ok(_) => (),
Err(ref error) if error.kind() == ErrorKind::Interrupted
@@ -102,7 +173,28 @@
break Err(error)
}
};
- self.socket.flush()?;
+ result
+ }
+
+ pub fn write(&mut self) -> NetworkResult<()> {
+ let result = {
+ #[cfg(not(feature = "tls-connections"))]
+ match self.socket {
+ ClientSocket::Plain(ref mut stream) =>
+ NetworkClient::write_impl(&mut self.buf_out, stream)
+ }
+
+ #[cfg(feature = "tls-connections")] {
+ match self.socket {
+ ClientSocket::SslHandshake(_) =>
+ Ok(((), NetworkClientState::Idle)),
+ ClientSocket::SslStream(ref mut stream) =>
+ NetworkClient::write_impl(&mut self.buf_out, stream)
+ }
+ }
+ };
+
+ self.socket.inner().flush()?;
result
}
@@ -119,12 +211,19 @@
}
}
+#[cfg(feature = "tls-connections")]
+struct ServerSsl {
+ context: SslContext
+}
+
pub struct NetworkLayer {
listener: TcpListener,
server: HWServer,
clients: Slab<NetworkClient>,
pending: HashSet<(ClientId, NetworkClientState)>,
- pending_cache: Vec<(ClientId, NetworkClientState)>
+ pending_cache: Vec<(ClientId, NetworkClientState)>,
+ #[cfg(feature = "tls-connections")]
+ ssl: ServerSsl
}
impl NetworkLayer {
@@ -133,7 +232,24 @@
let clients = Slab::with_capacity(clients_limit);
let pending = HashSet::with_capacity(2 * clients_limit);
let pending_cache = Vec::with_capacity(2 * clients_limit);
- NetworkLayer {listener, server, clients, pending, pending_cache}
+
+ NetworkLayer {
+ listener, server, clients, pending, pending_cache,
+ #[cfg(feature = "tls-connections")]
+ ssl: NetworkLayer::create_ssl_context()
+ }
+ }
+
+ #[cfg(feature = "tls-connections")]
+ fn create_ssl_context() -> ServerSsl {
+ let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
+ builder.set_verify(SslVerifyMode::NONE);
+ builder.set_read_ahead(true);
+ builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap();
+ builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap();
+ builder.set_options(SslOptions::NO_COMPRESSION);
+ builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
+ ServerSsl { context: builder.build() }
}
pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
@@ -144,7 +260,7 @@
fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
let mut client_exists = false;
if let Some(ref client) = self.clients.get(id) {
- poll.deregister(&client.socket)
+ poll.deregister(client.socket.inner())
.expect("could not deregister socket");
info!("client {} ({}) removed", client.id, client.peer_addr);
client_exists = true;
@@ -154,8 +270,8 @@
}
}
- fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: TcpStream, addr: SocketAddr) {
- poll.register(&client_socket, Token(id),
+ fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) {
+ poll.register(client_socket.inner(), Token(id),
Ready::readable() | Ready::writable(),
PollOpt::edge())
.expect("could not register socket with event loop");
@@ -180,12 +296,34 @@
}
}
+ fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
+ #[cfg(not(feature = "tls-connections"))] {
+ Ok(ClientSocket::Plain(socket))
+ }
+
+ #[cfg(feature = "tls-connections")] {
+ let ssl = Ssl::new(&self.ssl.context).unwrap();
+ let mut builder = SslStreamBuilder::new(ssl, socket);
+ builder.set_accept_state();
+ match builder.handshake() {
+ Ok(stream) =>
+ Ok(ClientSocket::SslStream(stream)),
+ Err(HandshakeError::WouldBlock(stream)) =>
+ Ok(ClientSocket::SslHandshake(Some(stream))),
+ Err(e) => {
+ debug!("OpenSSL handshake failed: {}", e);
+ Err(Error::new(ErrorKind::Other, "Connection failure"))
+ }
+ }
+ }
+ }
+
pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
let (client_socket, addr) = self.listener.accept()?;
info!("Connected: {}", addr);
let client_id = self.server.add_client();
- self.register_client(poll, client_id, client_socket, addr);
+ self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr);
self.flush_server_messages();
Ok(())
@@ -205,7 +343,7 @@
client_id: ClientId) -> io::Result<()> {
let messages =
if let Some(ref mut client) = self.clients.get_mut(client_id) {
- client.read_messages()
+ client.read()
} else {
warn!("invalid readable client: {}", client_id);
Ok((Vec::new(), NetworkClientState::Idle))
@@ -246,7 +384,7 @@
client_id: ClientId) -> io::Result<()> {
let result =
if let Some(ref mut client) = self.clients.get_mut(client_id) {
- client.flush()
+ client.write()
} else {
warn!("invalid writable client: {}", client_id);
Ok(((), NetworkClientState::Idle))