diff -r 7d0f747afcb8 -r a4d505a32879 rust/hedgewars-server/src/protocol.rs --- a/rust/hedgewars-server/src/protocol.rs Tue Feb 01 02:23:35 2022 +0300 +++ b/rust/hedgewars-server/src/protocol.rs Tue Feb 01 20:58:35 2022 +0300 @@ -1,22 +1,62 @@ use bytes::{Buf, BufMut, BytesMut}; use log::*; -use std::{io, io::ErrorKind, marker::Unpin}; -use tokio::io::AsyncReadExt; +use std::{ + error::Error, + fmt::{Debug, Display, Formatter}, + io, + io::ErrorKind, + marker::Unpin, + time::Duration, +}; +use tokio::{io::AsyncReadExt, time::timeout}; +use crate::protocol::ProtocolError::Timeout; use hedgewars_network_protocol::{ messages::HwProtocolMessage, + parser::HwProtocolError, parser::{malformed_message, message}, }; +#[derive(Debug)] +pub enum ProtocolError { + Eof, + Timeout, + Network(Box), +} + +impl Display for ProtocolError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ProtocolError::Eof => write!(f, "Connection reset by peer"), + ProtocolError::Timeout => write!(f, "Read operation timed out"), + ProtocolError::Network(source) => write!(f, "{:?}", source), + } + } +} + +impl Error for ProtocolError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + if let Self::Network(source) = self { + Some(source.as_ref()) + } else { + None + } + } +} + +pub type Result = std::result::Result; + pub struct ProtocolDecoder { buffer: BytesMut, + read_timeout: Duration, is_recovering: bool, } impl ProtocolDecoder { - pub fn new() -> ProtocolDecoder { + pub fn new(read_timeout: Duration) -> ProtocolDecoder { ProtocolDecoder { buffer: BytesMut::with_capacity(1024), + read_timeout, is_recovering: false, } } @@ -57,17 +97,21 @@ pub async fn read_from( &mut self, stream: &mut R, - ) -> Option { + ) -> Result { + use ProtocolError::*; + loop { if !self.buffer.has_remaining() { - let count = stream.read_buf(&mut self.buffer).await.ok()?; - if count == 0 { - return None; - } + match timeout(self.read_timeout, stream.read_buf(&mut self.buffer)).await { + Err(_) => return Err(Timeout), + Ok(Err(e)) => return Err(Network(Box::new(e))), + Ok(Ok(0)) => return Err(Eof), + Ok(Ok(_)) => (), + }; } while !self.buffer.is_empty() { if let Some(result) = self.extract_message() { - return Some(result); + return Ok(result); } } }