1 | /* |
2 | * IXSocketServer.cpp |
3 | * Author: Benjamin Sergeant |
4 | * Copyright (c) 2018 Machine Zone, Inc. All rights reserved. |
5 | */ |
6 | |
7 | #include "IXSocketServer.h" |
8 | |
9 | #include "IXNetSystem.h" |
10 | #include "IXSelectInterrupt.h" |
11 | #include "IXSelectInterruptFactory.h" |
12 | #include "IXSetThreadName.h" |
13 | #include "IXSocket.h" |
14 | #include "IXSocketConnect.h" |
15 | #include "IXSocketFactory.h" |
16 | #include <assert.h> |
17 | #include <sstream> |
18 | #include <stdio.h> |
19 | #include <string.h> |
20 | |
21 | namespace ix |
22 | { |
23 | const int SocketServer::kDefaultPort(8080); |
24 | const std::string SocketServer::kDefaultHost("127.0.0.1" ); |
25 | const int SocketServer::kDefaultTcpBacklog(5); |
26 | const size_t SocketServer::kDefaultMaxConnections(128); |
27 | const int SocketServer::kDefaultAddressFamily(AF_INET); |
28 | |
29 | SocketServer::SocketServer( |
30 | int port, const std::string& host, int backlog, size_t maxConnections, int addressFamily) |
31 | : _port(port) |
32 | , _host(host) |
33 | , _backlog(backlog) |
34 | , _maxConnections(maxConnections) |
35 | , _addressFamily(addressFamily) |
36 | , _serverFd(-1) |
37 | , _stop(false) |
38 | , _stopGc(false) |
39 | , _connectionStateFactory(&ConnectionState::createConnectionState) |
40 | , _acceptSelectInterrupt(createSelectInterrupt()) |
41 | { |
42 | } |
43 | |
44 | SocketServer::~SocketServer() |
45 | { |
46 | stop(); |
47 | } |
48 | |
49 | void SocketServer::logError(const std::string& str) |
50 | { |
51 | std::lock_guard<std::mutex> lock(_logMutex); |
52 | fprintf(stderr, "%s\n" , str.c_str()); |
53 | } |
54 | |
55 | void SocketServer::logInfo(const std::string& str) |
56 | { |
57 | std::lock_guard<std::mutex> lock(_logMutex); |
58 | fprintf(stdout, "%s\n" , str.c_str()); |
59 | } |
60 | |
61 | std::pair<bool, std::string> SocketServer::listen() |
62 | { |
63 | std::string acceptSelectInterruptInitErrorMsg; |
64 | if (!_acceptSelectInterrupt->init(acceptSelectInterruptInitErrorMsg)) |
65 | { |
66 | std::stringstream ss; |
67 | ss << "SocketServer::listen() error in SelectInterrupt::init: " |
68 | << acceptSelectInterruptInitErrorMsg; |
69 | |
70 | return std::make_pair(false, ss.str()); |
71 | } |
72 | |
73 | if (_addressFamily != AF_INET && _addressFamily != AF_INET6) |
74 | { |
75 | std::string errMsg("SocketServer::listen() AF_INET and AF_INET6 are currently " |
76 | "the only supported address families" ); |
77 | return std::make_pair(false, errMsg); |
78 | } |
79 | |
80 | // Get a socket for accepting connections. |
81 | if ((_serverFd = socket(_addressFamily, SOCK_STREAM, 0)) < 0) |
82 | { |
83 | std::stringstream ss; |
84 | ss << "SocketServer::listen() error creating socket): " << strerror(Socket::getErrno()); |
85 | |
86 | return std::make_pair(false, ss.str()); |
87 | } |
88 | |
89 | // Make that socket reusable. (allow restarting this server at will) |
90 | int enable = 1; |
91 | if (setsockopt(_serverFd, SOL_SOCKET, SO_REUSEADDR, (char*) &enable, sizeof(enable)) < 0) |
92 | { |
93 | std::stringstream ss; |
94 | ss << "SocketServer::listen() error calling setsockopt(SO_REUSEADDR) " |
95 | << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno()); |
96 | |
97 | Socket::closeSocket(_serverFd); |
98 | return std::make_pair(false, ss.str()); |
99 | } |
100 | |
101 | if (_addressFamily == AF_INET) |
102 | { |
103 | struct sockaddr_in server; |
104 | server.sin_family = _addressFamily; |
105 | server.sin_port = htons(_port); |
106 | |
107 | if (ix::inet_pton(_addressFamily, _host.c_str(), &server.sin_addr.s_addr) <= 0) |
108 | { |
109 | std::stringstream ss; |
110 | ss << "SocketServer::listen() error calling inet_pton " |
111 | << "at address " << _host << ":" << _port << " : " |
112 | << strerror(Socket::getErrno()); |
113 | |
114 | Socket::closeSocket(_serverFd); |
115 | return std::make_pair(false, ss.str()); |
116 | } |
117 | |
118 | // Bind the socket to the server address. |
119 | if (bind(_serverFd, (struct sockaddr*) &server, sizeof(server)) < 0) |
120 | { |
121 | std::stringstream ss; |
122 | ss << "SocketServer::listen() error calling bind " |
123 | << "at address " << _host << ":" << _port << " : " |
124 | << strerror(Socket::getErrno()); |
125 | |
126 | Socket::closeSocket(_serverFd); |
127 | return std::make_pair(false, ss.str()); |
128 | } |
129 | } |
130 | else // AF_INET6 |
131 | { |
132 | struct sockaddr_in6 server; |
133 | server.sin6_family = _addressFamily; |
134 | server.sin6_port = htons(_port); |
135 | |
136 | if (ix::inet_pton(_addressFamily, _host.c_str(), &server.sin6_addr) <= 0) |
137 | { |
138 | std::stringstream ss; |
139 | ss << "SocketServer::listen() error calling inet_pton " |
140 | << "at address " << _host << ":" << _port << " : " |
141 | << strerror(Socket::getErrno()); |
142 | |
143 | Socket::closeSocket(_serverFd); |
144 | return std::make_pair(false, ss.str()); |
145 | } |
146 | |
147 | // Bind the socket to the server address. |
148 | if (bind(_serverFd, (struct sockaddr*) &server, sizeof(server)) < 0) |
149 | { |
150 | std::stringstream ss; |
151 | ss << "SocketServer::listen() error calling bind " |
152 | << "at address " << _host << ":" << _port << " : " |
153 | << strerror(Socket::getErrno()); |
154 | |
155 | Socket::closeSocket(_serverFd); |
156 | return std::make_pair(false, ss.str()); |
157 | } |
158 | } |
159 | |
160 | // |
161 | // Listen for connections. Specify the tcp backlog. |
162 | // |
163 | if (::listen(_serverFd, _backlog) < 0) |
164 | { |
165 | std::stringstream ss; |
166 | ss << "SocketServer::listen() error calling listen " |
167 | << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno()); |
168 | |
169 | Socket::closeSocket(_serverFd); |
170 | return std::make_pair(false, ss.str()); |
171 | } |
172 | |
173 | return std::make_pair(true, "" ); |
174 | } |
175 | |
176 | void SocketServer::start() |
177 | { |
178 | _stop = false; |
179 | |
180 | if (!_thread.joinable()) |
181 | { |
182 | _thread = std::thread(&SocketServer::run, this); |
183 | } |
184 | |
185 | if (!_gcThread.joinable()) |
186 | { |
187 | _gcThread = std::thread(&SocketServer::runGC, this); |
188 | } |
189 | } |
190 | |
191 | void SocketServer::wait() |
192 | { |
193 | std::unique_lock<std::mutex> lock(_conditionVariableMutex); |
194 | _conditionVariable.wait(lock); |
195 | } |
196 | |
197 | void SocketServer::stopAcceptingConnections() |
198 | { |
199 | _stop = true; |
200 | } |
201 | |
202 | void SocketServer::stop() |
203 | { |
204 | // Stop accepting connections, and close the 'accept' thread |
205 | if (_thread.joinable()) |
206 | { |
207 | _stop = true; |
208 | // Wake up select |
209 | if (!_acceptSelectInterrupt->notify(SelectInterrupt::kCloseRequest)) |
210 | { |
211 | logError("SocketServer::stop: Cannot wake up from select" ); |
212 | } |
213 | |
214 | _thread.join(); |
215 | _stop = false; |
216 | } |
217 | |
218 | // Join all threads and make sure that all connections are terminated |
219 | if (_gcThread.joinable()) |
220 | { |
221 | _stopGc = true; |
222 | _conditionVariableGC.notify_one(); |
223 | _gcThread.join(); |
224 | _stopGc = false; |
225 | } |
226 | |
227 | _conditionVariable.notify_one(); |
228 | Socket::closeSocket(_serverFd); |
229 | } |
230 | |
231 | void SocketServer::setConnectionStateFactory( |
232 | const ConnectionStateFactory& connectionStateFactory) |
233 | { |
234 | _connectionStateFactory = connectionStateFactory; |
235 | } |
236 | |
237 | // |
238 | // join the threads for connections that have been closed |
239 | // |
240 | // When a connection is closed by a client, the connection state terminated |
241 | // field becomes true, and we can use that to know that we can join that thread |
242 | // and remove it from our _connectionsThreads data structure (a list). |
243 | // |
244 | void SocketServer::closeTerminatedThreads() |
245 | { |
246 | std::lock_guard<std::mutex> lock(_connectionsThreadsMutex); |
247 | auto it = _connectionsThreads.begin(); |
248 | auto itEnd = _connectionsThreads.end(); |
249 | |
250 | while (it != itEnd) |
251 | { |
252 | auto& connectionState = it->first; |
253 | auto& thread = it->second; |
254 | |
255 | if (!connectionState->isTerminated()) |
256 | { |
257 | ++it; |
258 | continue; |
259 | } |
260 | |
261 | if (thread.joinable()) thread.join(); |
262 | it = _connectionsThreads.erase(it); |
263 | } |
264 | } |
265 | |
266 | void SocketServer::run() |
267 | { |
268 | // Set the socket to non blocking mode, so that accept calls are not blocking |
269 | SocketConnect::configure(_serverFd); |
270 | |
271 | setThreadName("SocketServer::accept" ); |
272 | |
273 | for (;;) |
274 | { |
275 | if (_stop) return; |
276 | |
277 | // Use poll to check whether a new connection is in progress |
278 | int timeoutMs = -1; |
279 | #ifdef _WIN32 |
280 | // select cannot be interrupted on Windows so we need to pass a small timeout |
281 | timeoutMs = 10; |
282 | #endif |
283 | |
284 | bool readyToRead = true; |
285 | PollResultType pollResult = |
286 | Socket::poll(readyToRead, timeoutMs, _serverFd, _acceptSelectInterrupt); |
287 | |
288 | if (pollResult == PollResultType::Error) |
289 | { |
290 | std::stringstream ss; |
291 | ss << "SocketServer::run() error in select: " << strerror(Socket::getErrno()); |
292 | logError(ss.str()); |
293 | continue; |
294 | } |
295 | |
296 | if (pollResult != PollResultType::ReadyForRead) |
297 | { |
298 | continue; |
299 | } |
300 | |
301 | // Accept a connection. |
302 | // FIXME: Is this working for ipv6 ? |
303 | struct sockaddr_in client; // client address information |
304 | int clientFd; // socket connected to client |
305 | socklen_t addressLen = sizeof(client); |
306 | memset(&client, 0, sizeof(client)); |
307 | |
308 | if ((clientFd = accept(_serverFd, (struct sockaddr*) &client, &addressLen)) < 0) |
309 | { |
310 | if (!Socket::isWaitNeeded()) |
311 | { |
312 | // FIXME: that error should be propagated |
313 | int err = Socket::getErrno(); |
314 | std::stringstream ss; |
315 | ss << "SocketServer::run() error accepting connection: " << err << ", " |
316 | << strerror(err); |
317 | logError(ss.str()); |
318 | } |
319 | continue; |
320 | } |
321 | |
322 | if (getConnectedClientsCount() >= _maxConnections) |
323 | { |
324 | std::stringstream ss; |
325 | ss << "SocketServer::run() reached max connections = " << _maxConnections << ". " |
326 | << "Not accepting connection" ; |
327 | logError(ss.str()); |
328 | |
329 | Socket::closeSocket(clientFd); |
330 | |
331 | continue; |
332 | } |
333 | |
334 | // Retrieve connection info, the ip address of the remote peer/client) |
335 | std::string remoteIp; |
336 | int remotePort; |
337 | |
338 | if (_addressFamily == AF_INET) |
339 | { |
340 | char remoteIp4[INET_ADDRSTRLEN]; |
341 | if (ix::inet_ntop(AF_INET, &client.sin_addr, remoteIp4, INET_ADDRSTRLEN) == nullptr) |
342 | { |
343 | int err = Socket::getErrno(); |
344 | std::stringstream ss; |
345 | ss << "SocketServer::run() error calling inet_ntop (ipv4): " << err << ", " |
346 | << strerror(err); |
347 | logError(ss.str()); |
348 | |
349 | Socket::closeSocket(clientFd); |
350 | |
351 | continue; |
352 | } |
353 | |
354 | remotePort = ix::network_to_host_short(client.sin_port); |
355 | remoteIp = remoteIp4; |
356 | } |
357 | else // AF_INET6 |
358 | { |
359 | char remoteIp6[INET6_ADDRSTRLEN]; |
360 | if (ix::inet_ntop(AF_INET6, &client.sin_addr, remoteIp6, INET6_ADDRSTRLEN) == |
361 | nullptr) |
362 | { |
363 | int err = Socket::getErrno(); |
364 | std::stringstream ss; |
365 | ss << "SocketServer::run() error calling inet_ntop (ipv6): " << err << ", " |
366 | << strerror(err); |
367 | logError(ss.str()); |
368 | |
369 | Socket::closeSocket(clientFd); |
370 | |
371 | continue; |
372 | } |
373 | |
374 | remotePort = ix::network_to_host_short(client.sin_port); |
375 | remoteIp = remoteIp6; |
376 | } |
377 | |
378 | std::shared_ptr<ConnectionState> connectionState; |
379 | if (_connectionStateFactory) |
380 | { |
381 | connectionState = _connectionStateFactory(); |
382 | } |
383 | connectionState->setOnSetTerminatedCallback([this] { onSetTerminatedCallback(); }); |
384 | connectionState->setRemoteIp(remoteIp); |
385 | connectionState->setRemotePort(remotePort); |
386 | |
387 | if (_stop) return; |
388 | |
389 | // create socket |
390 | std::string errorMsg; |
391 | bool tls = _socketTLSOptions.tls; |
392 | auto socket = createSocket(tls, clientFd, errorMsg, _socketTLSOptions); |
393 | |
394 | if (socket == nullptr) |
395 | { |
396 | logError("SocketServer::run() cannot create socket: " + errorMsg); |
397 | Socket::closeSocket(clientFd); |
398 | continue; |
399 | } |
400 | |
401 | // Set the socket to non blocking mode + other tweaks |
402 | SocketConnect::configure(clientFd); |
403 | |
404 | if (!socket->accept(errorMsg)) |
405 | { |
406 | logError("SocketServer::run() tls accept failed: " + errorMsg); |
407 | Socket::closeSocket(clientFd); |
408 | continue; |
409 | } |
410 | |
411 | // Launch the handleConnection work asynchronously in its own thread. |
412 | std::lock_guard<std::mutex> lock(_connectionsThreadsMutex); |
413 | _connectionsThreads.push_back(std::make_pair( |
414 | connectionState, |
415 | std::thread( |
416 | &SocketServer::handleConnection, this, std::move(socket), connectionState))); |
417 | } |
418 | } |
419 | |
420 | size_t SocketServer::getConnectionsThreadsCount() |
421 | { |
422 | std::lock_guard<std::mutex> lock(_connectionsThreadsMutex); |
423 | return _connectionsThreads.size(); |
424 | } |
425 | |
426 | void SocketServer::runGC() |
427 | { |
428 | setThreadName("SocketServer::GC" ); |
429 | |
430 | for (;;) |
431 | { |
432 | // Garbage collection to shutdown/join threads for closed connections. |
433 | closeTerminatedThreads(); |
434 | |
435 | // We quit this thread if all connections are closed and we received |
436 | // a stop request by setting _stopGc to true. |
437 | if (_stopGc && getConnectionsThreadsCount() == 0) |
438 | { |
439 | break; |
440 | } |
441 | |
442 | // Unless we are stopping the server, wait for a connection |
443 | // to be terminated to run the threads GC, instead of busy waiting |
444 | // with a sleep |
445 | if (!_stopGc) |
446 | { |
447 | std::unique_lock<std::mutex> lock(_conditionVariableMutexGC); |
448 | _conditionVariableGC.wait(lock); |
449 | } |
450 | } |
451 | } |
452 | |
453 | void SocketServer::setTLSOptions(const SocketTLSOptions& socketTLSOptions) |
454 | { |
455 | _socketTLSOptions = socketTLSOptions; |
456 | } |
457 | |
458 | void SocketServer::onSetTerminatedCallback() |
459 | { |
460 | // a connection got terminated, we can run the connection thread GC, |
461 | // so wake up the thread responsible for that |
462 | _conditionVariableGC.notify_one(); |
463 | } |
464 | |
465 | int SocketServer::getPort() |
466 | { |
467 | return _port; |
468 | } |
469 | |
470 | std::string SocketServer::getHost() |
471 | { |
472 | return _host; |
473 | } |
474 | |
475 | int SocketServer::getBacklog() |
476 | { |
477 | return _backlog; |
478 | } |
479 | |
480 | std::size_t SocketServer::getMaxConnections() |
481 | { |
482 | return _maxConnections; |
483 | } |
484 | |
485 | int SocketServer::getAddressFamily() |
486 | { |
487 | return _addressFamily; |
488 | } |
489 | } // namespace ix |
490 | |