--- a/rust/hedgewars-server/src/server/network.rs Thu Mar 23 23:41:26 2023 +0300
+++ b/rust/hedgewars-server/src/server/network.rs Fri Mar 24 03:26:08 2023 +0300
@@ -1,16 +1,21 @@
use bytes::{Buf, Bytes};
use log::*;
use slab::Slab;
+use std::io::Error;
+use std::pin::Pin;
+use std::task::{Context, Poll};
use std::{
iter::Iterator,
net::{IpAddr, SocketAddr},
time::Duration,
};
use tokio::{
- io::AsyncReadExt,
+ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
net::{TcpListener, TcpStream},
sync::mpsc::{channel, Receiver, Sender},
};
+#[cfg(feature = "tls-connections")]
+use tokio_native_tls::{TlsAcceptor, TlsStream};
use crate::{
core::{
@@ -25,7 +30,6 @@
use hedgewars_network_protocol::{
messages::HwServerMessage::Redirect, messages::*, parser::server_message,
};
-use tokio::io::AsyncWriteExt;
const PING_TIMEOUT: Duration = Duration::from_secs(15);
@@ -56,9 +60,65 @@
}
}
+enum ClientStream {
+ Tcp(TcpStream),
+ #[cfg(feature = "tls-connections")]
+ Tls(TlsStream<TcpStream>),
+}
+
+impl Unpin for ClientStream {}
+
+impl AsyncRead for ClientStream {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<std::io::Result<()>> {
+ use ClientStream::*;
+ match Pin::into_inner(self) {
+ Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
+ #[cfg(feature = "tls-connections")]
+ Tls(stream) => Pin::new(stream).poll_read(cx, buf),
+ }
+ }
+}
+
+impl AsyncWrite for ClientStream {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, Error>> {
+ use ClientStream::*;
+ match Pin::into_inner(self) {
+ Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
+ #[cfg(feature = "tls-connections")]
+ Tls(stream) => Pin::new(stream).poll_write(cx, buf),
+ }
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
+ use ClientStream::*;
+ match Pin::into_inner(self) {
+ Tcp(stream) => Pin::new(stream).poll_flush(cx),
+ #[cfg(feature = "tls-connections")]
+ Tls(stream) => Pin::new(stream).poll_flush(cx),
+ }
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
+ use ClientStream::*;
+ match Pin::into_inner(self) {
+ Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
+ #[cfg(feature = "tls-connections")]
+ Tls(stream) => Pin::new(stream).poll_shutdown(cx),
+ }
+ }
+}
+
struct NetworkClient {
id: ClientId,
- socket: TcpStream,
+ stream: ClientStream,
receiver: Receiver<Bytes>,
peer_addr: SocketAddr,
decoder: ProtocolDecoder,
@@ -67,27 +127,27 @@
impl NetworkClient {
fn new(
id: ClientId,
- socket: TcpStream,
+ stream: ClientStream,
peer_addr: SocketAddr,
receiver: Receiver<Bytes>,
) -> Self {
Self {
id,
- socket,
+ stream,
peer_addr,
receiver,
decoder: ProtocolDecoder::new(PING_TIMEOUT),
}
}
- async fn read(
- socket: &mut TcpStream,
+ async fn read<T: AsyncRead + AsyncWrite + Unpin>(
+ stream: &mut T,
decoder: &mut ProtocolDecoder,
) -> protocol::Result<HwProtocolMessage> {
- let result = decoder.read_from(socket).await;
+ let result = decoder.read_from(stream).await;
if matches!(result, Err(ProtocolError::Timeout)) {
- if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await {
- decoder.read_from(socket).await
+ if Self::write(stream, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await {
+ decoder.read_from(stream).await
} else {
Err(ProtocolError::Eof)
}
@@ -96,8 +156,8 @@
}
}
- async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool {
- !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0)
+ async fn write<T: AsyncWrite + Unpin>(stream: &mut T, mut data: Bytes) -> bool {
+ !data.has_remaining() || matches!(stream.write_buf(&mut data).await, Ok(n) if n > 0)
}
async fn run(mut self, sender: Sender<ClientUpdate>) {
@@ -111,7 +171,7 @@
tokio::select! {
server_message = self.receiver.recv() => {
match server_message {
- Some(message) => if !Self::write(&mut self.socket, message).await {
+ Some(message) => if !Self::write(&mut self.stream, message).await {
sender.send(Error("Connection reset by peer".to_string())).await;
break;
}
@@ -120,7 +180,7 @@
}
}
}
- client_message = Self::read(&mut self.socket, &mut self.decoder) => {
+ client_message = Self::read(&mut self.stream, &mut self.decoder) => {
match client_message {
Ok(message) => {
if !sender.send(Message(message)).await {
@@ -130,7 +190,7 @@
Err(e) => {
sender.send(Error(format!("{}", e))).await;
if matches!(e, ProtocolError::Timeout) {
- Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
+ Self::write(&mut self.stream, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await;
}
break;
}
@@ -141,8 +201,16 @@
}
}
+#[cfg(feature = "tls-connections")]
+struct TlsListener {
+ listener: TcpListener,
+ acceptor: TlsAcceptor,
+}
+
pub struct NetworkLayer {
listener: TcpListener,
+ #[cfg(feature = "tls-connections")]
+ tls: TlsListener,
server_state: ServerState,
clients: Slab<Sender<Bytes>>,
}
@@ -151,36 +219,82 @@
pub async fn run(&mut self) {
let (update_tx, mut update_rx) = channel(128);
- loop {
- tokio::select! {
- Ok((stream, addr)) = self.listener.accept() => {
- if let Some(client) = self.create_client(stream, addr).await {
+ async fn accept_plain_branch(
+ layer: &mut NetworkLayer,
+ value: (TcpStream, SocketAddr),
+ update_tx: Sender<ClientUpdate>,
+ ) {
+ let (stream, addr) = value;
+ if let Some(client) = layer.create_client(ClientStream::Tcp(stream), addr).await {
+ tokio::spawn(client.run(update_tx));
+ }
+ }
+
+ #[cfg(feature = "tls-connections")]
+ async fn accept_tls_branch(
+ layer: &mut NetworkLayer,
+ value: (TcpStream, SocketAddr),
+ update_tx: Sender<ClientUpdate>,
+ ) {
+ let (stream, addr) = value;
+ match layer.tls.acceptor.accept(stream).await {
+ Ok(stream) => {
+ if let Some(client) = layer.create_client(ClientStream::Tls(stream), addr).await
+ {
tokio::spawn(client.run(update_tx.clone()));
}
}
- client_message = update_rx.recv(), if !self.clients.is_empty() => {
- use ClientUpdateData::*;
- match client_message {
- Some(ClientUpdate{ client_id, data: Message(message) } ) => {
- self.handle_message(client_id, message).await;
- }
- Some(ClientUpdate{ client_id, data: Error(e) } ) => {
- let mut response = handlers::Response::new(client_id);
- info!("Client {} error: {:?}", client_id, e);
- response.remove_client(client_id);
- handlers::handle_client_loss(&mut self.server_state, client_id, &mut response);
- self.handle_response(response).await;
- }
- None => unreachable!()
- }
+ Err(e) => {
+ warn!("Unable to establish TLS connection: {}", e);
+ }
+ }
+ }
+
+ async fn client_message_branch(
+ layer: &mut NetworkLayer,
+ client_message: Option<ClientUpdate>,
+ ) {
+ use ClientUpdateData::*;
+ match client_message {
+ Some(ClientUpdate {
+ client_id,
+ data: Message(message),
+ }) => {
+ layer.handle_message(client_id, message).await;
}
+ Some(ClientUpdate {
+ client_id,
+ data: Error(e),
+ }) => {
+ let mut response = handlers::Response::new(client_id);
+ info!("Client {} error: {:?}", client_id, e);
+ response.remove_client(client_id);
+ handlers::handle_client_loss(&mut layer.server_state, client_id, &mut response);
+ layer.handle_response(response).await;
+ }
+ None => unreachable!(),
+ }
+ }
+
+ loop {
+ #[cfg(not(feature = "tls-connections"))]
+ tokio::select! {
+ Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await,
+ client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
+ }
+
+ #[cfg(feature = "tls-connections")]
+ tokio::select! {
+ Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await,
+ Ok(value) = self.tls.listener.accept() => accept_tls_branch(self, value, update_tx.clone()).await,
+ client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await
}
}
}
async fn create_client(
&mut self,
- stream: TcpStream,
+ stream: ClientStream,
addr: SocketAddr,
) -> Option<NetworkClient> {
let entry = self.clients.vacant_entry();
@@ -263,6 +377,10 @@
pub struct NetworkLayerBuilder {
listener: Option<TcpListener>,
+ #[cfg(feature = "tls-connections")]
+ tls_listener: Option<TcpListener>,
+ #[cfg(feature = "tls-connections")]
+ tls_acceptor: Option<TlsAcceptor>,
clients_capacity: usize,
rooms_capacity: usize,
}
@@ -273,6 +391,10 @@
clients_capacity: 1024,
rooms_capacity: 512,
listener: None,
+ #[cfg(feature = "tls-connections")]
+ tls_listener: None,
+ #[cfg(feature = "tls-connections")]
+ tls_acceptor: None,
}
}
}
@@ -285,6 +407,22 @@
}
}
+ #[cfg(feature = "tls-connections")]
+ pub fn with_tls_acceptor(self, listener: TlsAcceptor) -> Self {
+ Self {
+ tls_acceptor: Option::from(listener),
+ ..self
+ }
+ }
+
+ #[cfg(feature = "tls-connections")]
+ pub fn with_tls_listener(self, listener: TlsAcceptor) -> Self {
+ Self {
+ tls_acceptor: Option::from(listener),
+ ..self
+ }
+ }
+
pub fn build(self) -> NetworkLayer {
let server_state = ServerState::new(self.clients_capacity, self.rooms_capacity);
@@ -292,6 +430,11 @@
NetworkLayer {
listener: self.listener.expect("No listener provided"),
+ #[cfg(feature = "tls-connections")]
+ tls: TlsListener {
+ listener: self.tls_listener.expect("No TLS listener provided"),
+ acceptor: self.tls_acceptor.expect("No TLS acceptor provided"),
+ },
server_state,
clients,
}