--- a/rust/hedgewars-server/src/protocol.rs Fri Apr 12 19:26:44 2019 +0300
+++ b/rust/hedgewars-server/src/protocol.rs Fri Apr 12 22:36:54 2019 +0300
@@ -1,3 +1,5 @@
+use crate::protocol::parser::message;
+use log::*;
use netbuf;
use nom::{Err, ErrorKind, IResult};
use std::io::{Read, Result};
@@ -10,6 +12,7 @@
pub struct ProtocolDecoder {
buf: netbuf::Buf,
consumed: usize,
+ is_recovering: bool,
}
impl ProtocolDecoder {
@@ -17,26 +20,55 @@
ProtocolDecoder {
buf: netbuf::Buf::new(),
consumed: 0,
+ is_recovering: false,
}
}
+ fn recover(&mut self) -> bool {
+ self.is_recovering = match parser::malformed_message(&self.buf[..]) {
+ Ok((tail, ())) => {
+ self.buf.consume(self.buf.len() - tail.len());
+ false
+ }
+ _ => {
+ self.buf.consume(self.buf.len());
+ true
+ }
+ };
+ !self.is_recovering
+ }
+
pub fn read_from<R: Read>(&mut self, stream: &mut R) -> Result<usize> {
- self.buf.read_from(stream)
+ let count = self.buf.read_from(stream)?;
+ if count > 0 && self.is_recovering {
+ self.recover();
+ }
+ Ok(count)
}
pub fn extract_messages(&mut self) -> Vec<messages::HWProtocolMessage> {
- let parse_result = parser::extract_messages(&self.buf[..]);
- match parse_result {
- Ok((tail, msgs)) => {
- self.consumed = self.buf.len() - self.consumed - tail.len();
- msgs
+ let mut messages = vec![];
+ let mut consumed = 0;
+ if !self.is_recovering {
+ loop {
+ match parser::message(&self.buf[consumed..]) {
+ Ok((tail, message)) => {
+ messages.push(message);
+ consumed += self.buf.len() - tail.len();
+ }
+ Err(nom::Err::Incomplete(_)) => break,
+ Err(nom::Err::Failure(e)) | Err(nom::Err::Error(e)) => {
+ debug!("Invalid message: {:?}", e);
+ self.buf.consume(consumed);
+ consumed = 0;
+ if !self.recover() || self.buf.is_empty() {
+ break;
+ }
+ }
+ }
}
- _ => unreachable!(),
}
- }
-
- pub fn sweep(&mut self) {
- self.buf.consume(self.consumed);
- self.consumed = 0;
+ self.buf.consume(consumed);
+ messages
}
}