1/*
2 * IXWebSocketServer.cpp
3 * Author: Benjamin Sergeant
4 * Copyright (c) 2018 Machine Zone, Inc. All rights reserved.
5 */
6
7#include "IXWebSocketServer.h"
8
9#include "IXNetSystem.h"
10#include "IXSetThreadName.h"
11#include "IXSocketConnect.h"
12#include "IXWebSocket.h"
13#include "IXWebSocketTransport.h"
14#include <future>
15#include <sstream>
16#include <string.h>
17
18namespace ix
19{
20 const int WebSocketServer::kDefaultHandShakeTimeoutSecs(3); // 3 seconds
21 const bool WebSocketServer::kDefaultEnablePong(true);
22
23 WebSocketServer::WebSocketServer(int port,
24 const std::string& host,
25 int backlog,
26 size_t maxConnections,
27 int handshakeTimeoutSecs,
28 int addressFamily)
29 : SocketServer(port, host, backlog, maxConnections, addressFamily)
30 , _handshakeTimeoutSecs(handshakeTimeoutSecs)
31 , _enablePong(kDefaultEnablePong)
32 , _enablePerMessageDeflate(true)
33 {
34 }
35
36 WebSocketServer::~WebSocketServer()
37 {
38 stop();
39 }
40
41 void WebSocketServer::stop()
42 {
43 stopAcceptingConnections();
44
45 auto clients = getClients();
46 for (auto client : clients)
47 {
48 client->close();
49 }
50
51 SocketServer::stop();
52 }
53
54 void WebSocketServer::enablePong()
55 {
56 _enablePong = true;
57 }
58
59 void WebSocketServer::disablePong()
60 {
61 _enablePong = false;
62 }
63
64 void WebSocketServer::disablePerMessageDeflate()
65 {
66 _enablePerMessageDeflate = false;
67 }
68
69 void WebSocketServer::setOnConnectionCallback(const OnConnectionCallback& callback)
70 {
71 _onConnectionCallback = callback;
72 }
73
74 void WebSocketServer::setOnClientMessageCallback(const OnClientMessageCallback& callback)
75 {
76 _onClientMessageCallback = callback;
77 }
78
79 void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket,
80 std::shared_ptr<ConnectionState> connectionState)
81 {
82 setThreadName("WebSocketServer::" + connectionState->getId());
83
84 auto webSocket = std::make_shared<WebSocket>();
85 if (_onConnectionCallback)
86 {
87 _onConnectionCallback(webSocket, connectionState);
88
89 if (!webSocket->isOnMessageCallbackRegistered())
90 {
91 logError("WebSocketServer Application developer error: Server callback improperly "
92 "registerered.");
93 logError("Missing call to setOnMessageCallback inside setOnConnectionCallback.");
94 connectionState->setTerminated();
95 return;
96 }
97 }
98 else if (_onClientMessageCallback)
99 {
100 WebSocket* webSocketRawPtr = webSocket.get();
101 webSocket->setOnMessageCallback(
102 [this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg) {
103 _onClientMessageCallback(connectionState, *webSocketRawPtr, msg);
104 });
105 }
106 else
107 {
108 logError(
109 "WebSocketServer Application developer error: No server callback is registerered.");
110 logError("Missing call to setOnConnectionCallback or setOnClientMessageCallback.");
111 connectionState->setTerminated();
112 return;
113 }
114
115 webSocket->disableAutomaticReconnection();
116
117 if (_enablePong)
118 {
119 webSocket->enablePong();
120 }
121 else
122 {
123 webSocket->disablePong();
124 }
125
126 // Add this client to our client set
127 {
128 std::lock_guard<std::mutex> lock(_clientsMutex);
129 _clients.insert(webSocket);
130 }
131
132 auto status = webSocket->connectToSocket(
133 std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate);
134 if (status.success)
135 {
136 // Process incoming messages and execute callbacks
137 // until the connection is closed
138 webSocket->run();
139 }
140 else
141 {
142 std::stringstream ss;
143 ss << "WebSocketServer::handleConnection() HTTP status: " << status.http_status
144 << " error: " << status.errorStr;
145 logError(ss.str());
146 }
147
148 webSocket->setOnMessageCallback(nullptr);
149
150 // Remove this client from our client set
151 {
152 std::lock_guard<std::mutex> lock(_clientsMutex);
153 if (_clients.erase(webSocket) != 1)
154 {
155 logError("Cannot delete client");
156 }
157 }
158
159 connectionState->setTerminated();
160 }
161
162 std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients()
163 {
164 std::lock_guard<std::mutex> lock(_clientsMutex);
165 return _clients;
166 }
167
168 size_t WebSocketServer::getConnectedClientsCount()
169 {
170 std::lock_guard<std::mutex> lock(_clientsMutex);
171 return _clients.size();
172 }
173
174 //
175 // Classic servers
176 //
177 void WebSocketServer::makeBroadcastServer()
178 {
179 setOnClientMessageCallback([this](std::shared_ptr<ConnectionState> connectionState,
180 WebSocket& webSocket,
181 const WebSocketMessagePtr& msg) {
182 auto remoteIp = connectionState->getRemoteIp();
183 if (msg->type == ix::WebSocketMessageType::Message)
184 {
185 for (auto&& client : getClients())
186 {
187 if (client.get() != &webSocket)
188 {
189 client->send(msg->str, msg->binary);
190
191 // Make sure the OS send buffer is flushed before moving on
192 do
193 {
194 std::chrono::duration<double, std::milli> duration(500);
195 std::this_thread::sleep_for(duration);
196 } while (client->bufferedAmount() != 0);
197 }
198 }
199 }
200 });
201 }
202
203 bool WebSocketServer::listenAndStart()
204 {
205 auto res = listen();
206 if (!res.first)
207 {
208 return false;
209 }
210
211 start();
212 return true;
213 }
214
215 int WebSocketServer::getHandshakeTimeoutSecs()
216 {
217 return _handshakeTimeoutSecs;
218 }
219
220 bool WebSocketServer::isPongEnabled()
221 {
222 return _enablePong;
223 }
224
225 bool WebSocketServer::isPerMessageDeflateEnabled()
226 {
227 return _enablePerMessageDeflate;
228 }
229} // namespace ix
230