1 | /* |
2 | * IXWebSocket.cpp |
3 | * Author: Benjamin Sergeant |
4 | * Copyright (c) 2017-2018 Machine Zone, Inc. All rights reserved. |
5 | */ |
6 | |
7 | #include "IXWebSocket.h" |
8 | |
9 | #include "IXExponentialBackoff.h" |
10 | #include "IXSetThreadName.h" |
11 | #include "IXUniquePtr.h" |
12 | #include "IXUtf8Validator.h" |
13 | #include "IXWebSocketHandshake.h" |
14 | #include <cassert> |
15 | #include <cmath> |
16 | |
17 | |
18 | namespace |
19 | { |
20 | const std::string emptyMsg; |
21 | } // namespace |
22 | |
23 | |
24 | namespace ix |
25 | { |
26 | OnTrafficTrackerCallback WebSocket::_onTrafficTrackerCallback = nullptr; |
27 | const int WebSocket::kDefaultHandShakeTimeoutSecs(60); |
28 | const int WebSocket::kDefaultPingIntervalSecs(-1); |
29 | const bool WebSocket::kDefaultEnablePong(true); |
30 | const uint32_t WebSocket::kDefaultMaxWaitBetweenReconnectionRetries(10 * 1000); // 10s |
31 | const uint32_t WebSocket::kDefaultMinWaitBetweenReconnectionRetries(1); // 1 ms |
32 | |
33 | WebSocket::WebSocket() |
34 | : _onMessageCallback(OnMessageCallback()) |
35 | , _stop(false) |
36 | , _automaticReconnection(true) |
37 | , _maxWaitBetweenReconnectionRetries(kDefaultMaxWaitBetweenReconnectionRetries) |
38 | , _minWaitBetweenReconnectionRetries(kDefaultMinWaitBetweenReconnectionRetries) |
39 | , _handshakeTimeoutSecs(kDefaultHandShakeTimeoutSecs) |
40 | , _enablePong(kDefaultEnablePong) |
41 | , _pingIntervalSecs(kDefaultPingIntervalSecs) |
42 | { |
43 | _ws.setOnCloseCallback( |
44 | [this](uint16_t code, const std::string& reason, size_t wireSize, bool remote) { |
45 | _onMessageCallback( |
46 | ix::make_unique<WebSocketMessage>(WebSocketMessageType::Close, |
47 | emptyMsg, |
48 | wireSize, |
49 | WebSocketErrorInfo(), |
50 | WebSocketOpenInfo(), |
51 | WebSocketCloseInfo(code, reason, remote))); |
52 | }); |
53 | } |
54 | |
55 | WebSocket::~WebSocket() |
56 | { |
57 | stop(); |
58 | _ws.setOnCloseCallback(nullptr); |
59 | } |
60 | |
61 | void WebSocket::setUrl(const std::string& url) |
62 | { |
63 | std::lock_guard<std::mutex> lock(_configMutex); |
64 | _url = url; |
65 | } |
66 | |
67 | void WebSocket::setHandshakeTimeout(int handshakeTimeoutSecs) |
68 | { |
69 | _handshakeTimeoutSecs = handshakeTimeoutSecs; |
70 | } |
71 | |
72 | void WebSocket::(const WebSocketHttpHeaders& ) |
73 | { |
74 | std::lock_guard<std::mutex> lock(_configMutex); |
75 | _extraHeaders = headers; |
76 | } |
77 | |
78 | const std::string WebSocket::getUrl() const |
79 | { |
80 | std::lock_guard<std::mutex> lock(_configMutex); |
81 | return _url; |
82 | } |
83 | |
84 | void WebSocket::setPerMessageDeflateOptions( |
85 | const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions) |
86 | { |
87 | std::lock_guard<std::mutex> lock(_configMutex); |
88 | _perMessageDeflateOptions = perMessageDeflateOptions; |
89 | } |
90 | |
91 | void WebSocket::setTLSOptions(const SocketTLSOptions& socketTLSOptions) |
92 | { |
93 | std::lock_guard<std::mutex> lock(_configMutex); |
94 | _socketTLSOptions = socketTLSOptions; |
95 | } |
96 | |
97 | const WebSocketPerMessageDeflateOptions WebSocket::getPerMessageDeflateOptions() const |
98 | { |
99 | std::lock_guard<std::mutex> lock(_configMutex); |
100 | return _perMessageDeflateOptions; |
101 | } |
102 | |
103 | void WebSocket::setPingInterval(int pingIntervalSecs) |
104 | { |
105 | std::lock_guard<std::mutex> lock(_configMutex); |
106 | _pingIntervalSecs = pingIntervalSecs; |
107 | } |
108 | |
109 | int WebSocket::getPingInterval() const |
110 | { |
111 | std::lock_guard<std::mutex> lock(_configMutex); |
112 | return _pingIntervalSecs; |
113 | } |
114 | |
115 | void WebSocket::enablePong() |
116 | { |
117 | std::lock_guard<std::mutex> lock(_configMutex); |
118 | _enablePong = true; |
119 | } |
120 | |
121 | void WebSocket::disablePong() |
122 | { |
123 | std::lock_guard<std::mutex> lock(_configMutex); |
124 | _enablePong = false; |
125 | } |
126 | |
127 | void WebSocket::enablePerMessageDeflate() |
128 | { |
129 | std::lock_guard<std::mutex> lock(_configMutex); |
130 | WebSocketPerMessageDeflateOptions perMessageDeflateOptions(true); |
131 | _perMessageDeflateOptions = perMessageDeflateOptions; |
132 | } |
133 | |
134 | void WebSocket::disablePerMessageDeflate() |
135 | { |
136 | std::lock_guard<std::mutex> lock(_configMutex); |
137 | WebSocketPerMessageDeflateOptions perMessageDeflateOptions(false); |
138 | _perMessageDeflateOptions = perMessageDeflateOptions; |
139 | } |
140 | |
141 | void WebSocket::setMaxWaitBetweenReconnectionRetries(uint32_t maxWaitBetweenReconnectionRetries) |
142 | { |
143 | std::lock_guard<std::mutex> lock(_configMutex); |
144 | _maxWaitBetweenReconnectionRetries = maxWaitBetweenReconnectionRetries; |
145 | } |
146 | |
147 | void WebSocket::setMinWaitBetweenReconnectionRetries(uint32_t minWaitBetweenReconnectionRetries) |
148 | { |
149 | std::lock_guard<std::mutex> lock(_configMutex); |
150 | _minWaitBetweenReconnectionRetries = minWaitBetweenReconnectionRetries; |
151 | } |
152 | |
153 | uint32_t WebSocket::getMaxWaitBetweenReconnectionRetries() const |
154 | { |
155 | std::lock_guard<std::mutex> lock(_configMutex); |
156 | return _maxWaitBetweenReconnectionRetries; |
157 | } |
158 | |
159 | uint32_t WebSocket::getMinWaitBetweenReconnectionRetries() const |
160 | { |
161 | std::lock_guard<std::mutex> lock(_configMutex); |
162 | return _minWaitBetweenReconnectionRetries; |
163 | } |
164 | |
165 | void WebSocket::start() |
166 | { |
167 | if (_thread.joinable()) return; // we've already been started |
168 | |
169 | _thread = std::thread(&WebSocket::run, this); |
170 | } |
171 | |
172 | void WebSocket::stop(uint16_t code, const std::string& reason) |
173 | { |
174 | close(code, reason); |
175 | |
176 | if (_thread.joinable()) |
177 | { |
178 | // wait until working thread will exit |
179 | // it will exit after close operation is finished |
180 | _stop = true; |
181 | _sleepCondition.notify_one(); |
182 | _thread.join(); |
183 | _stop = false; |
184 | } |
185 | } |
186 | |
187 | WebSocketInitResult WebSocket::connect(int timeoutSecs) |
188 | { |
189 | { |
190 | std::lock_guard<std::mutex> lock(_configMutex); |
191 | _ws.configure( |
192 | _perMessageDeflateOptions, _socketTLSOptions, _enablePong, _pingIntervalSecs); |
193 | } |
194 | |
195 | WebSocketHttpHeaders (_extraHeaders); |
196 | std::string ; |
197 | auto subProtocols = getSubProtocols(); |
198 | if (!subProtocols.empty()) |
199 | { |
200 | // |
201 | // Sub Protocol strings are comma separated. |
202 | // Python code to do that is: |
203 | // >>> ','.join(['json', 'msgpack']) |
204 | // 'json,msgpack' |
205 | // |
206 | int i = 0; |
207 | for (auto subProtocol : subProtocols) |
208 | { |
209 | if (i++ != 0) |
210 | { |
211 | subProtocolsHeader += "," ; |
212 | } |
213 | subProtocolsHeader += subProtocol; |
214 | } |
215 | headers["Sec-WebSocket-Protocol" ] = subProtocolsHeader; |
216 | } |
217 | |
218 | WebSocketInitResult status = _ws.connectToUrl(_url, headers, timeoutSecs); |
219 | if (!status.success) |
220 | { |
221 | return status; |
222 | } |
223 | |
224 | _onMessageCallback(ix::make_unique<WebSocketMessage>( |
225 | WebSocketMessageType::Open, |
226 | emptyMsg, |
227 | 0, |
228 | WebSocketErrorInfo(), |
229 | WebSocketOpenInfo(status.uri, status.headers, status.protocol), |
230 | WebSocketCloseInfo())); |
231 | |
232 | if (_pingIntervalSecs > 0) |
233 | { |
234 | // Send a heart beat right away |
235 | _ws.sendHeartBeat(); |
236 | } |
237 | |
238 | return status; |
239 | } |
240 | |
241 | WebSocketInitResult WebSocket::connectToSocket(std::unique_ptr<Socket> socket, |
242 | int timeoutSecs, |
243 | bool enablePerMessageDeflate) |
244 | { |
245 | { |
246 | std::lock_guard<std::mutex> lock(_configMutex); |
247 | _ws.configure( |
248 | _perMessageDeflateOptions, _socketTLSOptions, _enablePong, _pingIntervalSecs); |
249 | } |
250 | |
251 | WebSocketInitResult status = |
252 | _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate); |
253 | if (!status.success) |
254 | { |
255 | return status; |
256 | } |
257 | |
258 | _onMessageCallback( |
259 | ix::make_unique<WebSocketMessage>(WebSocketMessageType::Open, |
260 | emptyMsg, |
261 | 0, |
262 | WebSocketErrorInfo(), |
263 | WebSocketOpenInfo(status.uri, status.headers), |
264 | WebSocketCloseInfo())); |
265 | |
266 | if (_pingIntervalSecs > 0) |
267 | { |
268 | // Send a heart beat right away |
269 | _ws.sendHeartBeat(); |
270 | } |
271 | |
272 | return status; |
273 | } |
274 | |
275 | bool WebSocket::isConnected() const |
276 | { |
277 | return getReadyState() == ReadyState::Open; |
278 | } |
279 | |
280 | bool WebSocket::isClosing() const |
281 | { |
282 | return getReadyState() == ReadyState::Closing; |
283 | } |
284 | |
285 | void WebSocket::close(uint16_t code, const std::string& reason) |
286 | { |
287 | _ws.close(code, reason); |
288 | } |
289 | |
290 | void WebSocket::checkConnection(bool firstConnectionAttempt) |
291 | { |
292 | using millis = std::chrono::duration<double, std::milli>; |
293 | |
294 | uint32_t retries = 0; |
295 | millis duration(0); |
296 | |
297 | // Try to connect perpertually |
298 | while (true) |
299 | { |
300 | if (isConnected() || isClosing() || _stop) |
301 | { |
302 | break; |
303 | } |
304 | |
305 | if (!firstConnectionAttempt && !_automaticReconnection) |
306 | { |
307 | // Do not attempt to reconnect |
308 | break; |
309 | } |
310 | |
311 | firstConnectionAttempt = false; |
312 | |
313 | // Only sleep if we are retrying |
314 | if (duration.count() > 0) |
315 | { |
316 | std::unique_lock<std::mutex> lock(_sleepMutex); |
317 | _sleepCondition.wait_for(lock, duration); |
318 | } |
319 | |
320 | if (_stop) |
321 | { |
322 | break; |
323 | } |
324 | |
325 | // Try to connect synchronously |
326 | ix::WebSocketInitResult status = connect(_handshakeTimeoutSecs); |
327 | |
328 | if (!status.success) |
329 | { |
330 | WebSocketErrorInfo connectErr; |
331 | |
332 | if (_automaticReconnection) |
333 | { |
334 | duration = |
335 | millis(calculateRetryWaitMilliseconds(retries++, |
336 | _maxWaitBetweenReconnectionRetries, |
337 | _minWaitBetweenReconnectionRetries)); |
338 | |
339 | connectErr.wait_time = duration.count(); |
340 | connectErr.retries = retries; |
341 | } |
342 | |
343 | connectErr.reason = status.errorStr; |
344 | connectErr.http_status = status.http_status; |
345 | |
346 | _onMessageCallback(ix::make_unique<WebSocketMessage>(WebSocketMessageType::Error, |
347 | emptyMsg, |
348 | 0, |
349 | connectErr, |
350 | WebSocketOpenInfo(), |
351 | WebSocketCloseInfo())); |
352 | } |
353 | } |
354 | } |
355 | |
356 | void WebSocket::run() |
357 | { |
358 | setThreadName(getUrl()); |
359 | |
360 | bool firstConnectionAttempt = true; |
361 | |
362 | while (true) |
363 | { |
364 | // 1. Make sure we are always connected |
365 | checkConnection(firstConnectionAttempt); |
366 | |
367 | firstConnectionAttempt = false; |
368 | |
369 | // if here we are closed then checkConnection was not able to connect |
370 | if (getReadyState() == ReadyState::Closed) |
371 | { |
372 | break; |
373 | } |
374 | |
375 | // We can avoid to poll if we want to stop and are not closing |
376 | if (_stop && !isClosing()) break; |
377 | |
378 | // 2. Poll to see if there's any new data available |
379 | WebSocketTransport::PollResult pollResult = _ws.poll(); |
380 | |
381 | // 3. Dispatch the incoming messages |
382 | _ws.dispatch( |
383 | pollResult, |
384 | [this](const std::string& msg, |
385 | size_t wireSize, |
386 | bool decompressionError, |
387 | WebSocketTransport::MessageKind messageKind) { |
388 | WebSocketMessageType webSocketMessageType{WebSocketMessageType::Error}; |
389 | switch (messageKind) |
390 | { |
391 | case WebSocketTransport::MessageKind::MSG_TEXT: |
392 | case WebSocketTransport::MessageKind::MSG_BINARY: |
393 | { |
394 | webSocketMessageType = WebSocketMessageType::Message; |
395 | } |
396 | break; |
397 | |
398 | case WebSocketTransport::MessageKind::PING: |
399 | { |
400 | webSocketMessageType = WebSocketMessageType::Ping; |
401 | } |
402 | break; |
403 | |
404 | case WebSocketTransport::MessageKind::PONG: |
405 | { |
406 | webSocketMessageType = WebSocketMessageType::Pong; |
407 | } |
408 | break; |
409 | |
410 | case WebSocketTransport::MessageKind::FRAGMENT: |
411 | { |
412 | webSocketMessageType = WebSocketMessageType::Fragment; |
413 | } |
414 | break; |
415 | } |
416 | |
417 | WebSocketErrorInfo webSocketErrorInfo; |
418 | webSocketErrorInfo.decompressionError = decompressionError; |
419 | |
420 | bool binary = messageKind == WebSocketTransport::MessageKind::MSG_BINARY; |
421 | |
422 | _onMessageCallback(ix::make_unique<WebSocketMessage>(webSocketMessageType, |
423 | msg, |
424 | wireSize, |
425 | webSocketErrorInfo, |
426 | WebSocketOpenInfo(), |
427 | WebSocketCloseInfo(), |
428 | binary)); |
429 | |
430 | WebSocket::invokeTrafficTrackerCallback(wireSize, true); |
431 | }); |
432 | } |
433 | } |
434 | |
435 | void WebSocket::setOnMessageCallback(const OnMessageCallback& callback) |
436 | { |
437 | _onMessageCallback = callback; |
438 | } |
439 | |
440 | bool WebSocket::isOnMessageCallbackRegistered() const |
441 | { |
442 | return _onMessageCallback != nullptr; |
443 | } |
444 | |
445 | void WebSocket::setTrafficTrackerCallback(const OnTrafficTrackerCallback& callback) |
446 | { |
447 | _onTrafficTrackerCallback = callback; |
448 | } |
449 | |
450 | void WebSocket::resetTrafficTrackerCallback() |
451 | { |
452 | setTrafficTrackerCallback(nullptr); |
453 | } |
454 | |
455 | void WebSocket::invokeTrafficTrackerCallback(size_t size, bool incoming) |
456 | { |
457 | if (_onTrafficTrackerCallback) |
458 | { |
459 | _onTrafficTrackerCallback(size, incoming); |
460 | } |
461 | } |
462 | |
463 | WebSocketSendInfo WebSocket::send(const std::string& data, |
464 | bool binary, |
465 | const OnProgressCallback& onProgressCallback) |
466 | { |
467 | return (binary) ? sendBinary(data, onProgressCallback) : sendText(data, onProgressCallback); |
468 | } |
469 | |
470 | WebSocketSendInfo WebSocket::sendBinary(const std::string& data, |
471 | const OnProgressCallback& onProgressCallback) |
472 | { |
473 | return sendMessage(data, SendMessageKind::Binary, onProgressCallback); |
474 | } |
475 | |
476 | WebSocketSendInfo WebSocket::sendBinary(const IXWebSocketSendData& data, |
477 | const OnProgressCallback& onProgressCallback) |
478 | { |
479 | return sendMessage(data, SendMessageKind::Binary, onProgressCallback); |
480 | } |
481 | |
482 | WebSocketSendInfo WebSocket::sendUtf8Text(const std::string& text, |
483 | const OnProgressCallback& onProgressCallback) |
484 | { |
485 | return sendMessage(text, SendMessageKind::Text, onProgressCallback); |
486 | } |
487 | |
488 | WebSocketSendInfo WebSocket::sendUtf8Text(const IXWebSocketSendData& text, |
489 | const OnProgressCallback& onProgressCallback) |
490 | { |
491 | return sendMessage(text, SendMessageKind::Text, onProgressCallback); |
492 | } |
493 | |
494 | WebSocketSendInfo WebSocket::sendText(const std::string& text, |
495 | const OnProgressCallback& onProgressCallback) |
496 | { |
497 | if (!validateUtf8(text)) |
498 | { |
499 | close(WebSocketCloseConstants::kInvalidFramePayloadData, |
500 | WebSocketCloseConstants::kInvalidFramePayloadDataMessage); |
501 | return false; |
502 | } |
503 | return sendMessage(text, SendMessageKind::Text, onProgressCallback); |
504 | } |
505 | |
506 | WebSocketSendInfo WebSocket::ping(const std::string& text) |
507 | { |
508 | // Standard limit ping message size |
509 | constexpr size_t pingMaxPayloadSize = 125; |
510 | if (text.size() > pingMaxPayloadSize) return WebSocketSendInfo(false); |
511 | |
512 | return sendMessage(text, SendMessageKind::Ping); |
513 | } |
514 | |
515 | WebSocketSendInfo WebSocket::sendMessage(const IXWebSocketSendData& message, |
516 | SendMessageKind sendMessageKind, |
517 | const OnProgressCallback& onProgressCallback) |
518 | { |
519 | if (!isConnected()) return WebSocketSendInfo(false); |
520 | |
521 | // |
522 | // It is OK to read and write on the same socket in 2 different threads. |
523 | // https://stackoverflow.com/questions/1981372/are-parallel-calls-to-send-recv-on-the-same-socket-valid |
524 | // |
525 | // This makes it so that messages are sent right away, and we dont need |
526 | // a timeout while we poll to keep wake ups to a minimum (which helps |
527 | // with battery life), and use the system select call to notify us when |
528 | // incoming messages are arriving / there's data to be received. |
529 | // |
530 | std::lock_guard<std::mutex> lock(_writeMutex); |
531 | WebSocketSendInfo webSocketSendInfo; |
532 | |
533 | switch (sendMessageKind) |
534 | { |
535 | case SendMessageKind::Text: |
536 | { |
537 | webSocketSendInfo = _ws.sendText(message, onProgressCallback); |
538 | } |
539 | break; |
540 | |
541 | case SendMessageKind::Binary: |
542 | { |
543 | webSocketSendInfo = _ws.sendBinary(message, onProgressCallback); |
544 | } |
545 | break; |
546 | |
547 | case SendMessageKind::Ping: |
548 | { |
549 | webSocketSendInfo = _ws.sendPing(message); |
550 | } |
551 | break; |
552 | } |
553 | |
554 | WebSocket::invokeTrafficTrackerCallback(webSocketSendInfo.wireSize, false); |
555 | |
556 | return webSocketSendInfo; |
557 | } |
558 | |
559 | ReadyState WebSocket::getReadyState() const |
560 | { |
561 | switch (_ws.getReadyState()) |
562 | { |
563 | case ix::WebSocketTransport::ReadyState::OPEN: return ReadyState::Open; |
564 | case ix::WebSocketTransport::ReadyState::CONNECTING: return ReadyState::Connecting; |
565 | case ix::WebSocketTransport::ReadyState::CLOSING: return ReadyState::Closing; |
566 | case ix::WebSocketTransport::ReadyState::CLOSED: return ReadyState::Closed; |
567 | default: return ReadyState::Closed; |
568 | } |
569 | } |
570 | |
571 | std::string WebSocket::readyStateToString(ReadyState readyState) |
572 | { |
573 | switch (readyState) |
574 | { |
575 | case ReadyState::Open: return "OPEN" ; |
576 | case ReadyState::Connecting: return "CONNECTING" ; |
577 | case ReadyState::Closing: return "CLOSING" ; |
578 | case ReadyState::Closed: return "CLOSED" ; |
579 | default: return "UNKNOWN" ; |
580 | } |
581 | } |
582 | |
583 | void WebSocket::enableAutomaticReconnection() |
584 | { |
585 | _automaticReconnection = true; |
586 | } |
587 | |
588 | void WebSocket::disableAutomaticReconnection() |
589 | { |
590 | _automaticReconnection = false; |
591 | } |
592 | |
593 | bool WebSocket::isAutomaticReconnectionEnabled() const |
594 | { |
595 | return _automaticReconnection; |
596 | } |
597 | |
598 | size_t WebSocket::bufferedAmount() const |
599 | { |
600 | return _ws.bufferedAmount(); |
601 | } |
602 | |
603 | void WebSocket::addSubProtocol(const std::string& subProtocol) |
604 | { |
605 | std::lock_guard<std::mutex> lock(_configMutex); |
606 | _subProtocols.push_back(subProtocol); |
607 | } |
608 | |
609 | const std::vector<std::string>& WebSocket::getSubProtocols() |
610 | { |
611 | std::lock_guard<std::mutex> lock(_configMutex); |
612 | return _subProtocols; |
613 | } |
614 | } // namespace ix |
615 | |