1 | /* |
2 | * IXWebSocketServer.cpp |
3 | * Author: Benjamin Sergeant |
4 | * Copyright (c) 2018 Machine Zone, Inc. All rights reserved. |
5 | */ |
6 | |
7 | #include "IXWebSocketServer.h" |
8 | |
9 | #include "IXNetSystem.h" |
10 | #include "IXSetThreadName.h" |
11 | #include "IXSocketConnect.h" |
12 | #include "IXWebSocket.h" |
13 | #include "IXWebSocketTransport.h" |
14 | #include <future> |
15 | #include <sstream> |
16 | #include <string.h> |
17 | |
18 | namespace ix |
19 | { |
20 | const int WebSocketServer::kDefaultHandShakeTimeoutSecs(3); // 3 seconds |
21 | const bool WebSocketServer::kDefaultEnablePong(true); |
22 | |
23 | WebSocketServer::WebSocketServer(int port, |
24 | const std::string& host, |
25 | int backlog, |
26 | size_t maxConnections, |
27 | int handshakeTimeoutSecs, |
28 | int addressFamily) |
29 | : SocketServer(port, host, backlog, maxConnections, addressFamily) |
30 | , _handshakeTimeoutSecs(handshakeTimeoutSecs) |
31 | , _enablePong(kDefaultEnablePong) |
32 | , _enablePerMessageDeflate(true) |
33 | { |
34 | } |
35 | |
36 | WebSocketServer::~WebSocketServer() |
37 | { |
38 | stop(); |
39 | } |
40 | |
41 | void WebSocketServer::stop() |
42 | { |
43 | stopAcceptingConnections(); |
44 | |
45 | auto clients = getClients(); |
46 | for (auto client : clients) |
47 | { |
48 | client->close(); |
49 | } |
50 | |
51 | SocketServer::stop(); |
52 | } |
53 | |
54 | void WebSocketServer::enablePong() |
55 | { |
56 | _enablePong = true; |
57 | } |
58 | |
59 | void WebSocketServer::disablePong() |
60 | { |
61 | _enablePong = false; |
62 | } |
63 | |
64 | void WebSocketServer::disablePerMessageDeflate() |
65 | { |
66 | _enablePerMessageDeflate = false; |
67 | } |
68 | |
69 | void WebSocketServer::setOnConnectionCallback(const OnConnectionCallback& callback) |
70 | { |
71 | _onConnectionCallback = callback; |
72 | } |
73 | |
74 | void WebSocketServer::setOnClientMessageCallback(const OnClientMessageCallback& callback) |
75 | { |
76 | _onClientMessageCallback = callback; |
77 | } |
78 | |
79 | void WebSocketServer::handleConnection(std::unique_ptr<Socket> socket, |
80 | std::shared_ptr<ConnectionState> connectionState) |
81 | { |
82 | setThreadName("WebSocketServer::" + connectionState->getId()); |
83 | |
84 | auto webSocket = std::make_shared<WebSocket>(); |
85 | if (_onConnectionCallback) |
86 | { |
87 | _onConnectionCallback(webSocket, connectionState); |
88 | |
89 | if (!webSocket->isOnMessageCallbackRegistered()) |
90 | { |
91 | logError("WebSocketServer Application developer error: Server callback improperly " |
92 | "registerered." ); |
93 | logError("Missing call to setOnMessageCallback inside setOnConnectionCallback." ); |
94 | connectionState->setTerminated(); |
95 | return; |
96 | } |
97 | } |
98 | else if (_onClientMessageCallback) |
99 | { |
100 | WebSocket* webSocketRawPtr = webSocket.get(); |
101 | webSocket->setOnMessageCallback( |
102 | [this, webSocketRawPtr, connectionState](const WebSocketMessagePtr& msg) { |
103 | _onClientMessageCallback(connectionState, *webSocketRawPtr, msg); |
104 | }); |
105 | } |
106 | else |
107 | { |
108 | logError( |
109 | "WebSocketServer Application developer error: No server callback is registerered." ); |
110 | logError("Missing call to setOnConnectionCallback or setOnClientMessageCallback." ); |
111 | connectionState->setTerminated(); |
112 | return; |
113 | } |
114 | |
115 | webSocket->disableAutomaticReconnection(); |
116 | |
117 | if (_enablePong) |
118 | { |
119 | webSocket->enablePong(); |
120 | } |
121 | else |
122 | { |
123 | webSocket->disablePong(); |
124 | } |
125 | |
126 | // Add this client to our client set |
127 | { |
128 | std::lock_guard<std::mutex> lock(_clientsMutex); |
129 | _clients.insert(webSocket); |
130 | } |
131 | |
132 | auto status = webSocket->connectToSocket( |
133 | std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate); |
134 | if (status.success) |
135 | { |
136 | // Process incoming messages and execute callbacks |
137 | // until the connection is closed |
138 | webSocket->run(); |
139 | } |
140 | else |
141 | { |
142 | std::stringstream ss; |
143 | ss << "WebSocketServer::handleConnection() HTTP status: " << status.http_status |
144 | << " error: " << status.errorStr; |
145 | logError(ss.str()); |
146 | } |
147 | |
148 | webSocket->setOnMessageCallback(nullptr); |
149 | |
150 | // Remove this client from our client set |
151 | { |
152 | std::lock_guard<std::mutex> lock(_clientsMutex); |
153 | if (_clients.erase(webSocket) != 1) |
154 | { |
155 | logError("Cannot delete client" ); |
156 | } |
157 | } |
158 | |
159 | connectionState->setTerminated(); |
160 | } |
161 | |
162 | std::set<std::shared_ptr<WebSocket>> WebSocketServer::getClients() |
163 | { |
164 | std::lock_guard<std::mutex> lock(_clientsMutex); |
165 | return _clients; |
166 | } |
167 | |
168 | size_t WebSocketServer::getConnectedClientsCount() |
169 | { |
170 | std::lock_guard<std::mutex> lock(_clientsMutex); |
171 | return _clients.size(); |
172 | } |
173 | |
174 | // |
175 | // Classic servers |
176 | // |
177 | void WebSocketServer::makeBroadcastServer() |
178 | { |
179 | setOnClientMessageCallback([this](std::shared_ptr<ConnectionState> connectionState, |
180 | WebSocket& webSocket, |
181 | const WebSocketMessagePtr& msg) { |
182 | auto remoteIp = connectionState->getRemoteIp(); |
183 | if (msg->type == ix::WebSocketMessageType::Message) |
184 | { |
185 | for (auto&& client : getClients()) |
186 | { |
187 | if (client.get() != &webSocket) |
188 | { |
189 | client->send(msg->str, msg->binary); |
190 | |
191 | // Make sure the OS send buffer is flushed before moving on |
192 | do |
193 | { |
194 | std::chrono::duration<double, std::milli> duration(500); |
195 | std::this_thread::sleep_for(duration); |
196 | } while (client->bufferedAmount() != 0); |
197 | } |
198 | } |
199 | } |
200 | }); |
201 | } |
202 | |
203 | bool WebSocketServer::listenAndStart() |
204 | { |
205 | auto res = listen(); |
206 | if (!res.first) |
207 | { |
208 | return false; |
209 | } |
210 | |
211 | start(); |
212 | return true; |
213 | } |
214 | |
215 | int WebSocketServer::getHandshakeTimeoutSecs() |
216 | { |
217 | return _handshakeTimeoutSecs; |
218 | } |
219 | |
220 | bool WebSocketServer::isPongEnabled() |
221 | { |
222 | return _enablePong; |
223 | } |
224 | |
225 | bool WebSocketServer::isPerMessageDeflateEnabled() |
226 | { |
227 | return _enablePerMessageDeflate; |
228 | } |
229 | } // namespace ix |
230 | |