46 pub enum NetworkClientState { |
46 pub enum NetworkClientState { |
47 Idle, |
47 Idle, |
48 NeedsWrite, |
48 NeedsWrite, |
49 NeedsRead, |
49 NeedsRead, |
50 Closed, |
50 Closed, |
|
51 #[cfg(feature = "tls-connections")] |
|
52 Connected, |
51 } |
53 } |
52 |
54 |
53 type NetworkResult<T> = io::Result<(T, NetworkClientState)>; |
55 type NetworkResult<T> = io::Result<(T, NetworkClientState)>; |
54 |
56 |
55 #[cfg(not(feature = "tls-connections"))] |
|
56 pub enum ClientSocket { |
57 pub enum ClientSocket { |
57 Plain(TcpStream), |
58 Plain(TcpStream), |
58 } |
59 #[cfg(feature = "tls-connections")] |
59 |
|
60 #[cfg(feature = "tls-connections")] |
|
61 pub enum ClientSocket { |
|
62 SslHandshake(Option<MidHandshakeSslStream<TcpStream>>), |
60 SslHandshake(Option<MidHandshakeSslStream<TcpStream>>), |
|
61 #[cfg(feature = "tls-connections")] |
63 SslStream(SslStream<TcpStream>), |
62 SslStream(SslStream<TcpStream>), |
64 } |
63 } |
65 |
64 |
66 impl ClientSocket { |
65 impl ClientSocket { |
67 fn inner(&self) -> &TcpStream { |
66 fn inner(&self) -> &TcpStream { |
68 #[cfg(not(feature = "tls-connections"))] |
|
69 match self { |
67 match self { |
70 ClientSocket::Plain(stream) => stream, |
68 ClientSocket::Plain(stream) => stream, |
71 } |
69 #[cfg(feature = "tls-connections")] |
72 |
|
73 #[cfg(feature = "tls-connections")] |
|
74 match self { |
|
75 ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), |
70 ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), |
|
71 #[cfg(feature = "tls-connections")] |
76 ClientSocket::SslHandshake(None) => unreachable!(), |
72 ClientSocket::SslHandshake(None) => unreachable!(), |
|
73 #[cfg(feature = "tls-connections")] |
77 ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(), |
74 ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref(), |
78 } |
75 } |
79 } |
76 } |
80 } |
77 } |
81 |
78 |
115 self.socket = ClientSocket::SslStream(stream); |
112 self.socket = ClientSocket::SslStream(stream); |
116 debug!( |
113 debug!( |
117 "TLS handshake with {} ({}) completed", |
114 "TLS handshake with {} ({}) completed", |
118 self.id, self.peer_addr |
115 self.id, self.peer_addr |
119 ); |
116 ); |
120 Ok(NetworkClientState::Idle) |
117 Ok(NetworkClientState::Connected) |
121 } |
118 } |
122 Err(HandshakeError::WouldBlock(new_handshake)) => { |
119 Err(HandshakeError::WouldBlock(new_handshake)) => { |
123 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
120 self.socket = ClientSocket::SslHandshake(Some(new_handshake)); |
124 Ok(NetworkClientState::Idle) |
121 Ok(NetworkClientState::Idle) |
125 } |
122 } |
169 }; |
166 }; |
170 result |
167 result |
171 } |
168 } |
172 |
169 |
173 pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> { |
170 pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> { |
174 #[cfg(not(feature = "tls-connections"))] |
|
175 match self.socket { |
171 match self.socket { |
176 ClientSocket::Plain(ref mut stream) => { |
172 ClientSocket::Plain(ref mut stream) => { |
177 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
173 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
178 } |
174 } |
179 } |
175 #[cfg(feature = "tls-connections")] |
180 |
|
181 #[cfg(feature = "tls-connections")] |
|
182 match self.socket { |
|
183 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
176 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
184 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
177 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
185 Ok((Vec::new(), self.handshake_impl(handshake)?)) |
178 Ok((Vec::new(), self.handshake_impl(handshake)?)) |
186 } |
179 } |
|
180 #[cfg(feature = "tls-connections")] |
187 ClientSocket::SslStream(ref mut stream) => { |
181 ClientSocket::SslStream(ref mut stream) => { |
188 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
182 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
189 } |
183 } |
190 } |
184 } |
191 } |
185 } |
208 }; |
202 }; |
209 result |
203 result |
210 } |
204 } |
211 |
205 |
212 pub fn write(&mut self) -> NetworkResult<()> { |
206 pub fn write(&mut self) -> NetworkResult<()> { |
213 let result = { |
207 let result = match self.socket { |
214 #[cfg(not(feature = "tls-connections"))] |
208 ClientSocket::Plain(ref mut stream) => { |
215 match self.socket { |
209 NetworkClient::write_impl(&mut self.buf_out, stream) |
216 ClientSocket::Plain(ref mut stream) => { |
210 } |
217 NetworkClient::write_impl(&mut self.buf_out, stream) |
211 #[cfg(feature = "tls-connections")] |
218 } |
212 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
219 } |
213 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
220 |
214 Ok(((), self.handshake_impl(handshake)?)) |
221 #[cfg(feature = "tls-connections")] |
215 } |
222 { |
216 #[cfg(feature = "tls-connections")] |
223 match self.socket { |
217 ClientSocket::SslStream(ref mut stream) => { |
224 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
218 NetworkClient::write_impl(&mut self.buf_out, stream) |
225 let handshake = std::mem::replace(handshake_opt, None).unwrap(); |
|
226 Ok(((), self.handshake_impl(handshake)?)) |
|
227 } |
|
228 ClientSocket::SslStream(ref mut stream) => { |
|
229 NetworkClient::write_impl(&mut self.buf_out, stream) |
|
230 } |
|
231 } |
|
232 } |
219 } |
233 }; |
220 }; |
234 |
221 |
235 self.socket.inner().flush()?; |
222 self.socket.inner().flush()?; |
236 result |
223 result |
322 #[cfg(feature = "official-server")] |
310 #[cfg(feature = "official-server")] |
323 io: IoLayer, |
311 io: IoLayer, |
324 timer: timer::Timer<TimerData>, |
312 timer: timer::Timer<TimerData>, |
325 } |
313 } |
326 |
314 |
|
315 fn register_read<E: Evented>(poll: &Poll, evented: &E, token: mio::Token) -> io::Result<()> { |
|
316 poll.register(evented, token, Ready::readable(), PollOpt::edge()) |
|
317 } |
|
318 |
327 fn create_ping_timeout( |
319 fn create_ping_timeout( |
328 timer: &mut timer::Timer<TimerData>, |
320 timer: &mut timer::Timer<TimerData>, |
329 probes_count: u8, |
321 probes_count: u8, |
330 client_id: ClientId, |
322 client_id: ClientId, |
331 ) -> timer::Timeout { |
323 ) -> timer::Timeout { |
341 TimerData(TimeoutEvent::DropClient, client_id), |
333 TimerData(TimeoutEvent::DropClient, client_id), |
342 ) |
334 ) |
343 } |
335 } |
344 |
336 |
345 impl NetworkLayer { |
337 impl NetworkLayer { |
346 pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { |
|
347 let server = HWServer::new(clients_limit, rooms_limit); |
|
348 let clients = Slab::with_capacity(clients_limit); |
|
349 let pending = HashSet::with_capacity(2 * clients_limit); |
|
350 let pending_cache = Vec::with_capacity(2 * clients_limit); |
|
351 let timer = timer::Builder::default().build(); |
|
352 |
|
353 NetworkLayer { |
|
354 listener, |
|
355 server, |
|
356 clients, |
|
357 pending, |
|
358 pending_cache, |
|
359 #[cfg(feature = "tls-connections")] |
|
360 ssl: NetworkLayer::create_ssl_context(), |
|
361 #[cfg(feature = "official-server")] |
|
362 io: IoLayer::new(), |
|
363 timer, |
|
364 } |
|
365 } |
|
366 |
|
367 #[cfg(feature = "tls-connections")] |
338 #[cfg(feature = "tls-connections")] |
368 fn create_ssl_context() -> ServerSsl { |
339 fn create_ssl_context(listener: TcpListener) -> ServerSsl { |
369 let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); |
340 let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); |
370 builder.set_verify(SslVerifyMode::NONE); |
341 builder.set_verify(SslVerifyMode::NONE); |
371 builder.set_read_ahead(true); |
342 builder.set_read_ahead(true); |
372 builder |
343 builder |
373 .set_certificate_file("ssl/cert.pem", SslFiletype::PEM) |
344 .set_certificate_file("ssl/cert.pem", SslFiletype::PEM) |
376 .set_private_key_file("ssl/key.pem", SslFiletype::PEM) |
347 .set_private_key_file("ssl/key.pem", SslFiletype::PEM) |
377 .unwrap(); |
348 .unwrap(); |
378 builder.set_options(SslOptions::NO_COMPRESSION); |
349 builder.set_options(SslOptions::NO_COMPRESSION); |
379 builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); |
350 builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); |
380 ServerSsl { |
351 ServerSsl { |
|
352 listener, |
381 context: builder.build(), |
353 context: builder.build(), |
382 } |
354 } |
383 } |
355 } |
384 |
356 |
385 pub fn register_server(&self, poll: &Poll) -> io::Result<()> { |
357 pub fn register(&self, poll: &Poll) -> io::Result<()> { |
386 poll.register( |
358 register_read(poll, &self.listener, utils::SERVER_TOKEN)?; |
387 &self.listener, |
359 #[cfg(feature = "tls-connections")] |
388 utils::SERVER_TOKEN, |
360 register_read(poll, &self.listener, utils::SECURE_SERVER_TOKEN)?; |
389 Ready::readable(), |
361 register_read(poll, &self.timer, utils::TIMER_TOKEN)?; |
390 PollOpt::edge(), |
|
391 )?; |
|
392 |
|
393 poll.register( |
|
394 &self.timer, |
|
395 utils::TIMER_TOKEN, |
|
396 Ready::readable(), |
|
397 PollOpt::edge(), |
|
398 )?; |
|
399 |
362 |
400 #[cfg(feature = "official-server")] |
363 #[cfg(feature = "official-server")] |
401 self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?; |
364 self.io.io_thread.register_rx(poll, utils::IO_TOKEN)?; |
402 |
365 |
403 Ok(()) |
366 Ok(()) |
446 |
409 |
447 client_id |
410 client_id |
448 } |
411 } |
449 |
412 |
450 fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) { |
413 fn handle_response(&mut self, mut response: handlers::Response, poll: &Poll) { |
|
414 if response.is_empty() { |
|
415 return; |
|
416 } |
|
417 |
451 debug!("{} pending server messages", response.len()); |
418 debug!("{} pending server messages", response.len()); |
452 let output = response.extract_messages(&mut self.server); |
419 let output = response.extract_messages(&mut self.server); |
453 for (clients, message) in output { |
420 for (clients, message) in output { |
454 debug!("Message {:?} to {:?}", message, clients); |
421 debug!("Message {:?} to {:?}", message, clients); |
455 let msg_string = message.to_raw_protocol(); |
422 let msg_string = message.to_raw_protocol(); |
510 handlers::handle_io_result(&mut self.server, client_id, &mut response, result); |
477 handlers::handle_io_result(&mut self.server, client_id, &mut response, result); |
511 } |
478 } |
512 } |
479 } |
513 |
480 |
514 fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
481 fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
515 #[cfg(not(feature = "tls-connections"))] |
482 Ok(ClientSocket::Plain(socket)) |
516 { |
483 } |
517 Ok(ClientSocket::Plain(socket)) |
484 |
518 } |
485 #[cfg(feature = "tls-connections")] |
519 |
486 fn create_client_secure_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
520 #[cfg(feature = "tls-connections")] |
487 let ssl = Ssl::new(&self.ssl.context).unwrap(); |
521 { |
488 let mut builder = SslStreamBuilder::new(ssl, socket); |
522 let ssl = Ssl::new(&self.ssl.context).unwrap(); |
489 builder.set_accept_state(); |
523 let mut builder = SslStreamBuilder::new(ssl, socket); |
490 match builder.handshake() { |
524 builder.set_accept_state(); |
491 Ok(stream) => Ok(ClientSocket::SslStream(stream)), |
525 match builder.handshake() { |
492 Err(HandshakeError::WouldBlock(stream)) => Ok(ClientSocket::SslHandshake(Some(stream))), |
526 Ok(stream) => Ok(ClientSocket::SslStream(stream)), |
493 Err(e) => { |
527 Err(HandshakeError::WouldBlock(stream)) => { |
494 debug!("OpenSSL handshake failed: {}", e); |
528 Ok(ClientSocket::SslHandshake(Some(stream))) |
495 Err(Error::new(ErrorKind::Other, "Connection failure")) |
529 } |
496 } |
530 Err(e) => { |
497 } |
531 debug!("OpenSSL handshake failed: {}", e); |
498 } |
532 Err(Error::new(ErrorKind::Other, "Connection failure")) |
499 |
533 } |
500 pub fn accept_client(&mut self, poll: &Poll, server_token: mio::Token) -> io::Result<()> { |
534 } |
|
535 } |
|
536 } |
|
537 |
|
538 pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { |
|
539 let (client_socket, addr) = self.listener.accept()?; |
501 let (client_socket, addr) = self.listener.accept()?; |
540 info!("Connected: {}", addr); |
502 info!("Connected: {}", addr); |
541 |
503 |
542 let client_id = self.register_client(poll, self.create_client_socket(client_socket)?, addr); |
504 match server_token { |
543 |
505 utils::SERVER_TOKEN => { |
544 let mut response = handlers::Response::new(client_id); |
506 let client_id = |
545 |
507 self.register_client(poll, self.create_client_socket(client_socket)?, addr); |
546 handlers::handle_client_accept(&mut self.server, client_id, &mut response); |
508 let mut response = handlers::Response::new(client_id); |
547 |
509 handlers::handle_client_accept(&mut self.server, client_id, &mut response); |
548 if !response.is_empty() { |
510 self.handle_response(response, poll); |
549 self.handle_response(response, poll); |
511 } |
|
512 #[cfg(feature = "tls-connections")] |
|
513 utils::SECURE_SERVER_TOKEN => { |
|
514 self.register_client(poll, self.create_client_secure_socket(client_socket)?, addr); |
|
515 } |
|
516 _ => unreachable!(), |
550 } |
517 } |
551 |
518 |
552 Ok(()) |
519 Ok(()) |
553 } |
520 } |
554 |
521 |
593 match state { |
560 match state { |
594 NetworkClientState::NeedsRead => { |
561 NetworkClientState::NeedsRead => { |
595 self.pending.insert((client_id, state)); |
562 self.pending.insert((client_id, state)); |
596 } |
563 } |
597 NetworkClientState::Closed => self.client_error(&poll, client_id)?, |
564 NetworkClientState::Closed => self.client_error(&poll, client_id)?, |
|
565 #[cfg(feature = "tls-connections")] |
|
566 NetworkClientState::Connected => { |
|
567 let mut response = handlers::Response::new(client_id); |
|
568 handlers::handle_client_accept(&mut self.server, client_id, &mut response); |
|
569 self.handle_response(response, poll); |
|
570 } |
598 _ => {} |
571 _ => {} |
599 }; |
572 }; |
600 } |
573 } |
601 Err(e) => self.operation_failed( |
574 Err(e) => self.operation_failed( |
602 poll, |
575 poll, |
661 swap(&mut cache, &mut self.pending_cache); |
632 swap(&mut cache, &mut self.pending_cache); |
662 } |
633 } |
663 Ok(()) |
634 Ok(()) |
664 } |
635 } |
665 } |
636 } |
|
637 |
|
638 pub struct NetworkLayerBuilder { |
|
639 listener: Option<TcpListener>, |
|
640 secure_listener: Option<TcpListener>, |
|
641 clients_capacity: usize, |
|
642 rooms_capacity: usize, |
|
643 } |
|
644 |
|
645 impl Default for NetworkLayerBuilder { |
|
646 fn default() -> Self { |
|
647 Self { |
|
648 clients_capacity: 1024, |
|
649 rooms_capacity: 512, |
|
650 listener: None, |
|
651 secure_listener: None, |
|
652 } |
|
653 } |
|
654 } |
|
655 |
|
656 impl NetworkLayerBuilder { |
|
657 pub fn with_listener(self, listener: TcpListener) -> Self { |
|
658 Self { |
|
659 listener: Some(listener), |
|
660 ..self |
|
661 } |
|
662 } |
|
663 |
|
664 pub fn with_secure_listener(self, listener: TcpListener) -> Self { |
|
665 Self { |
|
666 secure_listener: Some(listener), |
|
667 ..self |
|
668 } |
|
669 } |
|
670 |
|
671 pub fn build(self) -> NetworkLayer { |
|
672 let server = HWServer::new(self.clients_capacity, self.rooms_capacity); |
|
673 let clients = Slab::with_capacity(self.clients_capacity); |
|
674 let pending = HashSet::with_capacity(2 * self.clients_capacity); |
|
675 let pending_cache = Vec::with_capacity(2 * self.clients_capacity); |
|
676 let timer = timer::Builder::default().build(); |
|
677 |
|
678 NetworkLayer { |
|
679 listener: self.listener.expect("No listener provided"), |
|
680 server, |
|
681 clients, |
|
682 pending, |
|
683 pending_cache, |
|
684 #[cfg(feature = "tls-connections")] |
|
685 ssl: NetworkLayer::create_ssl_context( |
|
686 self.secure_listener.expect("No secure listener provided"), |
|
687 ), |
|
688 #[cfg(feature = "official-server")] |
|
689 io: IoLayer::new(), |
|
690 timer, |
|
691 } |
|
692 } |
|
693 } |