Begin attempt to convert checker into async using tokio
authorunc0rr
Sat, 10 Jul 2021 12:03:50 +0200
changeset 15835 ad79e5c0885c
parent 15834 8c39a11f7756
child 15836 d9db7b763bd1
Begin attempt to convert checker into async using tokio
rust/hedgewars-checker/Cargo.toml
rust/hedgewars-checker/src/main.rs
--- a/rust/hedgewars-checker/Cargo.toml	Thu Jul 01 00:17:05 2021 +0300
+++ b/rust/hedgewars-checker/Cargo.toml	Sat Jul 10 12:03:50 2021 +0200
@@ -15,3 +15,4 @@
 base64 = "0.13"
 hedgewars-network-protocol = { path = "../hedgewars-network-protocol" }
 anyhow = "1.0"
+tokio = {version="1.6", features = ["full"]}
--- a/rust/hedgewars-checker/src/main.rs	Thu Jul 01 00:17:05 2021 +0300
+++ b/rust/hedgewars-checker/src/main.rs	Sat Jul 10 12:03:50 2021 +0200
@@ -6,9 +6,10 @@
 use ini::Ini;
 use log::{debug, info, warn};
 use netbuf::Buf;
-use std::{io::Write, net::TcpStream, process::Command, str::FromStr};
+use std::{io::Write, str::FromStr};
+use tokio::{io, io::AsyncWriteExt, net::TcpStream, process::Command};
 
