separate out DB task
authoralfadur
Sat, 25 Mar 2023 03:29:22 +0300
changeset 15967 e514ceb5e7d6
parent 15966 c5c53ebb2d91
child 15968 ce47259d5c86
separate out DB task
rust/hedgewars-server/src/handlers/inanteroom.rs
rust/hedgewars-server/src/server/database.rs
rust/hedgewars-server/src/server/io.rs
rust/hedgewars-server/src/server/network.rs
--- a/rust/hedgewars-server/src/handlers/inanteroom.rs	Fri Mar 24 03:26:08 2023 +0300
+++ b/rust/hedgewars-server/src/handlers/inanteroom.rs	Sat Mar 25 03:29:22 2023 +0300
@@ -13,8 +13,6 @@
     HwProtocolMessage, HwProtocolMessage::LoadRoom, HwServerMessage::*,
 };
 use log::*;
-#[cfg(feature = "official-server")]
-use openssl::sha::sha1;
 use std::{
     fmt::{Formatter, LowerHex},
     num::NonZeroU16,
--- a/rust/hedgewars-server/src/server/database.rs	Fri Mar 24 03:26:08 2023 +0300
+++ b/rust/hedgewars-server/src/server/database.rs	Sat Mar 25 03:29:22 2023 +0300
@@ -1,5 +1,6 @@
 use mysql_async::{self, from_row_opt, params, prelude::*, Pool};
 use sha1::{Digest, Sha1};
+use tokio::sync::mpsc::{channel, Receiver, Sender};
 
 use crate::handlers::{AccountInfo, Sha1Digest};
 
@@ -25,14 +26,96 @@
 
 pub struct Achievements {}
 
+pub enum DatabaseQuery {
+    CheckRegistered {
+        nick: String,
+    },
+    GetAccount {
+        nick: String,
+        protocol: u16,
+        password_hash: String,
+        client_salt: String,
+        server_salt: String,
+    },
+    GetCheckerAccount {
+        nick: String,
+        password: String,
+    },
+    GetReplayFilename {
+        id: u32,
+    },
+}
+
+pub enum DatabaseResponse {
+    AccountRegistered(bool),
+    Account(Option<AccountInfo>),
+    CheckerAccount { is_registered: bool },
+}
+
 pub struct Database {
     pool: Pool,
+    query_rx: Receiver<DatabaseQuery>,
+    response_tx: Sender<DatabaseResponse>,
 }
 
 impl Database {
     pub fn new(url: &str) -> Self {
+        let (query_tx, query_rx) = channel(32);
+        let (response_tx, response_rx) = channel(32);
         Self {
             pool: Pool::new(url),
+            query_rx,
+            response_tx,
+        }
+    }
+
+    pub async fn run(&mut self) {
+        use DatabaseResponse::*;
+        loop {
+            let query = self.query_rx.recv().await;
+            if let Some(query) = query {
+                match query {
+                    DatabaseQuery::CheckRegistered { nick } => {
+                        let is_registered = self.get_is_registered(&nick).await.unwrap_or(false);
+                        self.response_tx
+                            .send(AccountRegistered(is_registered))
+                            .await;
+                    }
+                    DatabaseQuery::GetAccount {
+                        nick,
+                        protocol,
+                        password_hash,
+                        client_salt,
+                        server_salt,
+                    } => {
+                        let account = self
+                            .get_account(
+                                &nick,
+                                protocol,
+                                &password_hash,
+                                &client_salt,
+                                &server_salt,
+                            )
+                            .await
+                            .unwrap_or(None);
+                        self.response_tx.send(Account(account)).await;
+                    }
+                    DatabaseQuery::GetCheckerAccount { nick, password } => {
+                        let is_registered = self
+                            .get_checker_account(&nick, &password)
+                            .await
+                            .unwrap_or(false);
+                        self.response_tx
+                            .send(CheckerAccount { is_registered })
+                            .await;
+                    }
+                    DatabaseQuery::GetReplayFilename { id } => {
+                        let filename = self.get_replay_name(id).await;
+                    }
+                };
+            } else {
+                break;
+            }
         }
     }
 
@@ -40,9 +123,9 @@
         let mut connection = self.pool.get_conn().await?;
         let result = CHECK_ACCOUNT_EXISTS_QUERY
             .with(params! { "username" => nick })
-            .first(&mut connection)
+            .first::<u32, _>(&mut connection)
             .await?;
-        Ok(!result.is_empty())
+        Ok(!result.is_some())
     }
 
     pub async fn get_account(
--- a/rust/hedgewars-server/src/server/io.rs	Fri Mar 24 03:26:08 2023 +0300
+++ b/rust/hedgewars-server/src/server/io.rs	Sat Mar 25 03:29:22 2023 +0300
@@ -2,6 +2,7 @@
     fs::{File, OpenOptions},
     io::{Error, ErrorKind, Read, Result, Write},
     sync::{mpsc, Arc},
+    task::Waker,
     thread,
 };
 
@@ -23,8 +24,7 @@
         let (core_tx, io_rx) = mpsc::channel();
         let (io_tx, core_rx) = mpsc::channel();
 
-        let mut db = Database::new();
-        db.connect("localhost");
+        /*let mut db = Database::new("localhost");
 
         thread::spawn(move || {
             while let Ok((request_id, task)) = io_rx.recv() {
@@ -138,7 +138,7 @@
                 io_tx.send((request_id, response));
                 waker.wake();
             }
-        });
+        });*/
 
         Self { core_rx, core_tx }
     }
--- a/rust/hedgewars-server/src/server/network.rs	Fri Mar 24 03:26:08 2023 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Sat Mar 25 03:29:22 2023 +0300
@@ -241,7 +241,7 @@
                 Ok(stream) => {
                     if let Some(client) = layer.create_client(ClientStream::Tls(stream), addr).await
                     {
-                        tokio::spawn(client.run(update_tx.clone()));
+                        tokio::spawn(client.run(update_tx));
                     }
                 }
                 Err(e) => {