rust/hedgewars-server/src/server/network.rs
changeset 15174 e705ac360785
parent 15173 21e87882df1c
child 15175 f1c2289d40bd
--- a/rust/hedgewars-server/src/server/network.rs	Wed Jun 19 00:49:45 2019 +0300
+++ b/rust/hedgewars-server/src/server/network.rs	Wed Jun 19 01:19:10 2019 +0300
@@ -84,6 +84,7 @@
     decoder: ProtocolDecoder,
     buf_out: netbuf::Buf,
     timeout: timer::Timeout,
+    pending_close: bool,
 }
 
 impl NetworkClient {
@@ -100,6 +101,7 @@
             decoder: ProtocolDecoder::new(),
             buf_out: netbuf::Buf::new(),
             timeout,
+            pending_close: false,
         }
     }
 
@@ -185,11 +187,20 @@
         }
     }
 
-    fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> {
+    fn write_impl<W: Write>(
+        buf_out: &mut netbuf::Buf,
+        destination: &mut W,
+        close_on_empty: bool,
+    ) -> NetworkResult<()> {
         let result = loop {
             match buf_out.write_to(destination) {
                 Ok(bytes) if buf_out.is_empty() || bytes == 0 => {
-                    break Ok(((), NetworkClientState::Idle));
+                    let status = if buf_out.is_empty() && close_on_empty {
+                        NetworkClientState::Closed
+                    } else {
+                        NetworkClientState::Idle
+                    };
+                    break Ok(((), status));
                 }
                 Ok(_) => (),
                 Err(ref error)
@@ -207,7 +218,7 @@
     pub fn write(&mut self) -> NetworkResult<()> {
         let result = match self.socket {
             ClientSocket::Plain(ref mut stream) => {
-                NetworkClient::write_impl(&mut self.buf_out, stream)
+                NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close)
             }
             #[cfg(feature = "tls-connections")]
             ClientSocket::SslHandshake(ref mut handshake_opt) => {
@@ -216,7 +227,7 @@
             }
             #[cfg(feature = "tls-connections")]
             ClientSocket::SslStream(ref mut stream) => {
-                NetworkClient::write_impl(&mut self.buf_out, stream)
+                NetworkClient::write_impl(&mut self.buf_out, stream, self.pending_close)
             }
         };
 
@@ -235,6 +246,10 @@
     pub fn replace_timeout(&mut self, timeout: timer::Timeout) -> timer::Timeout {
         replace(&mut self.timeout, timeout)
     }
+
+    pub fn has_pending_sends(&self) -> bool {
+        !self.buf_out.is_empty()
+    }
 }
 
 #[cfg(feature = "tls-connections")]
@@ -349,11 +364,28 @@
     }
 
     fn deregister_client(&mut self, poll: &Poll, id: ClientId) {
-        if let Some(ref client) = self.clients.get(id) {
+        if let Some(ref mut client) = self.clients.get_mut(id) {
             poll.deregister(client.socket.inner())
                 .expect("could not deregister socket");
-            info!("client {} ({}) removed", client.id, client.peer_addr);
-            self.clients.remove(id);
+            if client.has_pending_sends() {
+                info!(
+                    "client {} ({}) pending removal",
+                    client.id, client.peer_addr
+                );
+                client.pending_close = true;
+                poll.register(
+                    client.socket.inner(),
+                    Token(id),
+                    Ready::writable(),
+                    PollOpt::edge(),
+                )
+                .unwrap_or_else(|_| {
+                    self.clients.remove(id);
+                });
+            } else {
+                info!("client {} ({}) removed", client.id, client.peer_addr);
+                self.clients.remove(id);
+            }
             #[cfg(feature = "official-server")]
             self.io.cancel(id);
         }
@@ -583,7 +615,10 @@
             Ok(((), state)) if state == NetworkClientState::NeedsWrite => {
                 self.pending.insert((client_id, state));
             }
-            Ok(_) => {}
+            Ok(((), state)) if state == NetworkClientState::Closed => {
+                self.deregister_client(poll, client_id);
+            }
+            Ok(_) => (),
             Err(e) => {
                 self.operation_failed(poll, client_id, &e, "Error while writing to client socket")?
             }