--- a/rust/hedgewars-server/src/protocol.rs Mon Jan 31 18:24:49 2022 +0300
+++ b/rust/hedgewars-server/src/protocol.rs Tue Feb 01 02:23:35 2022 +0300
@@ -1,67 +1,75 @@
+use bytes::{Buf, BufMut, BytesMut};
+use log::*;
+use std::{io, io::ErrorKind, marker::Unpin};
+use tokio::io::AsyncReadExt;
+
use hedgewars_network_protocol::{
messages::HwProtocolMessage,
parser::{malformed_message, message},
};
-use log::*;
-use netbuf;
-use std::io::{Read, Result};
pub struct ProtocolDecoder {
- buf: netbuf::Buf,
+ buffer: BytesMut,
is_recovering: bool,
}
impl ProtocolDecoder {
pub fn new() -> ProtocolDecoder {
ProtocolDecoder {
- buf: netbuf::Buf::new(),
+ buffer: BytesMut::with_capacity(1024),
is_recovering: false,
}
}
fn recover(&mut self) -> bool {
- self.is_recovering = match malformed_message(&self.buf[..]) {
+ self.is_recovering = match malformed_message(&self.buffer[..]) {
Ok((tail, ())) => {
- let length = tail.len();
- self.buf.consume(self.buf.len() - length);
+ let remaining = tail.len();
+ self.buffer.advance(self.buffer.len() - remaining);
false
}
_ => {
- self.buf.consume(self.buf.len());
+ self.buffer.clear();
true
}
};
!self.is_recovering
}
- pub fn read_from<R: Read>(&mut self, stream: &mut R) -> Result<usize> {
- let count = self.buf.read_from(stream)?;
- if count > 0 && self.is_recovering {
- self.recover();
- }
- Ok(count)
- }
-
- pub fn extract_messages(&mut self) -> Vec<HwProtocolMessage> {
- let mut messages = vec![];
- if !self.is_recovering {
- while !self.buf.is_empty() {
- match message(&self.buf[..]) {
- Ok((tail, message)) => {
- messages.push(message);
- let length = tail.len();
- self.buf.consume(self.buf.len() - length);
- }
- Err(nom::Err::Incomplete(_)) => break,
- Err(nom::Err::Failure(e) | nom::Err::Error(e)) => {
- debug!("Invalid message: {:?}", e);
- if !self.recover() || self.buf.is_empty() {
- break;
- }
- }
+ fn extract_message(&mut self) -> Option<HwProtocolMessage> {
+ if !self.is_recovering || self.recover() {
+ match message(&self.buffer[..]) {
+ Ok((tail, message)) => {
+ let remaining = tail.len();
+ self.buffer.advance(self.buffer.len() - remaining);
+ return Some(message);
+ }
+ Err(nom::Err::Incomplete(_)) => {}
+ Err(nom::Err::Failure(e) | nom::Err::Error(e)) => {
+ debug!("Invalid message: {:?}", e);
+ self.recover();
}
}
}
- messages
+ None
+ }
+
+ pub async fn read_from<R: AsyncReadExt + Unpin>(
+ &mut self,
+ stream: &mut R,
+ ) -> Option<HwProtocolMessage> {
+ loop {
+ if !self.buffer.has_remaining() {
+ let count = stream.read_buf(&mut self.buffer).await.ok()?;
+ if count == 0 {
+ return None;
+ }
+ }
+ while !self.buffer.is_empty() {
+ if let Some(result) = self.extract_message() {
+ return Some(result);
+ }
+ }
+ }
}
}