1 use bytes::{Buf, Bytes}; |
1 use bytes::{Buf, Bytes}; |
2 use log::*; |
2 use log::*; |
3 use slab::Slab; |
3 use slab::Slab; |
|
4 use std::io::Error; |
|
5 use std::pin::Pin; |
|
6 use std::task::{Context, Poll}; |
4 use std::{ |
7 use std::{ |
5 iter::Iterator, |
8 iter::Iterator, |
6 net::{IpAddr, SocketAddr}, |
9 net::{IpAddr, SocketAddr}, |
7 time::Duration, |
10 time::Duration, |
8 }; |
11 }; |
9 use tokio::{ |
12 use tokio::{ |
10 io::AsyncReadExt, |
13 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, |
11 net::{TcpListener, TcpStream}, |
14 net::{TcpListener, TcpStream}, |
12 sync::mpsc::{channel, Receiver, Sender}, |
15 sync::mpsc::{channel, Receiver, Sender}, |
13 }; |
16 }; |
|
17 #[cfg(feature = "tls-connections")] |
|
18 use tokio_native_tls::{TlsAcceptor, TlsStream}; |
14 |
19 |
15 use crate::{ |
20 use crate::{ |
16 core::{ |
21 core::{ |
17 events::{TimedEvents, Timeout}, |
22 events::{TimedEvents, Timeout}, |
18 types::ClientId, |
23 types::ClientId, |
54 .await |
58 .await |
55 .is_ok() |
59 .is_ok() |
56 } |
60 } |
57 } |
61 } |
58 |
62 |
|
63 enum ClientStream { |
|
64 Tcp(TcpStream), |
|
65 #[cfg(feature = "tls-connections")] |
|
66 Tls(TlsStream<TcpStream>), |
|
67 } |
|
68 |
|
69 impl Unpin for ClientStream {} |
|
70 |
|
71 impl AsyncRead for ClientStream { |
|
72 fn poll_read( |
|
73 self: Pin<&mut Self>, |
|
74 cx: &mut Context<'_>, |
|
75 buf: &mut ReadBuf<'_>, |
|
76 ) -> Poll<std::io::Result<()>> { |
|
77 use ClientStream::*; |
|
78 match Pin::into_inner(self) { |
|
79 Tcp(stream) => Pin::new(stream).poll_read(cx, buf), |
|
80 #[cfg(feature = "tls-connections")] |
|
81 Tls(stream) => Pin::new(stream).poll_read(cx, buf), |
|
82 } |
|
83 } |
|
84 } |
|
85 |
|
86 impl AsyncWrite for ClientStream { |
|
87 fn poll_write( |
|
88 self: Pin<&mut Self>, |
|
89 cx: &mut Context<'_>, |
|
90 buf: &[u8], |
|
91 ) -> Poll<Result<usize, Error>> { |
|
92 use ClientStream::*; |
|
93 match Pin::into_inner(self) { |
|
94 Tcp(stream) => Pin::new(stream).poll_write(cx, buf), |
|
95 #[cfg(feature = "tls-connections")] |
|
96 Tls(stream) => Pin::new(stream).poll_write(cx, buf), |
|
97 } |
|
98 } |
|
99 |
|
100 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { |
|
101 use ClientStream::*; |
|
102 match Pin::into_inner(self) { |
|
103 Tcp(stream) => Pin::new(stream).poll_flush(cx), |
|
104 #[cfg(feature = "tls-connections")] |
|
105 Tls(stream) => Pin::new(stream).poll_flush(cx), |
|
106 } |
|
107 } |
|
108 |
|
109 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { |
|
110 use ClientStream::*; |
|
111 match Pin::into_inner(self) { |
|
112 Tcp(stream) => Pin::new(stream).poll_shutdown(cx), |
|
113 #[cfg(feature = "tls-connections")] |
|
114 Tls(stream) => Pin::new(stream).poll_shutdown(cx), |
|
115 } |
|
116 } |
|
117 } |
|
118 |
59 struct NetworkClient { |
119 struct NetworkClient { |
60 id: ClientId, |
120 id: ClientId, |
61 socket: TcpStream, |
121 stream: ClientStream, |
62 receiver: Receiver<Bytes>, |
122 receiver: Receiver<Bytes>, |
63 peer_addr: SocketAddr, |
123 peer_addr: SocketAddr, |
64 decoder: ProtocolDecoder, |
124 decoder: ProtocolDecoder, |
65 } |
125 } |
66 |
126 |
67 impl NetworkClient { |
127 impl NetworkClient { |
68 fn new( |
128 fn new( |
69 id: ClientId, |
129 id: ClientId, |
70 socket: TcpStream, |
130 stream: ClientStream, |
71 peer_addr: SocketAddr, |
131 peer_addr: SocketAddr, |
72 receiver: Receiver<Bytes>, |
132 receiver: Receiver<Bytes>, |
73 ) -> Self { |
133 ) -> Self { |
74 Self { |
134 Self { |
75 id, |
135 id, |
76 socket, |
136 stream, |
77 peer_addr, |
137 peer_addr, |
78 receiver, |
138 receiver, |
79 decoder: ProtocolDecoder::new(PING_TIMEOUT), |
139 decoder: ProtocolDecoder::new(PING_TIMEOUT), |
80 } |
140 } |
81 } |
141 } |
82 |
142 |
83 async fn read( |
143 async fn read<T: AsyncRead + AsyncWrite + Unpin>( |
84 socket: &mut TcpStream, |
144 stream: &mut T, |
85 decoder: &mut ProtocolDecoder, |
145 decoder: &mut ProtocolDecoder, |
86 ) -> protocol::Result<HwProtocolMessage> { |
146 ) -> protocol::Result<HwProtocolMessage> { |
87 let result = decoder.read_from(socket).await; |
147 let result = decoder.read_from(stream).await; |
88 if matches!(result, Err(ProtocolError::Timeout)) { |
148 if matches!(result, Err(ProtocolError::Timeout)) { |
89 if Self::write(socket, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await { |
149 if Self::write(stream, Bytes::from(HwServerMessage::Ping.to_raw_protocol())).await { |
90 decoder.read_from(socket).await |
150 decoder.read_from(stream).await |
91 } else { |
151 } else { |
92 Err(ProtocolError::Eof) |
152 Err(ProtocolError::Eof) |
93 } |
153 } |
94 } else { |
154 } else { |
95 result |
155 result |
96 } |
156 } |
97 } |
157 } |
98 |
158 |
99 async fn write(socket: &mut TcpStream, mut data: Bytes) -> bool { |
159 async fn write<T: AsyncWrite + Unpin>(stream: &mut T, mut data: Bytes) -> bool { |
100 !data.has_remaining() || matches!(socket.write_buf(&mut data).await, Ok(n) if n > 0) |
160 !data.has_remaining() || matches!(stream.write_buf(&mut data).await, Ok(n) if n > 0) |
101 } |
161 } |
102 |
162 |
103 async fn run(mut self, sender: Sender<ClientUpdate>) { |
163 async fn run(mut self, sender: Sender<ClientUpdate>) { |
104 use ClientUpdateData::*; |
164 use ClientUpdateData::*; |
105 let mut sender = ClientUpdateSender { |
165 let mut sender = ClientUpdateSender { |
109 |
169 |
110 loop { |
170 loop { |
111 tokio::select! { |
171 tokio::select! { |
112 server_message = self.receiver.recv() => { |
172 server_message = self.receiver.recv() => { |
113 match server_message { |
173 match server_message { |
114 Some(message) => if !Self::write(&mut self.socket, message).await { |
174 Some(message) => if !Self::write(&mut self.stream, message).await { |
115 sender.send(Error("Connection reset by peer".to_string())).await; |
175 sender.send(Error("Connection reset by peer".to_string())).await; |
116 break; |
176 break; |
117 } |
177 } |
118 None => { |
178 None => { |
119 break; |
179 break; |
120 } |
180 } |
121 } |
181 } |
122 } |
182 } |
123 client_message = Self::read(&mut self.socket, &mut self.decoder) => { |
183 client_message = Self::read(&mut self.stream, &mut self.decoder) => { |
124 match client_message { |
184 match client_message { |
125 Ok(message) => { |
185 Ok(message) => { |
126 if !sender.send(Message(message)).await { |
186 if !sender.send(Message(message)).await { |
127 break; |
187 break; |
128 } |
188 } |
129 } |
189 } |
130 Err(e) => { |
190 Err(e) => { |
131 sender.send(Error(format!("{}", e))).await; |
191 sender.send(Error(format!("{}", e))).await; |
132 if matches!(e, ProtocolError::Timeout) { |
192 if matches!(e, ProtocolError::Timeout) { |
133 Self::write(&mut self.socket, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await; |
193 Self::write(&mut self.stream, Bytes::from(HwServerMessage::Bye("Ping timeout".to_string()).to_raw_protocol())).await; |
134 } |
194 } |
135 break; |
195 break; |
136 } |
196 } |
137 } |
197 } |
138 } |
198 } |
139 } |
199 } |
140 } |
200 } |
141 } |
201 } |
142 } |
202 } |
143 |
203 |
|
204 #[cfg(feature = "tls-connections")] |
|
205 struct TlsListener { |
|
206 listener: TcpListener, |
|
207 acceptor: TlsAcceptor, |
|
208 } |
|
209 |
144 pub struct NetworkLayer { |
210 pub struct NetworkLayer { |
145 listener: TcpListener, |
211 listener: TcpListener, |
|
212 #[cfg(feature = "tls-connections")] |
|
213 tls: TlsListener, |
146 server_state: ServerState, |
214 server_state: ServerState, |
147 clients: Slab<Sender<Bytes>>, |
215 clients: Slab<Sender<Bytes>>, |
148 } |
216 } |
149 |
217 |
150 impl NetworkLayer { |
218 impl NetworkLayer { |
151 pub async fn run(&mut self) { |
219 pub async fn run(&mut self) { |
152 let (update_tx, mut update_rx) = channel(128); |
220 let (update_tx, mut update_rx) = channel(128); |
153 |
221 |
154 loop { |
222 async fn accept_plain_branch( |
155 tokio::select! { |
223 layer: &mut NetworkLayer, |
156 Ok((stream, addr)) = self.listener.accept() => { |
224 value: (TcpStream, SocketAddr), |
157 if let Some(client) = self.create_client(stream, addr).await { |
225 update_tx: Sender<ClientUpdate>, |
|
226 ) { |
|
227 let (stream, addr) = value; |
|
228 if let Some(client) = layer.create_client(ClientStream::Tcp(stream), addr).await { |
|
229 tokio::spawn(client.run(update_tx)); |
|
230 } |
|
231 } |
|
232 |
|
233 #[cfg(feature = "tls-connections")] |
|
234 async fn accept_tls_branch( |
|
235 layer: &mut NetworkLayer, |
|
236 value: (TcpStream, SocketAddr), |
|
237 update_tx: Sender<ClientUpdate>, |
|
238 ) { |
|
239 let (stream, addr) = value; |
|
240 match layer.tls.acceptor.accept(stream).await { |
|
241 Ok(stream) => { |
|
242 if let Some(client) = layer.create_client(ClientStream::Tls(stream), addr).await |
|
243 { |
158 tokio::spawn(client.run(update_tx.clone())); |
244 tokio::spawn(client.run(update_tx.clone())); |
159 } |
245 } |
160 } |
246 } |
161 client_message = update_rx.recv(), if !self.clients.is_empty() => { |
247 Err(e) => { |
162 use ClientUpdateData::*; |
248 warn!("Unable to establish TLS connection: {}", e); |
163 match client_message { |
249 } |
164 Some(ClientUpdate{ client_id, data: Message(message) } ) => { |
250 } |
165 self.handle_message(client_id, message).await; |
251 } |
166 } |
252 |
167 Some(ClientUpdate{ client_id, data: Error(e) } ) => { |
253 async fn client_message_branch( |
168 let mut response = handlers::Response::new(client_id); |
254 layer: &mut NetworkLayer, |
169 info!("Client {} error: {:?}", client_id, e); |
255 client_message: Option<ClientUpdate>, |
170 response.remove_client(client_id); |
256 ) { |
171 handlers::handle_client_loss(&mut self.server_state, client_id, &mut response); |
257 use ClientUpdateData::*; |
172 self.handle_response(response).await; |
258 match client_message { |
173 } |
259 Some(ClientUpdate { |
174 None => unreachable!() |
260 client_id, |
175 } |
261 data: Message(message), |
176 } |
262 }) => { |
|
263 layer.handle_message(client_id, message).await; |
|
264 } |
|
265 Some(ClientUpdate { |
|
266 client_id, |
|
267 data: Error(e), |
|
268 }) => { |
|
269 let mut response = handlers::Response::new(client_id); |
|
270 info!("Client {} error: {:?}", client_id, e); |
|
271 response.remove_client(client_id); |
|
272 handlers::handle_client_loss(&mut layer.server_state, client_id, &mut response); |
|
273 layer.handle_response(response).await; |
|
274 } |
|
275 None => unreachable!(), |
|
276 } |
|
277 } |
|
278 |
|
279 loop { |
|
280 #[cfg(not(feature = "tls-connections"))] |
|
281 tokio::select! { |
|
282 Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await, |
|
283 client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await |
|
284 } |
|
285 |
|
286 #[cfg(feature = "tls-connections")] |
|
287 tokio::select! { |
|
288 Ok(value) = self.listener.accept() => accept_plain_branch(self, value, update_tx.clone()).await, |
|
289 Ok(value) = self.tls.listener.accept() => accept_tls_branch(self, value, update_tx.clone()).await, |
|
290 client_message = update_rx.recv(), if !self.clients.is_empty() => client_message_branch(self, client_message).await |
177 } |
291 } |
178 } |
292 } |
179 } |
293 } |
180 |
294 |
181 async fn create_client( |
295 async fn create_client( |
182 &mut self, |
296 &mut self, |
183 stream: TcpStream, |
297 stream: ClientStream, |
184 addr: SocketAddr, |
298 addr: SocketAddr, |
185 ) -> Option<NetworkClient> { |
299 ) -> Option<NetworkClient> { |
186 let entry = self.clients.vacant_entry(); |
300 let entry = self.clients.vacant_entry(); |
187 let client_id = entry.key(); |
301 let client_id = entry.key(); |
188 let (tx, rx) = channel(16); |
302 let (tx, rx) = channel(16); |