33 Closed, |
42 Closed, |
34 } |
43 } |
35 |
44 |
36 type NetworkResult<T> = io::Result<(T, NetworkClientState)>; |
45 type NetworkResult<T> = io::Result<(T, NetworkClientState)>; |
37 |
46 |
|
47 #[cfg(not(feature = "tls-connections"))] |
|
48 pub enum ClientSocket { |
|
49 Plain(TcpStream) |
|
50 } |
|
51 |
|
52 #[cfg(feature = "tls-connections")] |
|
53 pub enum ClientSocket { |
|
54 SslHandshake(Option<MidHandshakeSslStream<TcpStream>>), |
|
55 SslStream(SslStream<TcpStream>) |
|
56 } |
|
57 |
|
58 impl ClientSocket { |
|
59 fn inner(&self) -> &TcpStream { |
|
60 #[cfg(not(feature = "tls-connections"))] |
|
61 match self { |
|
62 ClientSocket::Plain(stream) => stream, |
|
63 } |
|
64 |
|
65 #[cfg(feature = "tls-connections")] |
|
66 match self { |
|
67 ClientSocket::SslHandshake(Some(builder)) => builder.get_ref(), |
|
68 ClientSocket::SslHandshake(None) => unreachable!(), |
|
69 ClientSocket::SslStream(ssl_stream) => ssl_stream.get_ref() |
|
70 } |
|
71 } |
|
72 } |
|
73 |
38 pub struct NetworkClient { |
74 pub struct NetworkClient { |
39 id: ClientId, |
75 id: ClientId, |
40 socket: TcpStream, |
76 socket: ClientSocket, |
41 peer_addr: SocketAddr, |
77 peer_addr: SocketAddr, |
42 decoder: ProtocolDecoder, |
78 decoder: ProtocolDecoder, |
43 buf_out: netbuf::Buf |
79 buf_out: netbuf::Buf |
44 } |
80 } |
45 |
81 |
46 impl NetworkClient { |
82 impl NetworkClient { |
47 pub fn new(id: ClientId, socket: TcpStream, peer_addr: SocketAddr) -> NetworkClient { |
83 pub fn new(id: ClientId, socket: ClientSocket, peer_addr: SocketAddr) -> NetworkClient { |
48 NetworkClient { |
84 NetworkClient { |
49 id, socket, peer_addr, |
85 id, socket, peer_addr, |
50 decoder: ProtocolDecoder::new(), |
86 decoder: ProtocolDecoder::new(), |
51 buf_out: netbuf::Buf::new() |
87 buf_out: netbuf::Buf::new() |
52 } |
88 } |
53 } |
89 } |
54 |
90 |
55 pub fn read_messages(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> { |
91 fn read_impl<R: Read>(decoder: &mut ProtocolDecoder, source: &mut R, |
|
92 id: ClientId, addr: &SocketAddr) -> NetworkResult<Vec<HWProtocolMessage>> { |
56 let mut bytes_read = 0; |
93 let mut bytes_read = 0; |
57 let result = loop { |
94 let result = loop { |
58 match self.decoder.read_from(&mut self.socket) { |
95 match decoder.read_from(source) { |
59 Ok(bytes) => { |
96 Ok(bytes) => { |
60 debug!("Client {}: read {} bytes", self.id, bytes); |
97 debug!("Client {}: read {} bytes", id, bytes); |
61 bytes_read += bytes; |
98 bytes_read += bytes; |
62 if bytes == 0 { |
99 if bytes == 0 { |
63 let result = if bytes_read == 0 { |
100 let result = if bytes_read == 0 { |
64 info!("EOF for client {} ({})", self.id, self.peer_addr); |
101 info!("EOF for client {} ({})", id, addr); |
65 (Vec::new(), NetworkClientState::Closed) |
102 (Vec::new(), NetworkClientState::Closed) |
66 } else { |
103 } else { |
67 (self.decoder.extract_messages(), NetworkClientState::NeedsRead) |
104 (decoder.extract_messages(), NetworkClientState::NeedsRead) |
68 }; |
105 }; |
69 break Ok(result); |
106 break Ok(result); |
70 } |
107 } |
71 else if bytes_read >= MAX_BYTES_PER_READ { |
108 else if bytes_read >= MAX_BYTES_PER_READ { |
72 break Ok((self.decoder.extract_messages(), NetworkClientState::NeedsRead)) |
109 break Ok((decoder.extract_messages(), NetworkClientState::NeedsRead)) |
73 } |
110 } |
74 } |
111 } |
75 Err(ref error) if error.kind() == ErrorKind::WouldBlock => { |
112 Err(ref error) if error.kind() == ErrorKind::WouldBlock => { |
76 let messages = if bytes_read == 0 { |
113 let messages = if bytes_read == 0 { |
77 Vec::new() |
114 Vec::new() |
78 } else { |
115 } else { |
79 self.decoder.extract_messages() |
116 decoder.extract_messages() |
80 }; |
117 }; |
81 break Ok((messages, NetworkClientState::Idle)); |
118 break Ok((messages, NetworkClientState::Idle)); |
82 } |
119 } |
83 Err(error) => |
120 Err(error) => |
84 break Err(error) |
121 break Err(error) |
85 } |
122 } |
86 }; |
123 }; |
87 self.decoder.sweep(); |
124 decoder.sweep(); |
88 result |
125 result |
89 } |
126 } |
90 |
127 |
91 pub fn flush(&mut self) -> NetworkResult<()> { |
128 pub fn read(&mut self) -> NetworkResult<Vec<HWProtocolMessage>> { |
|
129 #[cfg(not(feature = "tls-connections"))] |
|
130 match self.socket { |
|
131 ClientSocket::Plain(ref mut stream) => |
|
132 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr), |
|
133 } |
|
134 |
|
135 #[cfg(feature = "tls-connections")] |
|
136 match self.socket { |
|
137 ClientSocket::SslHandshake(ref mut handshake_opt) => { |
|
138 let mut handshake = std::mem::replace(handshake_opt, None).unwrap(); |
|
139 |
|
140 match handshake.handshake() { |
|
141 Ok(stream) => { |
|
142 debug!("TLS handshake with {} ({}) completed", self.id, self.peer_addr); |
|
143 self.socket = ClientSocket::SslStream(stream); |
|
144 |
|
145 Ok((Vec::new(), NetworkClientState::Idle)) |
|
146 } |
|
147 Err(HandshakeError::WouldBlock(new_handshake)) => { |
|
148 *handshake_opt = Some(new_handshake); |
|
149 Ok((Vec::new(), NetworkClientState::Idle)) |
|
150 } |
|
151 Err(e) => { |
|
152 debug!("TLS handshake with {} ({}) failed", self.id, self.peer_addr); |
|
153 Err(Error::new(ErrorKind::Other, "Connection failure")) |
|
154 } |
|
155 } |
|
156 }, |
|
157 ClientSocket::SslStream(ref mut stream) => |
|
158 NetworkClient::read_impl(&mut self.decoder, stream, self.id, &self.peer_addr) |
|
159 } |
|
160 } |
|
161 |
|
162 fn write_impl<W: Write>(buf_out: &mut netbuf::Buf, destination: &mut W) -> NetworkResult<()> { |
92 let result = loop { |
163 let result = loop { |
93 match self.buf_out.write_to(&mut self.socket) { |
164 match buf_out.write_to(destination) { |
94 Ok(bytes) if self.buf_out.is_empty() || bytes == 0 => |
165 Ok(bytes) if buf_out.is_empty() || bytes == 0 => |
95 break Ok(((), NetworkClientState::Idle)), |
166 break Ok(((), NetworkClientState::Idle)), |
96 Ok(_) => (), |
167 Ok(_) => (), |
97 Err(ref error) if error.kind() == ErrorKind::Interrupted |
168 Err(ref error) if error.kind() == ErrorKind::Interrupted |
98 || error.kind() == ErrorKind::WouldBlock => { |
169 || error.kind() == ErrorKind::WouldBlock => { |
99 break Ok(((), NetworkClientState::NeedsWrite)); |
170 break Ok(((), NetworkClientState::NeedsWrite)); |
100 }, |
171 }, |
101 Err(error) => |
172 Err(error) => |
102 break Err(error) |
173 break Err(error) |
103 } |
174 } |
104 }; |
175 }; |
105 self.socket.flush()?; |
176 result |
|
177 } |
|
178 |
|
179 pub fn write(&mut self) -> NetworkResult<()> { |
|
180 let result = { |
|
181 #[cfg(not(feature = "tls-connections"))] |
|
182 match self.socket { |
|
183 ClientSocket::Plain(ref mut stream) => |
|
184 NetworkClient::write_impl(&mut self.buf_out, stream) |
|
185 } |
|
186 |
|
187 #[cfg(feature = "tls-connections")] { |
|
188 match self.socket { |
|
189 ClientSocket::SslHandshake(_) => |
|
190 Ok(((), NetworkClientState::Idle)), |
|
191 ClientSocket::SslStream(ref mut stream) => |
|
192 NetworkClient::write_impl(&mut self.buf_out, stream) |
|
193 } |
|
194 } |
|
195 }; |
|
196 |
|
197 self.socket.inner().flush()?; |
106 result |
198 result |
107 } |
199 } |
108 |
200 |
109 pub fn send_raw_msg(&mut self, msg: &[u8]) { |
201 pub fn send_raw_msg(&mut self, msg: &[u8]) { |
110 self.buf_out.write_all(msg).unwrap(); |
202 self.buf_out.write_all(msg).unwrap(); |
115 } |
207 } |
116 |
208 |
117 pub fn send_msg(&mut self, msg: &HWServerMessage) { |
209 pub fn send_msg(&mut self, msg: &HWServerMessage) { |
118 self.send_string(&msg.to_raw_protocol()); |
210 self.send_string(&msg.to_raw_protocol()); |
119 } |
211 } |
|
212 } |
|
213 |
|
214 #[cfg(feature = "tls-connections")] |
|
215 struct ServerSsl { |
|
216 context: SslContext |
120 } |
217 } |
121 |
218 |
122 pub struct NetworkLayer { |
219 pub struct NetworkLayer { |
123 listener: TcpListener, |
220 listener: TcpListener, |
124 server: HWServer, |
221 server: HWServer, |
125 clients: Slab<NetworkClient>, |
222 clients: Slab<NetworkClient>, |
126 pending: HashSet<(ClientId, NetworkClientState)>, |
223 pending: HashSet<(ClientId, NetworkClientState)>, |
127 pending_cache: Vec<(ClientId, NetworkClientState)> |
224 pending_cache: Vec<(ClientId, NetworkClientState)>, |
|
225 #[cfg(feature = "tls-connections")] |
|
226 ssl: ServerSsl |
128 } |
227 } |
129 |
228 |
130 impl NetworkLayer { |
229 impl NetworkLayer { |
131 pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { |
230 pub fn new(listener: TcpListener, clients_limit: usize, rooms_limit: usize) -> NetworkLayer { |
132 let server = HWServer::new(clients_limit, rooms_limit); |
231 let server = HWServer::new(clients_limit, rooms_limit); |
133 let clients = Slab::with_capacity(clients_limit); |
232 let clients = Slab::with_capacity(clients_limit); |
134 let pending = HashSet::with_capacity(2 * clients_limit); |
233 let pending = HashSet::with_capacity(2 * clients_limit); |
135 let pending_cache = Vec::with_capacity(2 * clients_limit); |
234 let pending_cache = Vec::with_capacity(2 * clients_limit); |
136 NetworkLayer {listener, server, clients, pending, pending_cache} |
235 |
|
236 NetworkLayer { |
|
237 listener, server, clients, pending, pending_cache, |
|
238 #[cfg(feature = "tls-connections")] |
|
239 ssl: NetworkLayer::create_ssl_context() |
|
240 } |
|
241 } |
|
242 |
|
243 #[cfg(feature = "tls-connections")] |
|
244 fn create_ssl_context() -> ServerSsl { |
|
245 let mut builder = SslContextBuilder::new(SslMethod::tls()).unwrap(); |
|
246 builder.set_verify(SslVerifyMode::NONE); |
|
247 builder.set_read_ahead(true); |
|
248 builder.set_certificate_file("ssl/cert.pem", SslFiletype::PEM).unwrap(); |
|
249 builder.set_private_key_file("ssl/key.pem", SslFiletype::PEM).unwrap(); |
|
250 builder.set_options(SslOptions::NO_COMPRESSION); |
|
251 builder.set_cipher_list("DEFAULT:!LOW:!RC4:!EXP").unwrap(); |
|
252 ServerSsl { context: builder.build() } |
137 } |
253 } |
138 |
254 |
139 pub fn register_server(&self, poll: &Poll) -> io::Result<()> { |
255 pub fn register_server(&self, poll: &Poll) -> io::Result<()> { |
140 poll.register(&self.listener, utils::SERVER, Ready::readable(), |
256 poll.register(&self.listener, utils::SERVER, Ready::readable(), |
141 PollOpt::edge()) |
257 PollOpt::edge()) |
142 } |
258 } |
143 |
259 |
144 fn deregister_client(&mut self, poll: &Poll, id: ClientId) { |
260 fn deregister_client(&mut self, poll: &Poll, id: ClientId) { |
145 let mut client_exists = false; |
261 let mut client_exists = false; |
146 if let Some(ref client) = self.clients.get(id) { |
262 if let Some(ref client) = self.clients.get(id) { |
147 poll.deregister(&client.socket) |
263 poll.deregister(client.socket.inner()) |
148 .expect("could not deregister socket"); |
264 .expect("could not deregister socket"); |
149 info!("client {} ({}) removed", client.id, client.peer_addr); |
265 info!("client {} ({}) removed", client.id, client.peer_addr); |
150 client_exists = true; |
266 client_exists = true; |
151 } |
267 } |
152 if client_exists { |
268 if client_exists { |
153 self.clients.remove(id); |
269 self.clients.remove(id); |
154 } |
270 } |
155 } |
271 } |
156 |
272 |
157 fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: TcpStream, addr: SocketAddr) { |
273 fn register_client(&mut self, poll: &Poll, id: ClientId, client_socket: ClientSocket, addr: SocketAddr) { |
158 poll.register(&client_socket, Token(id), |
274 poll.register(client_socket.inner(), Token(id), |
159 Ready::readable() | Ready::writable(), |
275 Ready::readable() | Ready::writable(), |
160 PollOpt::edge()) |
276 PollOpt::edge()) |
161 .expect("could not register socket with event loop"); |
277 .expect("could not register socket with event loop"); |
162 |
278 |
163 let entry = self.clients.vacant_entry(); |
279 let entry = self.clients.vacant_entry(); |
178 } |
294 } |
179 } |
295 } |
180 } |
296 } |
181 } |
297 } |
182 |
298 |
|
299 fn create_client_socket(&self, socket: TcpStream) -> io::Result<ClientSocket> { |
|
300 #[cfg(not(feature = "tls-connections"))] { |
|
301 Ok(ClientSocket::Plain(socket)) |
|
302 } |
|
303 |
|
304 #[cfg(feature = "tls-connections")] { |
|
305 let ssl = Ssl::new(&self.ssl.context).unwrap(); |
|
306 let mut builder = SslStreamBuilder::new(ssl, socket); |
|
307 builder.set_accept_state(); |
|
308 match builder.handshake() { |
|
309 Ok(stream) => |
|
310 Ok(ClientSocket::SslStream(stream)), |
|
311 Err(HandshakeError::WouldBlock(stream)) => |
|
312 Ok(ClientSocket::SslHandshake(Some(stream))), |
|
313 Err(e) => { |
|
314 debug!("OpenSSL handshake failed: {}", e); |
|
315 Err(Error::new(ErrorKind::Other, "Connection failure")) |
|
316 } |
|
317 } |
|
318 } |
|
319 } |
|
320 |
183 pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { |
321 pub fn accept_client(&mut self, poll: &Poll) -> io::Result<()> { |
184 let (client_socket, addr) = self.listener.accept()?; |
322 let (client_socket, addr) = self.listener.accept()?; |
185 info!("Connected: {}", addr); |
323 info!("Connected: {}", addr); |
186 |
324 |
187 let client_id = self.server.add_client(); |
325 let client_id = self.server.add_client(); |
188 self.register_client(poll, client_id, client_socket, addr); |
326 self.register_client(poll, client_id, self.create_client_socket(client_socket)?, addr); |
189 self.flush_server_messages(); |
327 self.flush_server_messages(); |
190 |
328 |
191 Ok(()) |
329 Ok(()) |
192 } |
330 } |
193 |
331 |