1/*
2 * IXWebSocket.cpp
3 * Author: Benjamin Sergeant
4 * Copyright (c) 2017-2018 Machine Zone, Inc. All rights reserved.
5 */
6
7#include "IXWebSocket.h"
8
9#include "IXExponentialBackoff.h"
10#include "IXSetThreadName.h"
11#include "IXUniquePtr.h"
12#include "IXUtf8Validator.h"
13#include "IXWebSocketHandshake.h"
14#include <cassert>
15#include <cmath>
16
17
18namespace
19{
20 const std::string emptyMsg;
21} // namespace
22
23
24namespace ix
25{
26 OnTrafficTrackerCallback WebSocket::_onTrafficTrackerCallback = nullptr;
27 const int WebSocket::kDefaultHandShakeTimeoutSecs(60);
28 const int WebSocket::kDefaultPingIntervalSecs(-1);
29 const bool WebSocket::kDefaultEnablePong(true);
30 const uint32_t WebSocket::kDefaultMaxWaitBetweenReconnectionRetries(10 * 1000); // 10s
31 const uint32_t WebSocket::kDefaultMinWaitBetweenReconnectionRetries(1); // 1 ms
32
33 WebSocket::WebSocket()
34 : _onMessageCallback(OnMessageCallback())
35 , _stop(false)
36 , _automaticReconnection(true)
37 , _maxWaitBetweenReconnectionRetries(kDefaultMaxWaitBetweenReconnectionRetries)
38 , _minWaitBetweenReconnectionRetries(kDefaultMinWaitBetweenReconnectionRetries)
39 , _handshakeTimeoutSecs(kDefaultHandShakeTimeoutSecs)
40 , _enablePong(kDefaultEnablePong)
41 , _pingIntervalSecs(kDefaultPingIntervalSecs)
42 {
43 _ws.setOnCloseCallback(
44 [this](uint16_t code, const std::string& reason, size_t wireSize, bool remote) {
45 _onMessageCallback(
46 ix::make_unique<WebSocketMessage>(WebSocketMessageType::Close,
47 emptyMsg,
48 wireSize,
49 WebSocketErrorInfo(),
50 WebSocketOpenInfo(),
51 WebSocketCloseInfo(code, reason, remote)));
52 });
53 }
54
55 WebSocket::~WebSocket()
56 {
57 stop();
58 _ws.setOnCloseCallback(nullptr);
59 }
60
61 void WebSocket::setUrl(const std::string& url)
62 {
63 std::lock_guard<std::mutex> lock(_configMutex);
64 _url = url;
65 }
66
67 void WebSocket::setHandshakeTimeout(int handshakeTimeoutSecs)
68 {
69 _handshakeTimeoutSecs = handshakeTimeoutSecs;
70 }
71
72 void WebSocket::setExtraHeaders(const WebSocketHttpHeaders& headers)
73 {
74 std::lock_guard<std::mutex> lock(_configMutex);
75 _extraHeaders = headers;
76 }
77
78 const std::string WebSocket::getUrl() const
79 {
80 std::lock_guard<std::mutex> lock(_configMutex);
81 return _url;
82 }
83
84 void WebSocket::setPerMessageDeflateOptions(
85 const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions)
86 {
87 std::lock_guard<std::mutex> lock(_configMutex);
88 _perMessageDeflateOptions = perMessageDeflateOptions;
89 }
90
91 void WebSocket::setTLSOptions(const SocketTLSOptions& socketTLSOptions)
92 {
93 std::lock_guard<std::mutex> lock(_configMutex);
94 _socketTLSOptions = socketTLSOptions;
95 }
96
97 const WebSocketPerMessageDeflateOptions WebSocket::getPerMessageDeflateOptions() const
98 {
99 std::lock_guard<std::mutex> lock(_configMutex);
100 return _perMessageDeflateOptions;
101 }
102
103 void WebSocket::setPingInterval(int pingIntervalSecs)
104 {
105 std::lock_guard<std::mutex> lock(_configMutex);
106 _pingIntervalSecs = pingIntervalSecs;
107 }
108
109 int WebSocket::getPingInterval() const
110 {
111 std::lock_guard<std::mutex> lock(_configMutex);
112 return _pingIntervalSecs;
113 }
114
115 void WebSocket::enablePong()
116 {
117 std::lock_guard<std::mutex> lock(_configMutex);
118 _enablePong = true;
119 }
120
121 void WebSocket::disablePong()
122 {
123 std::lock_guard<std::mutex> lock(_configMutex);
124 _enablePong = false;
125 }
126
127 void WebSocket::enablePerMessageDeflate()
128 {
129 std::lock_guard<std::mutex> lock(_configMutex);
130 WebSocketPerMessageDeflateOptions perMessageDeflateOptions(true);
131 _perMessageDeflateOptions = perMessageDeflateOptions;
132 }
133
134 void WebSocket::disablePerMessageDeflate()
135 {
136 std::lock_guard<std::mutex> lock(_configMutex);
137 WebSocketPerMessageDeflateOptions perMessageDeflateOptions(false);
138 _perMessageDeflateOptions = perMessageDeflateOptions;
139 }
140
141 void WebSocket::setMaxWaitBetweenReconnectionRetries(uint32_t maxWaitBetweenReconnectionRetries)
142 {
143 std::lock_guard<std::mutex> lock(_configMutex);
144 _maxWaitBetweenReconnectionRetries = maxWaitBetweenReconnectionRetries;
145 }
146
147 void WebSocket::setMinWaitBetweenReconnectionRetries(uint32_t minWaitBetweenReconnectionRetries)
148 {
149 std::lock_guard<std::mutex> lock(_configMutex);
150 _minWaitBetweenReconnectionRetries = minWaitBetweenReconnectionRetries;
151 }
152
153 uint32_t WebSocket::getMaxWaitBetweenReconnectionRetries() const
154 {
155 std::lock_guard<std::mutex> lock(_configMutex);
156 return _maxWaitBetweenReconnectionRetries;
157 }
158
159 uint32_t WebSocket::getMinWaitBetweenReconnectionRetries() const
160 {
161 std::lock_guard<std::mutex> lock(_configMutex);
162 return _minWaitBetweenReconnectionRetries;
163 }
164
165 void WebSocket::start()
166 {
167 if (_thread.joinable()) return; // we've already been started
168
169 _thread = std::thread(&WebSocket::run, this);
170 }
171
172 void WebSocket::stop(uint16_t code, const std::string& reason)
173 {
174 close(code, reason);
175
176 if (_thread.joinable())
177 {
178 // wait until working thread will exit
179 // it will exit after close operation is finished
180 _stop = true;
181 _sleepCondition.notify_one();
182 _thread.join();
183 _stop = false;
184 }
185 }
186
187 WebSocketInitResult WebSocket::connect(int timeoutSecs)
188 {
189 {
190 std::lock_guard<std::mutex> lock(_configMutex);
191 _ws.configure(
192 _perMessageDeflateOptions, _socketTLSOptions, _enablePong, _pingIntervalSecs);
193 }
194
195 WebSocketHttpHeaders headers(_extraHeaders);
196 std::string subProtocolsHeader;
197 auto subProtocols = getSubProtocols();
198 if (!subProtocols.empty())
199 {
200 //
201 // Sub Protocol strings are comma separated.
202 // Python code to do that is:
203 // >>> ','.join(['json', 'msgpack'])
204 // 'json,msgpack'
205 //
206 int i = 0;
207 for (auto subProtocol : subProtocols)
208 {
209 if (i++ != 0)
210 {
211 subProtocolsHeader += ",";
212 }
213 subProtocolsHeader += subProtocol;
214 }
215 headers["Sec-WebSocket-Protocol"] = subProtocolsHeader;
216 }
217
218 WebSocketInitResult status = _ws.connectToUrl(_url, headers, timeoutSecs);
219 if (!status.success)
220 {
221 return status;
222 }
223
224 _onMessageCallback(ix::make_unique<WebSocketMessage>(
225 WebSocketMessageType::Open,
226 emptyMsg,
227 0,
228 WebSocketErrorInfo(),
229 WebSocketOpenInfo(status.uri, status.headers, status.protocol),
230 WebSocketCloseInfo()));
231
232 if (_pingIntervalSecs > 0)
233 {
234 // Send a heart beat right away
235 _ws.sendHeartBeat();
236 }
237
238 return status;
239 }
240
241 WebSocketInitResult WebSocket::connectToSocket(std::unique_ptr<Socket> socket,
242 int timeoutSecs,
243 bool enablePerMessageDeflate)
244 {
245 {
246 std::lock_guard<std::mutex> lock(_configMutex);
247 _ws.configure(
248 _perMessageDeflateOptions, _socketTLSOptions, _enablePong, _pingIntervalSecs);
249 }
250
251 WebSocketInitResult status =
252 _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate);
253 if (!status.success)
254 {
255 return status;
256 }
257
258 _onMessageCallback(
259 ix::make_unique<WebSocketMessage>(WebSocketMessageType::Open,
260 emptyMsg,
261 0,
262 WebSocketErrorInfo(),
263 WebSocketOpenInfo(status.uri, status.headers),
264 WebSocketCloseInfo()));
265
266 if (_pingIntervalSecs > 0)
267 {
268 // Send a heart beat right away
269 _ws.sendHeartBeat();
270 }
271
272 return status;
273 }
274
275 bool WebSocket::isConnected() const
276 {
277 return getReadyState() == ReadyState::Open;
278 }
279
280 bool WebSocket::isClosing() const
281 {
282 return getReadyState() == ReadyState::Closing;
283 }
284
285 void WebSocket::close(uint16_t code, const std::string& reason)
286 {
287 _ws.close(code, reason);
288 }
289
290 void WebSocket::checkConnection(bool firstConnectionAttempt)
291 {
292 using millis = std::chrono::duration<double, std::milli>;
293
294 uint32_t retries = 0;
295 millis duration(0);
296
297 // Try to connect perpertually
298 while (true)
299 {
300 if (isConnected() || isClosing() || _stop)
301 {
302 break;
303 }
304
305 if (!firstConnectionAttempt && !_automaticReconnection)
306 {
307 // Do not attempt to reconnect
308 break;
309 }
310
311 firstConnectionAttempt = false;
312
313 // Only sleep if we are retrying
314 if (duration.count() > 0)
315 {
316 std::unique_lock<std::mutex> lock(_sleepMutex);
317 _sleepCondition.wait_for(lock, duration);
318 }
319
320 if (_stop)
321 {
322 break;
323 }
324
325 // Try to connect synchronously
326 ix::WebSocketInitResult status = connect(_handshakeTimeoutSecs);
327
328 if (!status.success)
329 {
330 WebSocketErrorInfo connectErr;
331
332 if (_automaticReconnection)
333 {
334 duration =
335 millis(calculateRetryWaitMilliseconds(retries++,
336 _maxWaitBetweenReconnectionRetries,
337 _minWaitBetweenReconnectionRetries));
338
339 connectErr.wait_time = duration.count();
340 connectErr.retries = retries;
341 }
342
343 connectErr.reason = status.errorStr;
344 connectErr.http_status = status.http_status;
345
346 _onMessageCallback(ix::make_unique<WebSocketMessage>(WebSocketMessageType::Error,
347 emptyMsg,
348 0,
349 connectErr,
350 WebSocketOpenInfo(),
351 WebSocketCloseInfo()));
352 }
353 }
354 }
355
356 void WebSocket::run()
357 {
358 setThreadName(getUrl());
359
360 bool firstConnectionAttempt = true;
361
362 while (true)
363 {
364 // 1. Make sure we are always connected
365 checkConnection(firstConnectionAttempt);
366
367 firstConnectionAttempt = false;
368
369 // if here we are closed then checkConnection was not able to connect
370 if (getReadyState() == ReadyState::Closed)
371 {
372 break;
373 }
374
375 // We can avoid to poll if we want to stop and are not closing
376 if (_stop && !isClosing()) break;
377
378 // 2. Poll to see if there's any new data available
379 WebSocketTransport::PollResult pollResult = _ws.poll();
380
381 // 3. Dispatch the incoming messages
382 _ws.dispatch(
383 pollResult,
384 [this](const std::string& msg,
385 size_t wireSize,
386 bool decompressionError,
387 WebSocketTransport::MessageKind messageKind) {
388 WebSocketMessageType webSocketMessageType{WebSocketMessageType::Error};
389 switch (messageKind)
390 {
391 case WebSocketTransport::MessageKind::MSG_TEXT:
392 case WebSocketTransport::MessageKind::MSG_BINARY:
393 {
394 webSocketMessageType = WebSocketMessageType::Message;
395 }
396 break;
397
398 case WebSocketTransport::MessageKind::PING:
399 {
400 webSocketMessageType = WebSocketMessageType::Ping;
401 }
402 break;
403
404 case WebSocketTransport::MessageKind::PONG:
405 {
406 webSocketMessageType = WebSocketMessageType::Pong;
407 }
408 break;
409
410 case WebSocketTransport::MessageKind::FRAGMENT:
411 {
412 webSocketMessageType = WebSocketMessageType::Fragment;
413 }
414 break;
415 }
416
417 WebSocketErrorInfo webSocketErrorInfo;
418 webSocketErrorInfo.decompressionError = decompressionError;
419
420 bool binary = messageKind == WebSocketTransport::MessageKind::MSG_BINARY;
421
422 _onMessageCallback(ix::make_unique<WebSocketMessage>(webSocketMessageType,
423 msg,
424 wireSize,
425 webSocketErrorInfo,
426 WebSocketOpenInfo(),
427 WebSocketCloseInfo(),
428 binary));
429
430 WebSocket::invokeTrafficTrackerCallback(wireSize, true);
431 });
432 }
433 }
434
435 void WebSocket::setOnMessageCallback(const OnMessageCallback& callback)
436 {
437 _onMessageCallback = callback;
438 }
439
440 bool WebSocket::isOnMessageCallbackRegistered() const
441 {
442 return _onMessageCallback != nullptr;
443 }
444
445 void WebSocket::setTrafficTrackerCallback(const OnTrafficTrackerCallback& callback)
446 {
447 _onTrafficTrackerCallback = callback;
448 }
449
450 void WebSocket::resetTrafficTrackerCallback()
451 {
452 setTrafficTrackerCallback(nullptr);
453 }
454
455 void WebSocket::invokeTrafficTrackerCallback(size_t size, bool incoming)
456 {
457 if (_onTrafficTrackerCallback)
458 {
459 _onTrafficTrackerCallback(size, incoming);
460 }
461 }
462
463 WebSocketSendInfo WebSocket::send(const std::string& data,
464 bool binary,
465 const OnProgressCallback& onProgressCallback)
466 {
467 return (binary) ? sendBinary(data, onProgressCallback) : sendText(data, onProgressCallback);
468 }
469
470 WebSocketSendInfo WebSocket::sendBinary(const std::string& data,
471 const OnProgressCallback& onProgressCallback)
472 {
473 return sendMessage(data, SendMessageKind::Binary, onProgressCallback);
474 }
475
476 WebSocketSendInfo WebSocket::sendBinary(const IXWebSocketSendData& data,
477 const OnProgressCallback& onProgressCallback)
478 {
479 return sendMessage(data, SendMessageKind::Binary, onProgressCallback);
480 }
481
482 WebSocketSendInfo WebSocket::sendUtf8Text(const std::string& text,
483 const OnProgressCallback& onProgressCallback)
484 {
485 return sendMessage(text, SendMessageKind::Text, onProgressCallback);
486 }
487
488 WebSocketSendInfo WebSocket::sendUtf8Text(const IXWebSocketSendData& text,
489 const OnProgressCallback& onProgressCallback)
490 {
491 return sendMessage(text, SendMessageKind::Text, onProgressCallback);
492 }
493
494 WebSocketSendInfo WebSocket::sendText(const std::string& text,
495 const OnProgressCallback& onProgressCallback)
496 {
497 if (!validateUtf8(text))
498 {
499 close(WebSocketCloseConstants::kInvalidFramePayloadData,
500 WebSocketCloseConstants::kInvalidFramePayloadDataMessage);
501 return false;
502 }
503 return sendMessage(text, SendMessageKind::Text, onProgressCallback);
504 }
505
506 WebSocketSendInfo WebSocket::ping(const std::string& text)
507 {
508 // Standard limit ping message size
509 constexpr size_t pingMaxPayloadSize = 125;
510 if (text.size() > pingMaxPayloadSize) return WebSocketSendInfo(false);
511
512 return sendMessage(text, SendMessageKind::Ping);
513 }
514
515 WebSocketSendInfo WebSocket::sendMessage(const IXWebSocketSendData& message,
516 SendMessageKind sendMessageKind,
517 const OnProgressCallback& onProgressCallback)
518 {
519 if (!isConnected()) return WebSocketSendInfo(false);
520
521 //
522 // It is OK to read and write on the same socket in 2 different threads.
523 // https://stackoverflow.com/questions/1981372/are-parallel-calls-to-send-recv-on-the-same-socket-valid
524 //
525 // This makes it so that messages are sent right away, and we dont need
526 // a timeout while we poll to keep wake ups to a minimum (which helps
527 // with battery life), and use the system select call to notify us when
528 // incoming messages are arriving / there's data to be received.
529 //
530 std::lock_guard<std::mutex> lock(_writeMutex);
531 WebSocketSendInfo webSocketSendInfo;
532
533 switch (sendMessageKind)
534 {
535 case SendMessageKind::Text:
536 {
537 webSocketSendInfo = _ws.sendText(message, onProgressCallback);
538 }
539 break;
540
541 case SendMessageKind::Binary:
542 {
543 webSocketSendInfo = _ws.sendBinary(message, onProgressCallback);
544 }
545 break;
546
547 case SendMessageKind::Ping:
548 {
549 webSocketSendInfo = _ws.sendPing(message);
550 }
551 break;
552 }
553
554 WebSocket::invokeTrafficTrackerCallback(webSocketSendInfo.wireSize, false);
555
556 return webSocketSendInfo;
557 }
558
559 ReadyState WebSocket::getReadyState() const
560 {
561 switch (_ws.getReadyState())
562 {
563 case ix::WebSocketTransport::ReadyState::OPEN: return ReadyState::Open;
564 case ix::WebSocketTransport::ReadyState::CONNECTING: return ReadyState::Connecting;
565 case ix::WebSocketTransport::ReadyState::CLOSING: return ReadyState::Closing;
566 case ix::WebSocketTransport::ReadyState::CLOSED: return ReadyState::Closed;
567 default: return ReadyState::Closed;
568 }
569 }
570
571 std::string WebSocket::readyStateToString(ReadyState readyState)
572 {
573 switch (readyState)
574 {
575 case ReadyState::Open: return "OPEN";
576 case ReadyState::Connecting: return "CONNECTING";
577 case ReadyState::Closing: return "CLOSING";
578 case ReadyState::Closed: return "CLOSED";
579 default: return "UNKNOWN";
580 }
581 }
582
583 void WebSocket::enableAutomaticReconnection()
584 {
585 _automaticReconnection = true;
586 }
587
588 void WebSocket::disableAutomaticReconnection()
589 {
590 _automaticReconnection = false;
591 }
592
593 bool WebSocket::isAutomaticReconnectionEnabled() const
594 {
595 return _automaticReconnection;
596 }
597
598 size_t WebSocket::bufferedAmount() const
599 {
600 return _ws.bufferedAmount();
601 }
602
603 void WebSocket::addSubProtocol(const std::string& subProtocol)
604 {
605 std::lock_guard<std::mutex> lock(_configMutex);
606 _subProtocols.push_back(subProtocol);
607 }
608
609 const std::vector<std::string>& WebSocket::getSubProtocols()
610 {
611 std::lock_guard<std::mutex> lock(_configMutex);
612 return _subProtocols;
613 }
614} // namespace ix
615