-fn check(executable: &str, data_prefix: &str, buffer: &[String]) -> Result<Vec<String>> {
+async fn check(executable: &str, data_prefix: &str, buffer: &[String]) -> Result<Vec<String>> {
     let mut replay = tempfile::NamedTempFile::new()?;
 
     for line in buffer.into_iter() {
@@ -31,7 +32,12 @@
         .arg("--nosound")
         .arg("--stats-only")
         .arg(temp_file_path)
-        .output()?;
+        //.spawn()?
+        //.wait_with_output()
+        .output()
+        .await?;
+
+    debug!("Engine finished!");
 
     let mut result = Vec::new();
 
@@ -40,6 +46,8 @@
         .split(|b| *b == '\n' as u8)
         .skip_while(|l| *l != b"WINNERS" && *l != b"DRAW");
 
+    debug!("Engine lines: {:?}", &engine_lines);
+
     loop {
         match engine_lines.next() {
             Some(b"DRAW") => result.push("DRAW".to_owned()),
@@ -73,6 +81,8 @@
         }
     }
 
+    println!("Engine lines: {:?}", &result);
+
     if result.len() > 0 {
         Ok(result)
     } else {
@@ -80,7 +90,7 @@
     }
 }
 
-fn connect_and_run(
+async fn connect_and_run(
     username: &str,
     password: &str,
     protocol_number: u16,
@@ -89,63 +99,110 @@
 ) -> Result<()> {
     info!("Connecting...");
 
-    let mut stream = TcpStream::connect("hedgewars.org:46631")?;
-    stream.set_nonblocking(false)?;
+    let mut stream = TcpStream::connect("hedgewars.org:46631").await?;
 
     let mut buf = Buf::new();
 
+    let mut replay_lines: Option<Vec<String>> = None;
+
     loop {
-        buf.read_from(&mut stream)?;
+        let r = if let Some(ref lines) = replay_lines {
+            let r = tokio::select! {
+                _ = stream.readable() => None,
+                r = check(executable, data_prefix, &lines) => Some(r)
+            };
+
+            r
+        } else {
+            stream.readable().await?;
+            None
+        };
+
+        println!("Loop: {:?}", &r);
+
+        if let Some(execution_result) = r {
+            replay_lines = None;
+
+            match execution_result {
+                Ok(result) => {
+                    info!("Checked");
+                    debug!("Check result: [{:?}]", result);
+
+                    stream
+                        .write(
+                            ClientMessage::CheckedOk(result)
+                                .to_raw_protocol()
+                                .as_bytes(),
+                        )
+                        .await?;
+                    stream
+                        .write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())
+                        .await?;
+                }
+                Err(e) => {
+                    info!("Check failed: {:?}", e);
+                    stream
+                        .write(
+                            ClientMessage::CheckedFail("error".to_owned())
+                                .to_raw_protocol()
+                                .as_bytes(),
+                        )
+                        .await?;
+                    stream
+                        .write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())
+                        .await?;
+                }
+            }
+        } else {
+            let mut msg = [0; 4096];
+            // Try to read data, this may still fail with `WouldBlock`
+            // if the readiness event is a false positive.
+            match stream.try_read(&mut msg) {
+                Ok(n) => {
+                    //println!("{:?}", &msg);
+                    buf.write_all(&msg[0..n])?;
+                }
+                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
+                Err(e) => {
+                    return Err(e.into());
+                }
+            }
+        }
 
         while let Ok((tail, msg)) = parser::server_message(buf.as_ref()) {
             let tail_len = tail.len();
             buf.consume(buf.len() - tail_len);
 
+            println!("Message from server: {:?}", &msg);
+
             match msg {
                 Connected(_, _) => {
                     info!("Connected");
-                    stream.write(
-                        ClientMessage::Checker(
-                            protocol_number,
-                            username.to_owned(),
-                            password.to_owned(),
+                    stream
+                        .write(
+                            ClientMessage::Checker(
+                                protocol_number,
+                                username.to_owned(),
+                                password.to_owned(),
+                            )
+                            .to_raw_protocol()
+                            .as_bytes(),
                         )
-                        .to_raw_protocol()
-                        .as_bytes(),
-                    )?;
+                        .await?;
                 }
                 Ping => {
-                    stream.write(ClientMessage::Pong.to_raw_protocol().as_bytes())?;
+                    stream
+                        .write(ClientMessage::Pong.to_raw_protocol().as_bytes())
+                        .await?;
                 }
                 LogonPassed => {
-                    stream.write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())?;
+                    stream
+                        .write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())
+                        .await?;
                 }
                 Replay(lines) => {
                     info!("Got a replay");
-                    match check(executable, data_prefix, &lines) {
-                        Ok(result) => {
-                            info!("Checked");
-                            debug!("Check result: [{:?}]", result);
-
-                            stream.write(
-                                ClientMessage::CheckedOk(result)
-                                    .to_raw_protocol()
-                                    .as_bytes(),
-                            )?;
-                            stream
-                                .write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())?;
-                        }
-                        Err(e) => {
-                            info!("Check failed: {:?}", e);
-                            stream.write(
-                                ClientMessage::CheckedFail("error".to_owned())
-                                    .to_raw_protocol()
-                                    .as_bytes(),
-                            )?;
-                            stream
-                                .write(ClientMessage::CheckerReady.to_raw_protocol().as_bytes())?;
-                        }
-                    }
+                    replay_lines = Some(lines);
                 }
                 Bye(message) => {
                     warn!("Received BYE: {}", message);
@@ -181,13 +238,14 @@
     }
 }
 
-fn get_protocol_number(executable: &str) -> std::io::Result<u16> {
-    let output = Command::new(executable).arg("--protocol").output()?;
+async fn get_protocol_number(executable: &str) -> Result<u16> {
+    let output = Command::new(executable).arg("--protocol").output().await?;
 
     Ok(u16::from_str(&String::from_utf8(output.stdout).unwrap().trim()).unwrap_or(55))
 }
 
-fn main() {
+#[tokio::main]
+async fn main() -> Result<()> {
     stderrlog::new()
         .verbosity(3)
         .timestamp(stderrlog::Timestamp::Second)
@@ -217,9 +275,9 @@
     info!("Executable: {}", exe);
     info!("Data dir: {}", prefix);
 
-    let protocol_number = get_protocol_number(&exe.as_str()).unwrap_or_default();
+    let protocol_number = get_protocol_number(&exe.as_str()).await.unwrap_or_default();
 
     info!("Using protocol number {}", protocol_number);
 
-    connect_and_run(&username, &password, protocol_number, &exe, &prefix).unwrap();
+    connect_and_run(&username, &password, protocol_number, &exe, &prefix).await
 }