--- a/rust/hedgewars-server/src/server/network.rs Thu Apr 25 19:44:14 2019 +0200
+++ b/rust/hedgewars-server/src/server/network.rs Thu Apr 25 19:58:52 2019 +0200
@@ -11,7 +11,7 @@
use log::*;
use mio::{
net::{TcpListener, TcpStream},
- Poll, PollOpt, Ready, Token,
+ Evented, Poll, PollOpt, Ready, Token,
};
use mio_extras::timer;
use netbuf;
@@ -48,32 +48,29 @@
NeedsWrite,
NeedsRead,
Closed,
+ #[cfg(feature = "tls-connections")]
+ Connected,
}
type NetworkResult<T> = io::Result<(T, NetworkClientState)>;
-#[cfg(not(feature = "tls-connections"))]
pub enum ClientSocket {
Plain(TcpStream),
-}
-
-#[cfg(feature = "tls-connections")]
-pub enum ClientSocket {
+ #[cfg(feature = "tls-connections")]
SslHandshake(Option<MidHandshakeSslStream<TcpStream>>),
+ #[cfg(feature = "tls-connections")]
SslStream(SslStream<TcpStream>),
}
impl ClientSocket {
fn inner(&self) -> &TcpStream {
- #[cfg(not(feature = "tls-connections"))]
match self {
ClientSocket::Plain(stream) => stream,
- }
-
- #[cfg(feature = "tls-connections")]
- match self {
+ #[cfg(feature = "tls-connections")]
ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(),
+ #[cfg(feature = "tls-connections")]
ClientSocket::SslHandshake(None) => unreachable!(),
+ #[cfg(feature = "tls-connections")]
ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(),
}
}
@@ -117,7 +114,7 @@
"TLS handshake with {} ({}) completed",
self.id, self.peer_addr
);
- Ok(NetworkClientState::Idle)
+ Ok(NetworkClientState::Connected)
}
Err(HandshakeError::WouldBlock(new_handshake)) => {
self.socket = ClientSocket::SslHandshake(Some(new_handshake));
@@ -171,19 +168,16 @@
}
pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> {
- #[cfg(not(feature = "tls-connections"))]
match self.socket {
ClientSocket::Plain(ref mut stream) => {
NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
}
- }
-
- #[cfg(feature = "tls-connections")]
- match self.socket {
+ #[cfg(feature = "tls-connections")]
ClientSocket::SslHandshake(ref mut handshake_opt) => {
let handshake = std::mem::replace(handshake_opt, None).unwrap();
Ok((Vec::new(), self.handshake_impl(handshake)?))
}
+ #[cfg(feature = "tls-connections")]
ClientSocket::SslStream(ref mut stream) => {
NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr)
}
@@ -210,25 +204,18 @@
}
pub fn write(&mut self) -> NetworkResult<()> {
- let result = {
- #[cfg(not(feature = "tls-connections"))]
- match self.socket {
- ClientSocket::Plain(ref mut stream) => {
- NetworkClient::write_impl(&mut self.buf_out, stream)
- }
+ let result = match self.socket {
+ ClientSocket::Plain(ref mut stream) => {
+ NetworkClient::write_impl(&mut self.buf_out, stream)
}
-
#[cfg(feature = "tls-connections")]
- {
- match self.socket {
- ClientSocket::SslHandshake(ref mut handshake_opt) => {
- let handshake = std::mem::replace(handshake_opt, None).unwrap();
- Ok(((), self.handshake_impl(handshake)?))
- }
- ClientSocket::SslStream(ref mut stream) => {
- NetworkClient::write_impl(&mut self.buf_out, stream)
- }
- }
+ ClientSocket::SslHandshake(ref mut handshake_opt) => {
+ let handshake = std::mem::replace(handshake_opt, None).unwrap();
+ Ok(((), self.handshake_impl(handshake)?))
+ }
+ #[cfg(feature = "tls-connections")]
+ ClientSocket::SslStream(ref mut stream) => {
+ NetworkClient::write_impl(&mut self.buf_out, stream)
}
};
@@ -251,6 +238,7 @@
#[cfg(feature = "tls-connections")]
struct ServerSsl {
+ listener: TcpListener,
context: SslContext,
}
@@ -324,6 +312,10 @@
timer: timer::Timer<TimerData>,
}
+fn register_read<E: Evented>(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> {
+ poll.register(evented, token, Ready::readable(), PollOpt::edge())
+}
+
fn create_ping_timeout(
timer: &mut timer::Timer<TimerData>,
probes_count: u8,
@@ -343,29 +335,8 @@
}
impl NetworkLayer {
- pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer {
- let server = HWServer::new(clients_limit, rooms_limit);
- let clients = Slab::with_capacity(clients_limit);
- let pending = HashSet::with_capacity(2 * clients_limit);
- let pending_cache = Vec::with_capacity(2 * clients_limit);
- let timer = timer::Builder::default().build();
-
- NetworkLayer {
- listener,
- server,
- clients,
- pending,
- pending_cache,
- #[cfg(feature = "tls-connections")]
- ssl: NetworkLayer::create_ssl_context(),
- #[cfg(feature = "official-server")]
- io: IoLayer::new(),
- timer,
- }
- }
-
#[cfg(feature = "tls-connections")]
- fn create_ssl_context() -> ServerSsl {
+ fn create_ssl_context(listener: TcpListener) -> ServerSsl {
let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
builder.set_read_ahead(true);
@@ -378,24 +349,16 @@
builder.set_options(SslOptions::NO_COMPRESSION);
builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap();
ServerSsl {
+ listener,
context: builder.build(),
}
}
- pub fn register_server(&self, poll: &Poll) -> io::Result<()> {
- poll.register(
- &self.listener,
- utils::SERVER_TOKEN,
- Ready::readable(),
- PollOpt::edge(),
- )?;
-
- poll.register(
- &self.timer,
- utils::TIMER_TOKEN,
- Ready::readable(),
- PollOpt::edge(),
- )?;
+ pub fn register(&self, poll: &Poll) -> io::Result<()> {
+ register_read(poll, &self.listener, utils::SERVER_TOKEN)?;
+ #[cfg(feature = "tls-connections")]
+ register_read(poll, &self.listener, utils::SECURE_SERVER_TOKEN)?;
+ register_read(poll, &self.timer, utils::TIMER_TOKEN)?;
#[cfg(feature = "official-server")]
self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?;
@@ -448,6 +411,10 @@
}
fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) {
+ if response.is_empty() {
+ return;
+ }
+
debug!("{} pending server messages", response.len());
let output = response.extract_messages(&mut self.server);
for (clients, message) in output {
@@ -512,41 +479,41 @@
}
fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
- #[cfg(not(feature = "tls-connections"))]
- {
- Ok(ClientSocket::Plain(socket))
- }
+ Ok(ClientSocket::Plain(socket))
+ }
- #[cfg(feature = "tls-connections")]
- {
- let ssl = Ssl::new(&self.ssl.context).unwrap();
- let mut builder = SslStreamBuilder::new(ssl, socket);
- builder.set_accept_state();
- match builder.handshake() {
- Ok(stream) => Ok(ClientSocket::SslStream(stream)),
- Err(HandshakeError::WouldBlock(stream)) => {
- Ok(ClientSocket::SslHandshake(Some(stream)))
- }
- Err(e) => {
- debug!("OpenSSL handshake failed: {}", e);
- Err(Error::new(ErrorKind::Other, "Connection failure"))
- }
+ #[cfg(feature = "tls-connections")]
+ fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> {
+ let ssl = Ssl::new(&self.ssl.context).unwrap();
+ let mut builder = SslStreamBuilder::new(ssl, socket);
+ builder.set_accept_state();
+ match builder.handshake() {
+ Ok(stream) => Ok(ClientSocket::SslStream(stream)),
+ Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))),
+ Err(e) => {
+ debug!("OpenSSL handshake failed: {}", e);
+ Err(Error::new(ErrorKind::Other, "Connection failure"))
}
}
}
- pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> {
+ pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> {
let (client_socket, addr) = self.listener.accept()?;
info!("Connected: {}", addr);
- let client_id = self.register_client(poll, self.create_client_socket(client_socket)?, addr);
-
- let mut response = handlers::Response::new(client_id);
-
- handlers::handle_client_accept(&mut self.server, client_id, &mut response);
-
- if !response.is_empty() {
- self.handle_response(response, poll);
+ match server_token {
+ utils::SERVER_TOKEN => {
+ let client_id =
+ self.register_client(poll, self.create_client_socket(client_socket)?, addr);
+ let mut response = handlers::Response::new(client_id);
+ handlers::handle_client_accept(&mut self.server, client_id, &mut response);
+ self.handle_response(response, poll);
+ }
+ #[cfg(feature = "tls-connections")]
+ utils::SECURE_SERVER_TOKEN => {
+ self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr);
+ }
+ _ => unreachable!(),
}
Ok(())
@@ -595,6 +562,12 @@
self.pending.insert((client_id, state));
}
NetworkClientState::Closed => self.client_error(&poll, client_id)?,
+ #[cfg(feature = "tls-connections")]
+ NetworkClientState::Connected => {
+ let mut response = handlers::Response::new(client_id);
+ handlers::handle_client_accept(&mut self.server, client_id, &mut response);
+ self.handle_response(response, poll);
+ }
_ => {}
};
}
@@ -606,9 +579,7 @@
)?,
}
- if !response.is_empty() {
- self.handle_response(response, poll);
- }
+ self.handle_response(response, poll);
Ok(())
}
@@ -663,3 +634,60 @@
Ok(())
}
}
+
+pub struct NetworkLayerBuilder {
+ listener: Option<TcpListener>,
+ secure_listener: Option<TcpListener>,
+ clients_capacity: usize,
+ rooms_capacity: usize,
+}
+
+impl Default for NetworkLayerBuilder {
+ fn default() -> Self {
+ Self {
+ clients_capacity: 1024,
+ rooms_capacity: 512,
+ listener: None,
+ secure_listener: None,
+ }
+ }
+}
+
+impl NetworkLayerBuilder {
+ pub fn with_listener(self, listener: TcpListener) -> Self {
+ Self {
+ listener: Some(listener),
+ ..self
+ }
+ }
+
+ pub fn with_secure_listener(self, listener: TcpListener) -> Self {
+ Self {
+ secure_listener: Some(listener),
+ ..self
+ }
+ }
+
+ pub fn build(self) -> NetworkLayer {
+ let server = HWServer::new(self.clients_capacity, self.rooms_capacity);
+ let clients = Slab::with_capacity(self.clients_capacity);
+ let pending = HashSet::with_capacity(2 * self.clients_capacity);
+ let pending_cache = Vec::with_capacity(2 * self.clients_capacity);
+ let timer = timer::Builder::default().build();
+
+ NetworkLayer {
+ listener: self.listener.expect("No listener provided"),
+ server,
+ clients,
+ pending,
+ pending_cache,
+ #[cfg(feature = "tls-connections")]
+ ssl: NetworkLayer::create_ssl_context(
+ self.secure_listener.expect("No secure listener provided"),
+ ),
+ #[cfg(feature = "official-server")]
+ io: IoLayer::new(),
+ timer,
+ }
+ }
+}