1/*
2 * IXSocketServer.cpp
3 * Author: Benjamin Sergeant
4 * Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
5 */
6
7#include "IXSocketServer.h"
8
9#include "IXNetSystem.h"
10#include "IXSelectInterrupt.h"
11#include "IXSelectInterruptFactory.h"
12#include "IXSetThreadName.h"
13#include "IXSocket.h"
14#include "IXSocketConnect.h"
15#include "IXSocketFactory.h"
16#include <assert.h>
17#include <sstream>
18#include <stdio.h>
19#include <string.h>
20
21namespace ix
22{
23 const int SocketServer::kDefaultPort(8080);
24 const std::string SocketServer::kDefaultHost("127.0.0.1");
25 const int SocketServer::kDefaultTcpBacklog(5);
26 const size_t SocketServer::kDefaultMaxConnections(128);
27 const int SocketServer::kDefaultAddressFamily(AF_INET);
28
29 SocketServer::SocketServer(
30 int port, const std::string& host, int backlog, size_t maxConnections, int addressFamily)
31 : _port(port)
32 , _host(host)
33 , _backlog(backlog)
34 , _maxConnections(maxConnections)
35 , _addressFamily(addressFamily)
36 , _serverFd(-1)
37 , _stop(false)
38 , _stopGc(false)
39 , _connectionStateFactory(&ConnectionState::createConnectionState)
40 , _acceptSelectInterrupt(createSelectInterrupt())
41 {
42 }
43
44 SocketServer::~SocketServer()
45 {
46 stop();
47 }
48
49 void SocketServer::logError(const std::string& str)
50 {
51 std::lock_guard<std::mutex> lock(_logMutex);
52 fprintf(stderr, "%s\n", str.c_str());
53 }
54
55 void SocketServer::logInfo(const std::string& str)
56 {
57 std::lock_guard<std::mutex> lock(_logMutex);
58 fprintf(stdout, "%s\n", str.c_str());
59 }
60
61 std::pair<bool, std::string> SocketServer::listen()
62 {
63 std::string acceptSelectInterruptInitErrorMsg;
64 if (!_acceptSelectInterrupt->init(acceptSelectInterruptInitErrorMsg))
65 {
66 std::stringstream ss;
67 ss << "SocketServer::listen() error in SelectInterrupt::init: "
68 << acceptSelectInterruptInitErrorMsg;
69
70 return std::make_pair(false, ss.str());
71 }
72
73 if (_addressFamily != AF_INET && _addressFamily != AF_INET6)
74 {
75 std::string errMsg("SocketServer::listen() AF_INET and AF_INET6 are currently "
76 "the only supported address families");
77 return std::make_pair(false, errMsg);
78 }
79
80 // Get a socket for accepting connections.
81 if ((_serverFd = socket(_addressFamily, SOCK_STREAM, 0)) < 0)
82 {
83 std::stringstream ss;
84 ss << "SocketServer::listen() error creating socket): " << strerror(Socket::getErrno());
85
86 return std::make_pair(false, ss.str());
87 }
88
89 // Make that socket reusable. (allow restarting this server at will)
90 int enable = 1;
91 if (setsockopt(_serverFd, SOL_SOCKET, SO_REUSEADDR, (char*) &enable, sizeof(enable)) < 0)
92 {
93 std::stringstream ss;
94 ss << "SocketServer::listen() error calling setsockopt(SO_REUSEADDR) "
95 << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno());
96
97 Socket::closeSocket(_serverFd);
98 return std::make_pair(false, ss.str());
99 }
100
101 if (_addressFamily == AF_INET)
102 {
103 struct sockaddr_in server;
104 server.sin_family = _addressFamily;
105 server.sin_port = htons(_port);
106
107 if (ix::inet_pton(_addressFamily, _host.c_str(), &server.sin_addr.s_addr) <= 0)
108 {
109 std::stringstream ss;
110 ss << "SocketServer::listen() error calling inet_pton "
111 << "at address " << _host << ":" << _port << " : "
112 << strerror(Socket::getErrno());
113
114 Socket::closeSocket(_serverFd);
115 return std::make_pair(false, ss.str());
116 }
117
118 // Bind the socket to the server address.
119 if (bind(_serverFd, (struct sockaddr*) &server, sizeof(server)) < 0)
120 {
121 std::stringstream ss;
122 ss << "SocketServer::listen() error calling bind "
123 << "at address " << _host << ":" << _port << " : "
124 << strerror(Socket::getErrno());
125
126 Socket::closeSocket(_serverFd);
127 return std::make_pair(false, ss.str());
128 }
129 }
130 else // AF_INET6
131 {
132 struct sockaddr_in6 server;
133 server.sin6_family = _addressFamily;
134 server.sin6_port = htons(_port);
135
136 if (ix::inet_pton(_addressFamily, _host.c_str(), &server.sin6_addr) <= 0)
137 {
138 std::stringstream ss;
139 ss << "SocketServer::listen() error calling inet_pton "
140 << "at address " << _host << ":" << _port << " : "
141 << strerror(Socket::getErrno());
142
143 Socket::closeSocket(_serverFd);
144 return std::make_pair(false, ss.str());
145 }
146
147 // Bind the socket to the server address.
148 if (bind(_serverFd, (struct sockaddr*) &server, sizeof(server)) < 0)
149 {
150 std::stringstream ss;
151 ss << "SocketServer::listen() error calling bind "
152 << "at address " << _host << ":" << _port << " : "
153 << strerror(Socket::getErrno());
154
155 Socket::closeSocket(_serverFd);
156 return std::make_pair(false, ss.str());
157 }
158 }
159
160 //
161 // Listen for connections. Specify the tcp backlog.
162 //
163 if (::listen(_serverFd, _backlog) < 0)
164 {
165 std::stringstream ss;
166 ss << "SocketServer::listen() error calling listen "
167 << "at address " << _host << ":" << _port << " : " << strerror(Socket::getErrno());
168
169 Socket::closeSocket(_serverFd);
170 return std::make_pair(false, ss.str());
171 }
172
173 return std::make_pair(true, "");
174 }
175
176 void SocketServer::start()
177 {
178 _stop = false;
179
180 if (!_thread.joinable())
181 {
182 _thread = std::thread(&SocketServer::run, this);
183 }
184
185 if (!_gcThread.joinable())
186 {
187 _gcThread = std::thread(&SocketServer::runGC, this);
188 }
189 }
190
191 void SocketServer::wait()
192 {
193 std::unique_lock<std::mutex> lock(_conditionVariableMutex);
194 _conditionVariable.wait(lock);
195 }
196
197 void SocketServer::stopAcceptingConnections()
198 {
199 _stop = true;
200 }
201
202 void SocketServer::stop()
203 {
204 // Stop accepting connections, and close the 'accept' thread
205 if (_thread.joinable())
206 {
207 _stop = true;
208 // Wake up select
209 if (!_acceptSelectInterrupt->notify(SelectInterrupt::kCloseRequest))
210 {
211 logError("SocketServer::stop: Cannot wake up from select");
212 }
213
214 _thread.join();
215 _stop = false;
216 }
217
218 // Join all threads and make sure that all connections are terminated
219 if (_gcThread.joinable())
220 {
221 _stopGc = true;
222 _conditionVariableGC.notify_one();
223 _gcThread.join();
224 _stopGc = false;
225 }
226
227 _conditionVariable.notify_one();
228 Socket::closeSocket(_serverFd);
229 }
230
231 void SocketServer::setConnectionStateFactory(
232 const ConnectionStateFactory& connectionStateFactory)
233 {
234 _connectionStateFactory = connectionStateFactory;
235 }
236
237 //
238 // join the threads for connections that have been closed
239 //
240 // When a connection is closed by a client, the connection state terminated
241 // field becomes true, and we can use that to know that we can join that thread
242 // and remove it from our _connectionsThreads data structure (a list).
243 //
244 void SocketServer::closeTerminatedThreads()
245 {
246 std::lock_guard<std::mutex> lock(_connectionsThreadsMutex);
247 auto it = _connectionsThreads.begin();
248 auto itEnd = _connectionsThreads.end();
249
250 while (it != itEnd)
251 {
252 auto& connectionState = it->first;
253 auto& thread = it->second;
254
255 if (!connectionState->isTerminated())
256 {
257 ++it;
258 continue;
259 }
260
261 if (thread.joinable()) thread.join();
262 it = _connectionsThreads.erase(it);
263 }
264 }
265
266 void SocketServer::run()
267 {
268 // Set the socket to non blocking mode, so that accept calls are not blocking
269 SocketConnect::configure(_serverFd);
270
271 setThreadName("SocketServer::accept");
272
273 for (;;)
274 {
275 if (_stop) return;
276
277 // Use poll to check whether a new connection is in progress
278 int timeoutMs = -1;
279#ifdef _WIN32
280 // select cannot be interrupted on Windows so we need to pass a small timeout
281 timeoutMs = 10;
282#endif
283
284 bool readyToRead = true;
285 PollResultType pollResult =
286 Socket::poll(readyToRead, timeoutMs, _serverFd, _acceptSelectInterrupt);
287
288 if (pollResult == PollResultType::Error)
289 {
290 std::stringstream ss;
291 ss << "SocketServer::run() error in select: " << strerror(Socket::getErrno());
292 logError(ss.str());
293 continue;
294 }
295
296 if (pollResult != PollResultType::ReadyForRead)
297 {
298 continue;
299 }
300
301 // Accept a connection.
302 // FIXME: Is this working for ipv6 ?
303 struct sockaddr_in client; // client address information
304 int clientFd; // socket connected to client
305 socklen_t addressLen = sizeof(client);
306 memset(&client, 0, sizeof(client));
307
308 if ((clientFd = accept(_serverFd, (struct sockaddr*) &client, &addressLen)) < 0)
309 {
310 if (!Socket::isWaitNeeded())
311 {
312 // FIXME: that error should be propagated
313 int err = Socket::getErrno();
314 std::stringstream ss;
315 ss << "SocketServer::run() error accepting connection: " << err << ", "
316 << strerror(err);
317 logError(ss.str());
318 }
319 continue;
320 }
321
322 if (getConnectedClientsCount() >= _maxConnections)
323 {
324 std::stringstream ss;
325 ss << "SocketServer::run() reached max connections = " << _maxConnections << ". "
326 << "Not accepting connection";
327 logError(ss.str());
328
329 Socket::closeSocket(clientFd);
330
331 continue;
332 }
333
334 // Retrieve connection info, the ip address of the remote peer/client)
335 std::string remoteIp;
336 int remotePort;
337
338 if (_addressFamily == AF_INET)
339 {
340 char remoteIp4[INET_ADDRSTRLEN];
341 if (ix::inet_ntop(AF_INET, &client.sin_addr, remoteIp4, INET_ADDRSTRLEN) == nullptr)
342 {
343 int err = Socket::getErrno();
344 std::stringstream ss;
345 ss << "SocketServer::run() error calling inet_ntop (ipv4): " << err << ", "
346 << strerror(err);
347 logError(ss.str());
348
349 Socket::closeSocket(clientFd);
350
351 continue;
352 }
353
354 remotePort = ix::network_to_host_short(client.sin_port);
355 remoteIp = remoteIp4;
356 }
357 else // AF_INET6
358 {
359 char remoteIp6[INET6_ADDRSTRLEN];
360 if (ix::inet_ntop(AF_INET6, &client.sin_addr, remoteIp6, INET6_ADDRSTRLEN) ==
361 nullptr)
362 {
363 int err = Socket::getErrno();
364 std::stringstream ss;
365 ss << "SocketServer::run() error calling inet_ntop (ipv6): " << err << ", "
366 << strerror(err);
367 logError(ss.str());
368
369 Socket::closeSocket(clientFd);
370
371 continue;
372 }
373
374 remotePort = ix::network_to_host_short(client.sin_port);
375 remoteIp = remoteIp6;
376 }
377
378 std::shared_ptr<ConnectionState> connectionState;
379 if (_connectionStateFactory)
380 {
381 connectionState = _connectionStateFactory();
382 }
383 connectionState->setOnSetTerminatedCallback([this] { onSetTerminatedCallback(); });
384 connectionState->setRemoteIp(remoteIp);
385 connectionState->setRemotePort(remotePort);
386
387 if (_stop) return;
388
389 // create socket
390 std::string errorMsg;
391 bool tls = _socketTLSOptions.tls;
392 auto socket = createSocket(tls, clientFd, errorMsg, _socketTLSOptions);
393
394 if (socket == nullptr)
395 {
396 logError("SocketServer::run() cannot create socket: " + errorMsg);
397 Socket::closeSocket(clientFd);
398 continue;
399 }
400
401 // Set the socket to non blocking mode + other tweaks
402 SocketConnect::configure(clientFd);
403
404 if (!socket->accept(errorMsg))
405 {
406 logError("SocketServer::run() tls accept failed: " + errorMsg);
407 Socket::closeSocket(clientFd);
408 continue;
409 }
410
411 // Launch the handleConnection work asynchronously in its own thread.
412 std::lock_guard<std::mutex> lock(_connectionsThreadsMutex);
413 _connectionsThreads.push_back(std::make_pair(
414 connectionState,
415 std::thread(
416 &SocketServer::handleConnection, this, std::move(socket), connectionState)));
417 }
418 }
419
420 size_t SocketServer::getConnectionsThreadsCount()
421 {
422 std::lock_guard<std::mutex> lock(_connectionsThreadsMutex);
423 return _connectionsThreads.size();
424 }
425
426 void SocketServer::runGC()
427 {
428 setThreadName("SocketServer::GC");
429
430 for (;;)
431 {
432 // Garbage collection to shutdown/join threads for closed connections.
433 closeTerminatedThreads();
434
435 // We quit this thread if all connections are closed and we received
436 // a stop request by setting _stopGc to true.
437 if (_stopGc && getConnectionsThreadsCount() == 0)
438 {
439 break;
440 }
441
442 // Unless we are stopping the server, wait for a connection
443 // to be terminated to run the threads GC, instead of busy waiting
444 // with a sleep
445 if (!_stopGc)
446 {
447 std::unique_lock<std::mutex> lock(_conditionVariableMutexGC);
448 _conditionVariableGC.wait(lock);
449 }
450 }
451 }
452
453 void SocketServer::setTLSOptions(const SocketTLSOptions& socketTLSOptions)
454 {
455 _socketTLSOptions = socketTLSOptions;
456 }
457
458 void SocketServer::onSetTerminatedCallback()
459 {
460 // a connection got terminated, we can run the connection thread GC,
461 // so wake up the thread responsible for that
462 _conditionVariableGC.notify_one();
463 }
464
465 int SocketServer::getPort()
466 {
467 return _port;
468 }
469
470 std::string SocketServer::getHost()
471 {
472 return _host;
473 }
474
475 int SocketServer::getBacklog()
476 {
477 return _backlog;
478 }
479
480 std::size_t SocketServer::getMaxConnections()
481 {
482 return _maxConnections;
483 }
484
485 int SocketServer::getAddressFamily()
486 {
487 return _addressFamily;
488 }
489} // namespace ix
490