1/*
2 * The MIT License (MIT)
3 *
4 * Copyright (c) 2012, 2013 <dhbaird@gmail.com>
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to deal
8 * in the Software without restriction, including without limitation the rights
9 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 * copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 * THE SOFTWARE.
23 */
24
25/*
26 * IXWebSocketTransport.cpp
27 * Author: Benjamin Sergeant
28 * Copyright (c) 2017-2019 Machine Zone, Inc. All rights reserved.
29 */
30
31//
32// Adapted from https://github.com/dhbaird/easywsclient
33//
34
35#include "IXWebSocketTransport.h"
36
37#include "IXSocketFactory.h"
38#include "IXSocketTLSOptions.h"
39#include "IXUniquePtr.h"
40#include "IXUrlParser.h"
41#include "IXUtf8Validator.h"
42#include "IXWebSocketHandshake.h"
43#include "IXWebSocketHttpHeaders.h"
44#include <chrono>
45#include <cstdarg>
46#include <cstdlib>
47#include <sstream>
48#include <stdlib.h>
49#include <string.h>
50#include <string>
51#include <thread>
52#include <vector>
53
54
55namespace ix
56{
57 const std::string WebSocketTransport::kPingMessage("ixwebsocket::heartbeat");
58 const int WebSocketTransport::kDefaultPingIntervalSecs(-1);
59 const bool WebSocketTransport::kDefaultEnablePong(true);
60 const int WebSocketTransport::kClosingMaximumWaitingDelayInMs(300);
61 constexpr size_t WebSocketTransport::kChunkSize;
62
63 WebSocketTransport::WebSocketTransport()
64 : _useMask(true)
65 , _blockingSend(false)
66 , _receivedMessageCompressed(false)
67 , _readyState(ReadyState::CLOSED)
68 , _closeCode(WebSocketCloseConstants::kInternalErrorCode)
69 , _closeWireSize(0)
70 , _closeRemote(false)
71 , _enablePerMessageDeflate(false)
72 , _requestInitCancellation(false)
73 , _closingTimePoint(std::chrono::steady_clock::now())
74 , _enablePong(kDefaultEnablePong)
75 , _pingIntervalSecs(kDefaultPingIntervalSecs)
76 , _pongReceived(false)
77 , _pingCount(0)
78 , _lastSendPingTimePoint(std::chrono::steady_clock::now())
79 {
80 setCloseReason(WebSocketCloseConstants::kInternalErrorMessage);
81 _readbuf.resize(kChunkSize);
82 }
83
84 WebSocketTransport::~WebSocketTransport()
85 {
86 ;
87 }
88
89 void WebSocketTransport::configure(
90 const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions,
91 const SocketTLSOptions& socketTLSOptions,
92 bool enablePong,
93 int pingIntervalSecs)
94 {
95 _perMessageDeflateOptions = perMessageDeflateOptions;
96 _enablePerMessageDeflate = _perMessageDeflateOptions.enabled();
97 _socketTLSOptions = socketTLSOptions;
98 _enablePong = enablePong;
99 _pingIntervalSecs = pingIntervalSecs;
100 }
101
102 // Client
103 WebSocketInitResult WebSocketTransport::connectToUrl(const std::string& url,
104 const WebSocketHttpHeaders& headers,
105 int timeoutSecs)
106 {
107 std::lock_guard<std::mutex> lock(_socketMutex);
108
109 std::string protocol, host, path, query;
110 int port;
111 std::string remoteUrl(url);
112
113 WebSocketInitResult result;
114 const int maxRedirections = 10;
115
116 for (int i = 0; i < maxRedirections; ++i)
117 {
118 if (!UrlParser::parse(remoteUrl, protocol, host, path, query, port))
119 {
120 std::stringstream ss;
121 ss << "Could not parse url: '" << url << "'";
122 return WebSocketInitResult(false, 0, ss.str());
123 }
124
125 std::string errorMsg;
126 bool tls = protocol == "wss";
127 _socket = createSocket(tls, -1, errorMsg, _socketTLSOptions);
128 _perMessageDeflate = ix::make_unique<WebSocketPerMessageDeflate>();
129
130 if (!_socket)
131 {
132 return WebSocketInitResult(false, 0, errorMsg);
133 }
134
135 WebSocketHandshake webSocketHandshake(_requestInitCancellation,
136 _socket,
137 _perMessageDeflate,
138 _perMessageDeflateOptions,
139 _enablePerMessageDeflate);
140
141 result = webSocketHandshake.clientHandshake(
142 remoteUrl, headers, host, path, port, timeoutSecs);
143
144 if (result.http_status >= 300 && result.http_status < 400)
145 {
146 auto it = result.headers.find("Location");
147 if (it == result.headers.end())
148 {
149 std::stringstream ss;
150 ss << "Missing Location Header for HTTP Redirect response. "
151 << "Rejecting connection to " << url << ", status: " << result.http_status;
152 result.errorStr = ss.str();
153 break;
154 }
155
156 remoteUrl = it->second;
157 continue;
158 }
159
160 if (result.success)
161 {
162 setReadyState(ReadyState::OPEN);
163 }
164 return result;
165 }
166
167 return result;
168 }
169
170 // Server
171 WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket,
172 int timeoutSecs,
173 bool enablePerMessageDeflate)
174 {
175 std::lock_guard<std::mutex> lock(_socketMutex);
176
177 // Server should not mask the data it sends to the client
178 _useMask = false;
179 _blockingSend = true;
180
181 _socket = std::move(socket);
182 _perMessageDeflate = ix::make_unique<WebSocketPerMessageDeflate>();
183
184 WebSocketHandshake webSocketHandshake(_requestInitCancellation,
185 _socket,
186 _perMessageDeflate,
187 _perMessageDeflateOptions,
188 _enablePerMessageDeflate);
189
190 auto result = webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate);
191 if (result.success)
192 {
193 setReadyState(ReadyState::OPEN);
194 }
195 return result;
196 }
197
198 WebSocketTransport::ReadyState WebSocketTransport::getReadyState() const
199 {
200 return _readyState;
201 }
202
203 void WebSocketTransport::setReadyState(ReadyState readyState)
204 {
205 // No state change, return
206 if (_readyState == readyState) return;
207
208 if (readyState == ReadyState::CLOSED)
209 {
210 if (_onCloseCallback)
211 {
212 _onCloseCallback(_closeCode, getCloseReason(), _closeWireSize, _closeRemote);
213 }
214 setCloseReason(WebSocketCloseConstants::kInternalErrorMessage);
215 _closeCode = WebSocketCloseConstants::kInternalErrorCode;
216 _closeWireSize = 0;
217 _closeRemote = false;
218 }
219 else if (readyState == ReadyState::OPEN)
220 {
221 initTimePointsAfterConnect();
222 _pongReceived = false;
223 }
224
225 _readyState = readyState;
226 }
227
228 void WebSocketTransport::setOnCloseCallback(const OnCloseCallback& onCloseCallback)
229 {
230 _onCloseCallback = onCloseCallback;
231 }
232
233 void WebSocketTransport::initTimePointsAfterConnect()
234 {
235 {
236 std::lock_guard<std::mutex> lock(_lastSendPingTimePointMutex);
237 _lastSendPingTimePoint = std::chrono::steady_clock::now();
238 }
239 }
240
241 // Only consider send PING time points for that computation.
242 bool WebSocketTransport::pingIntervalExceeded()
243 {
244 if (_pingIntervalSecs <= 0) return false;
245
246 std::lock_guard<std::mutex> lock(_lastSendPingTimePointMutex);
247 auto now = std::chrono::steady_clock::now();
248 return now - _lastSendPingTimePoint > std::chrono::seconds(_pingIntervalSecs);
249 }
250
251 WebSocketSendInfo WebSocketTransport::sendHeartBeat()
252 {
253 _pongReceived = false;
254 std::stringstream ss;
255 ss << kPingMessage << "::" << _pingIntervalSecs << "s"
256 << "::" << _pingCount++;
257 return sendPing(ss.str());
258 }
259
260 bool WebSocketTransport::closingDelayExceeded()
261 {
262 std::lock_guard<std::mutex> lock(_closingTimePointMutex);
263 auto now = std::chrono::steady_clock::now();
264 return now - _closingTimePoint > std::chrono::milliseconds(kClosingMaximumWaitingDelayInMs);
265 }
266
267 WebSocketTransport::PollResult WebSocketTransport::poll()
268 {
269 if (_readyState == ReadyState::OPEN)
270 {
271 if (pingIntervalExceeded())
272 {
273 if (!_pongReceived)
274 {
275 // ping response (PONG) exceeds the maximum delay, close the connection
276 close(WebSocketCloseConstants::kInternalErrorCode,
277 WebSocketCloseConstants::kPingTimeoutMessage);
278 }
279 else
280 {
281 sendHeartBeat();
282 }
283 }
284 }
285
286 // No timeout if state is not OPEN, otherwise computed
287 // pingIntervalOrTimeoutGCD (equals to -1 if no ping and no ping timeout are set)
288 int lastingTimeoutDelayInMs = (_readyState != ReadyState::OPEN) ? 0 : _pingIntervalSecs;
289
290 if (_pingIntervalSecs > 0)
291 {
292 // compute lasting delay to wait for next ping / timeout, if at least one set
293 auto now = std::chrono::steady_clock::now();
294 int timeSinceLastPingMs = (int) std::chrono::duration_cast<std::chrono::milliseconds>(
295 now - _lastSendPingTimePoint)
296 .count();
297 lastingTimeoutDelayInMs = (1000 * _pingIntervalSecs) - timeSinceLastPingMs;
298 }
299
300 // The platform may not have select interrupt capabilities, so wait with a small timeout
301 if (lastingTimeoutDelayInMs <= 0 && !_socket->isWakeUpFromPollSupported())
302 {
303 lastingTimeoutDelayInMs = 20;
304 }
305
306 // If we are requesting a cancellation, pass in a positive and small timeout
307 // to never poll forever without a timeout.
308 if (_requestInitCancellation)
309 {
310 lastingTimeoutDelayInMs = 100;
311 }
312
313 // poll the socket
314 PollResultType pollResult = _socket->isReadyToRead(lastingTimeoutDelayInMs);
315
316 // Make sure we send all the buffered data
317 // there can be a lot of it for large messages.
318 if (pollResult == PollResultType::SendRequest)
319 {
320 if (!flushSendBuffer())
321 {
322 return PollResult::CannotFlushSendBuffer;
323 }
324 }
325 else if (pollResult == PollResultType::ReadyForRead)
326 {
327 if (!receiveFromSocket())
328 {
329 return PollResult::AbnormalClose;
330 }
331 }
332 else if (pollResult == PollResultType::Error)
333 {
334 closeSocket();
335 }
336 else if (pollResult == PollResultType::CloseRequest)
337 {
338 closeSocket();
339 }
340
341 if (_readyState == ReadyState::CLOSING && closingDelayExceeded())
342 {
343 _rxbuf.clear();
344 // close code and reason were set when calling close()
345 closeSocket();
346 setReadyState(ReadyState::CLOSED);
347 }
348
349 return PollResult::Succeeded;
350 }
351
352 bool WebSocketTransport::isSendBufferEmpty() const
353 {
354 std::lock_guard<std::mutex> lock(_txbufMutex);
355 return _txbuf.empty();
356 }
357
358 template<class Iterator>
359 void WebSocketTransport::appendToSendBuffer(const std::vector<uint8_t>& header,
360 Iterator begin,
361 Iterator end,
362 uint64_t message_size,
363 uint8_t masking_key[4])
364 {
365 std::lock_guard<std::mutex> lock(_txbufMutex);
366
367 _txbuf.insert(_txbuf.end(), header.begin(), header.end());
368 _txbuf.insert(_txbuf.end(), begin, end);
369
370 if (_useMask)
371 {
372 for (size_t i = 0; i != (size_t) message_size; ++i)
373 {
374 *(_txbuf.end() - (size_t) message_size + i) ^= masking_key[i & 0x3];
375 }
376 }
377 }
378
379 void WebSocketTransport::unmaskReceiveBuffer(const wsheader_type& ws)
380 {
381 if (ws.mask)
382 {
383 for (size_t j = 0; j != ws.N; ++j)
384 {
385 _rxbuf[j + ws.header_size] ^= ws.masking_key[j & 0x3];
386 }
387 }
388 }
389
390 //
391 // http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol
392 //
393 // 0 1 2 3
394 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
395 // +-+-+-+-+-------+-+-------------+-------------------------------+
396 // |F|R|R|R| opcode|M| Payload len | Extended payload length |
397 // |I|S|S|S| (4) |A| (7) | (16/64) |
398 // |N|V|V|V| |S| | (if payload len==126/127) |
399 // | |1|2|3| |K| | |
400 // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
401 // | Extended payload length continued, if payload len == 127 |
402 // + - - - - - - - - - - - - - - - +-------------------------------+
403 // | |Masking-key, if MASK set to 1 |
404 // +-------------------------------+-------------------------------+
405 // | Masking-key (continued) | Payload Data |
406 // +-------------------------------- - - - - - - - - - - - - - - - +
407 // : Payload Data continued ... :
408 // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
409 // | Payload Data continued ... |
410 // +---------------------------------------------------------------+
411 //
412 void WebSocketTransport::dispatch(WebSocketTransport::PollResult pollResult,
413 const OnMessageCallback& onMessageCallback)
414 {
415 while (true)
416 {
417 wsheader_type ws;
418 if (_rxbuf.size() < 2) break; /* Need at least 2 */
419 const uint8_t* data = (uint8_t*) &_rxbuf[0]; // peek, but don't consume
420 ws.fin = (data[0] & 0x80) == 0x80;
421 ws.rsv1 = (data[0] & 0x40) == 0x40;
422 ws.rsv2 = (data[0] & 0x20) == 0x20;
423 ws.rsv3 = (data[0] & 0x10) == 0x10;
424 ws.opcode = (wsheader_type::opcode_type)(data[0] & 0x0f);
425 ws.mask = (data[1] & 0x80) == 0x80;
426 ws.N0 = (data[1] & 0x7f);
427 ws.header_size =
428 2 + (ws.N0 == 126 ? 2 : 0) + (ws.N0 == 127 ? 8 : 0) + (ws.mask ? 4 : 0);
429 if (_rxbuf.size() < ws.header_size) break; /* Need: ws.header_size - _rxbuf.size() */
430
431 if ((ws.rsv1 && !_enablePerMessageDeflate) || ws.rsv2 || ws.rsv3)
432 {
433 close(WebSocketCloseConstants::kProtocolErrorCode,
434 WebSocketCloseConstants::kProtocolErrorReservedBitUsed,
435 _rxbuf.size());
436 return;
437 }
438
439 //
440 // Calculate payload length:
441 // 0-125 mean the payload is that long.
442 // 126 means that the following two bytes indicate the length,
443 // 127 means the next 8 bytes indicate the length.
444 //
445 int i = 0;
446 if (ws.N0 < 126)
447 {
448 ws.N = ws.N0;
449 i = 2;
450 }
451 else if (ws.N0 == 126)
452 {
453 ws.N = 0;
454 ws.N |= ((uint64_t) data[2]) << 8;
455 ws.N |= ((uint64_t) data[3]) << 0;
456 i = 4;
457 }
458 else if (ws.N0 == 127)
459 {
460 ws.N = 0;
461 ws.N |= ((uint64_t) data[2]) << 56;
462 ws.N |= ((uint64_t) data[3]) << 48;
463 ws.N |= ((uint64_t) data[4]) << 40;
464 ws.N |= ((uint64_t) data[5]) << 32;
465 ws.N |= ((uint64_t) data[6]) << 24;
466 ws.N |= ((uint64_t) data[7]) << 16;
467 ws.N |= ((uint64_t) data[8]) << 8;
468 ws.N |= ((uint64_t) data[9]) << 0;
469 i = 10;
470 }
471 else
472 {
473 // invalid payload length according to the spec. bail out
474 return;
475 }
476
477 if (ws.mask)
478 {
479 ws.masking_key[0] = ((uint8_t) data[i + 0]) << 0;
480 ws.masking_key[1] = ((uint8_t) data[i + 1]) << 0;
481 ws.masking_key[2] = ((uint8_t) data[i + 2]) << 0;
482 ws.masking_key[3] = ((uint8_t) data[i + 3]) << 0;
483 }
484 else
485 {
486 ws.masking_key[0] = 0;
487 ws.masking_key[1] = 0;
488 ws.masking_key[2] = 0;
489 ws.masking_key[3] = 0;
490 }
491
492 // Prevent integer overflow in the next conditional
493 const uint64_t maxFrameSize(1ULL << 63);
494 if (ws.N > maxFrameSize)
495 {
496 return;
497 }
498
499 if (_rxbuf.size() < ws.header_size + ws.N)
500 {
501 return; /* Need: ws.header_size+ws.N - _rxbuf.size() */
502 }
503
504 if (!ws.fin && (ws.opcode == wsheader_type::PING || ws.opcode == wsheader_type::PONG ||
505 ws.opcode == wsheader_type::CLOSE))
506 {
507 // Control messages should not be fragmented
508 close(WebSocketCloseConstants::kProtocolErrorCode,
509 WebSocketCloseConstants::kProtocolErrorCodeControlMessageFragmented);
510 return;
511 }
512
513 unmaskReceiveBuffer(ws);
514 std::string frameData(_rxbuf.begin() + ws.header_size,
515 _rxbuf.begin() + ws.header_size + (size_t) ws.N);
516
517 // We got a whole message, now do something with it:
518 if (ws.opcode == wsheader_type::TEXT_FRAME ||
519 ws.opcode == wsheader_type::BINARY_FRAME ||
520 ws.opcode == wsheader_type::CONTINUATION)
521 {
522 if (ws.opcode != wsheader_type::CONTINUATION)
523 {
524 _fragmentedMessageKind = (ws.opcode == wsheader_type::TEXT_FRAME)
525 ? MessageKind::MSG_TEXT
526 : MessageKind::MSG_BINARY;
527
528 _receivedMessageCompressed = _enablePerMessageDeflate && ws.rsv1;
529
530 // Continuation message needs to follow a non-fin TEXT or BINARY message
531 if (!_chunks.empty())
532 {
533 close(WebSocketCloseConstants::kProtocolErrorCode,
534 WebSocketCloseConstants::kProtocolErrorCodeDataOpcodeOutOfSequence);
535 }
536 }
537 else if (_chunks.empty())
538 {
539 // Continuation message need to follow a non-fin TEXT or BINARY message
540 close(
541 WebSocketCloseConstants::kProtocolErrorCode,
542 WebSocketCloseConstants::kProtocolErrorCodeContinuationOpCodeOutOfSequence);
543 }
544
545 //
546 // Usual case. Small unfragmented messages
547 //
548 if (ws.fin && _chunks.empty())
549 {
550 emitMessage(_fragmentedMessageKind,
551 frameData,
552 _receivedMessageCompressed,
553 onMessageCallback);
554
555 _receivedMessageCompressed = false;
556 }
557 else
558 {
559 //
560 // Add intermediary message to our chunk list.
561 // We use a chunk list instead of a big buffer because resizing
562 // large buffer can be very costly when we need to re-allocate
563 // the internal buffer which is slow and can let the internal OS
564 // receive buffer fill out.
565 //
566 _chunks.emplace_back(frameData);
567
568 if (ws.fin)
569 {
570 emitMessage(_fragmentedMessageKind,
571 getMergedChunks(),
572 _receivedMessageCompressed,
573 onMessageCallback);
574
575 _chunks.clear();
576 _receivedMessageCompressed = false;
577 }
578 else
579 {
580 emitMessage(MessageKind::FRAGMENT, std::string(), false, onMessageCallback);
581 }
582 }
583 }
584 else if (ws.opcode == wsheader_type::PING)
585 {
586 // too large
587 if (frameData.size() > 125)
588 {
589 // Unexpected frame type
590 close(WebSocketCloseConstants::kProtocolErrorCode,
591 WebSocketCloseConstants::kProtocolErrorPingPayloadOversized);
592 return;
593 }
594
595 if (_enablePong)
596 {
597 // Reply back right away
598 bool compress = false;
599 sendData(wsheader_type::PONG, frameData, compress);
600 }
601
602 emitMessage(MessageKind::PING, frameData, false, onMessageCallback);
603 }
604 else if (ws.opcode == wsheader_type::PONG)
605 {
606 _pongReceived = true;
607 emitMessage(MessageKind::PONG, frameData, false, onMessageCallback);
608 }
609 else if (ws.opcode == wsheader_type::CLOSE)
610 {
611 std::string reason;
612 uint16_t code = 0;
613
614 if (ws.N >= 2)
615 {
616 // Extract the close code first, available as the first 2 bytes
617 code |= ((uint64_t) _rxbuf[ws.header_size]) << 8;
618 code |= ((uint64_t) _rxbuf[ws.header_size + 1]) << 0;
619
620 // Get the reason.
621 if (ws.N > 2)
622 {
623 reason = frameData.substr(2, frameData.size());
624 }
625
626 // Validate that the reason is proper utf-8. Autobahn 7.5.1
627 if (!validateUtf8(reason))
628 {
629 code = WebSocketCloseConstants::kInvalidFramePayloadData;
630 reason = WebSocketCloseConstants::kInvalidFramePayloadDataMessage;
631 }
632
633 //
634 // Validate close codes. Autobahn 7.9.*
635 // 1014, 1015 are debattable. The firefox MSDN has a description for them.
636 // Full list of status code and status range is defined in the dedicated
637 // RFC section at https://tools.ietf.org/html/rfc6455#page-45
638 //
639 if (code < 1000 || code == 1004 || code == 1006 || (code > 1013 && code < 3000))
640 {
641 // build up an error message containing the bad error code
642 std::stringstream ss;
643 ss << WebSocketCloseConstants::kInvalidCloseCodeMessage << ": " << code;
644 reason = ss.str();
645
646 code = WebSocketCloseConstants::kProtocolErrorCode;
647 }
648 }
649 else
650 {
651 // no close code received
652 code = WebSocketCloseConstants::kNoStatusCodeErrorCode;
653 reason = WebSocketCloseConstants::kNoStatusCodeErrorMessage;
654 }
655
656 // We receive a CLOSE frame from remote and are NOT the ones who triggered the close
657 if (_readyState != ReadyState::CLOSING)
658 {
659 // send back the CLOSE frame
660 sendCloseFrame(code, reason);
661
662 wakeUpFromPoll(SelectInterrupt::kCloseRequest);
663
664 bool remote = true;
665 closeSocketAndSwitchToClosedState(code, reason, _rxbuf.size(), remote);
666 }
667 else
668 {
669 // we got the CLOSE frame answer from our close, so we can close the connection
670 // if the code/reason are the same
671 bool identicalReason = _closeCode == code && getCloseReason() == reason;
672
673 if (identicalReason)
674 {
675 bool remote = false;
676 closeSocketAndSwitchToClosedState(code, reason, _rxbuf.size(), remote);
677 }
678 }
679 }
680 else
681 {
682 // Unexpected frame type
683 close(WebSocketCloseConstants::kProtocolErrorCode,
684 WebSocketCloseConstants::kProtocolErrorMessage,
685 _rxbuf.size());
686 }
687
688 // Erase the message that has been processed from the input/read buffer
689 _rxbuf.erase(_rxbuf.begin(), _rxbuf.begin() + ws.header_size + (size_t) ws.N);
690 }
691
692 // if an abnormal closure was raised in poll, and nothing else triggered a CLOSED state in
693 // the received and processed data then close the connection
694 if (pollResult != PollResult::Succeeded)
695 {
696 _rxbuf.clear();
697
698 // if we previously closed the connection (CLOSING state), then set state to CLOSED
699 // (code/reason were set before)
700 if (_readyState == ReadyState::CLOSING)
701 {
702 closeSocket();
703 setReadyState(ReadyState::CLOSED);
704 }
705 // if we weren't closing, then close using abnormal close code and message
706 else if (_readyState != ReadyState::CLOSED)
707 {
708 closeSocketAndSwitchToClosedState(WebSocketCloseConstants::kAbnormalCloseCode,
709 WebSocketCloseConstants::kAbnormalCloseMessage,
710 0,
711 false);
712 }
713 }
714 }
715
716 std::string WebSocketTransport::getMergedChunks() const
717 {
718 size_t length = 0;
719 for (auto&& chunk : _chunks)
720 {
721 length += chunk.size();
722 }
723
724 std::string msg;
725 msg.reserve(length);
726
727 for (auto&& chunk : _chunks)
728 {
729 msg += chunk;
730 }
731
732 return msg;
733 }
734
735 void WebSocketTransport::emitMessage(MessageKind messageKind,
736 const std::string& message,
737 bool compressedMessage,
738 const OnMessageCallback& onMessageCallback)
739 {
740 size_t wireSize = message.size();
741
742 // When the RSV1 bit is 1 it means the message is compressed
743 if (compressedMessage && messageKind != MessageKind::FRAGMENT)
744 {
745 bool success = _perMessageDeflate->decompress(message, _decompressedMessage);
746
747 if (messageKind == MessageKind::MSG_TEXT && !validateUtf8(_decompressedMessage))
748 {
749 close(WebSocketCloseConstants::kInvalidFramePayloadData,
750 WebSocketCloseConstants::kInvalidFramePayloadDataMessage);
751 }
752 else
753 {
754 onMessageCallback(_decompressedMessage, wireSize, !success, messageKind);
755 }
756 }
757 else
758 {
759 if (messageKind == MessageKind::MSG_TEXT && !validateUtf8(message))
760 {
761 close(WebSocketCloseConstants::kInvalidFramePayloadData,
762 WebSocketCloseConstants::kInvalidFramePayloadDataMessage);
763 }
764 else
765 {
766 onMessageCallback(message, wireSize, false, messageKind);
767 }
768 }
769 }
770
771 unsigned WebSocketTransport::getRandomUnsigned()
772 {
773 auto now = std::chrono::system_clock::now();
774 auto seconds =
775 std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch()).count();
776 return static_cast<unsigned>(seconds);
777 }
778
779 WebSocketSendInfo WebSocketTransport::sendData(wsheader_type::opcode_type type,
780 const IXWebSocketSendData& message,
781 bool compress,
782 const OnProgressCallback& onProgressCallback)
783 {
784 if (_readyState != ReadyState::OPEN && _readyState != ReadyState::CLOSING)
785 {
786 return WebSocketSendInfo(false);
787 }
788
789 size_t payloadSize = message.size();
790 size_t wireSize = message.size();
791 bool compressionError = false;
792
793 auto message_begin = message.cbegin();
794 auto message_end = message.cend();
795
796 if (compress)
797 {
798 if (!_perMessageDeflate->compress(message, _compressedMessage))
799 {
800 bool success = false;
801 compressionError = true;
802 payloadSize = 0;
803 wireSize = 0;
804 return WebSocketSendInfo(success, compressionError, payloadSize, wireSize);
805 }
806 compressionError = false;
807 wireSize = _compressedMessage.size();
808
809 IXWebSocketSendData compressedSendData(_compressedMessage);
810 message_begin = compressedSendData.cbegin();
811 message_end = compressedSendData.cend();
812 }
813
814 {
815 std::lock_guard<std::mutex> lock(_txbufMutex);
816 _txbuf.reserve(wireSize);
817 }
818
819 bool success = true;
820
821 // Common case for most message. No fragmentation required.
822 if (wireSize < kChunkSize)
823 {
824 success = sendFragment(type, true, message_begin, message_end, compress);
825
826 if (onProgressCallback)
827 {
828 onProgressCallback(0, 1);
829 }
830 }
831 else
832 {
833 //
834 // Large messages need to be fragmented
835 //
836 // Rules:
837 // First message needs to specify a proper type (BINARY or TEXT)
838 // Intermediary and last messages need to be of type CONTINUATION
839 // Last message must set the fin byte.
840 //
841 auto steps = wireSize / kChunkSize;
842
843 auto begin = message_begin;
844 auto end = message_end;
845
846 for (uint64_t i = 0; i < steps; ++i)
847 {
848 bool firstStep = i == 0;
849 bool lastStep = (i + 1) == steps;
850 bool fin = lastStep;
851
852 end = begin + kChunkSize;
853 if (lastStep)
854 {
855 end = message_end;
856 }
857
858 auto opcodeType = type;
859 if (!firstStep)
860 {
861 opcodeType = wsheader_type::CONTINUATION;
862 }
863
864 // Send message
865 if (!sendFragment(opcodeType, fin, begin, end, compress))
866 {
867 return WebSocketSendInfo(false);
868 }
869
870 if (onProgressCallback && !onProgressCallback((int) i, (int) steps))
871 {
872 break;
873 }
874
875 begin += kChunkSize;
876 }
877 }
878
879 // Request to flush the send buffer on the background thread if it isn't empty
880 if (!isSendBufferEmpty())
881 {
882 wakeUpFromPoll(SelectInterrupt::kSendRequest);
883
884 // FIXME: we should have a timeout when sending large messages: see #131
885 if (_blockingSend && !flushSendBuffer())
886 {
887 success = false;
888 }
889 }
890
891 return WebSocketSendInfo(success, compressionError, payloadSize, wireSize);
892 }
893
894 template<class Iterator>
895 bool WebSocketTransport::sendFragment(wsheader_type::opcode_type type,
896 bool fin,
897 Iterator message_begin,
898 Iterator message_end,
899 bool compress)
900 {
901 uint64_t message_size = static_cast<uint64_t>(message_end - message_begin);
902
903 unsigned x = getRandomUnsigned();
904 uint8_t masking_key[4] = {};
905 masking_key[0] = (x >> 24);
906 masking_key[1] = (x >> 16) & 0xff;
907 masking_key[2] = (x >> 8) & 0xff;
908 masking_key[3] = (x) &0xff;
909
910 std::vector<uint8_t> header;
911 header.assign(2 + (message_size >= 126 ? 2 : 0) + (message_size >= 65536 ? 6 : 0) +
912 (_useMask ? 4 : 0),
913 0);
914 header[0] = type;
915
916 // The fin bit indicate that this is the last fragment. Fin is French for end.
917 if (fin)
918 {
919 header[0] |= 0x80;
920 }
921
922 // The rsv1 bit indicate that the frame is compressed
923 // continuation opcodes should not set it. Autobahn 12.2.10 and others 12.X
924 if (compress && type != wsheader_type::CONTINUATION)
925 {
926 header[0] |= 0x40;
927 }
928
929 if (message_size < 126)
930 {
931 header[1] = (message_size & 0xff) | (_useMask ? 0x80 : 0);
932
933 if (_useMask)
934 {
935 header[2] = masking_key[0];
936 header[3] = masking_key[1];
937 header[4] = masking_key[2];
938 header[5] = masking_key[3];
939 }
940 }
941 else if (message_size < 65536)
942 {
943 header[1] = 126 | (_useMask ? 0x80 : 0);
944 header[2] = (message_size >> 8) & 0xff;
945 header[3] = (message_size >> 0) & 0xff;
946
947 if (_useMask)
948 {
949 header[4] = masking_key[0];
950 header[5] = masking_key[1];
951 header[6] = masking_key[2];
952 header[7] = masking_key[3];
953 }
954 }
955 else
956 { // TODO: run coverage testing here
957 header[1] = 127 | (_useMask ? 0x80 : 0);
958 header[2] = (message_size >> 56) & 0xff;
959 header[3] = (message_size >> 48) & 0xff;
960 header[4] = (message_size >> 40) & 0xff;
961 header[5] = (message_size >> 32) & 0xff;
962 header[6] = (message_size >> 24) & 0xff;
963 header[7] = (message_size >> 16) & 0xff;
964 header[8] = (message_size >> 8) & 0xff;
965 header[9] = (message_size >> 0) & 0xff;
966
967 if (_useMask)
968 {
969 header[10] = masking_key[0];
970 header[11] = masking_key[1];
971 header[12] = masking_key[2];
972 header[13] = masking_key[3];
973 }
974 }
975
976 // _txbuf will keep growing until it can be transmitted over the socket:
977 appendToSendBuffer(header, message_begin, message_end, message_size, masking_key);
978
979 // Now actually send this data
980 return sendOnSocket();
981 }
982
983 WebSocketSendInfo WebSocketTransport::sendPing(const IXWebSocketSendData& message)
984 {
985 bool compress = false;
986 WebSocketSendInfo info = sendData(wsheader_type::PING, message, compress);
987
988 if (info.success)
989 {
990 std::lock_guard<std::mutex> lck(_lastSendPingTimePointMutex);
991 _lastSendPingTimePoint = std::chrono::steady_clock::now();
992 }
993
994 return info;
995 }
996
997 WebSocketSendInfo WebSocketTransport::sendBinary(const IXWebSocketSendData& message,
998 const OnProgressCallback& onProgressCallback)
999
1000 {
1001 return sendData(
1002 wsheader_type::BINARY_FRAME, message, _enablePerMessageDeflate, onProgressCallback);
1003 }
1004
1005 WebSocketSendInfo WebSocketTransport::sendText(const IXWebSocketSendData& message,
1006 const OnProgressCallback& onProgressCallback)
1007
1008 {
1009 return sendData(
1010 wsheader_type::TEXT_FRAME, message, _enablePerMessageDeflate, onProgressCallback);
1011 }
1012
1013 bool WebSocketTransport::sendOnSocket()
1014 {
1015 std::lock_guard<std::mutex> lock(_txbufMutex);
1016
1017 while (_txbuf.size())
1018 {
1019 ssize_t ret = 0;
1020 {
1021 std::lock_guard<std::mutex> lock(_socketMutex);
1022 ret = _socket->send((char*) &_txbuf[0], _txbuf.size());
1023 }
1024
1025 if (ret < 0 && Socket::isWaitNeeded())
1026 {
1027 break;
1028 }
1029 else if (ret <= 0)
1030 {
1031 closeSocket();
1032 setReadyState(ReadyState::CLOSED);
1033 return false;
1034 }
1035 else
1036 {
1037 _txbuf.erase(_txbuf.begin(), _txbuf.begin() + ret);
1038 }
1039 }
1040
1041 return true;
1042 }
1043
1044 bool WebSocketTransport::receiveFromSocket()
1045 {
1046 while (true)
1047 {
1048 ssize_t ret = _socket->recv((char*) &_readbuf[0], _readbuf.size());
1049
1050 if (ret < 0 && Socket::isWaitNeeded())
1051 {
1052 break;
1053 }
1054 else if (ret <= 0)
1055 {
1056 // if there are received data pending to be processed, then delay the abnormal
1057 // closure to after dispatch (other close code/reason could be read from the
1058 // buffer)
1059
1060 closeSocket();
1061 return false;
1062 }
1063 else
1064 {
1065 _rxbuf.insert(_rxbuf.end(), _readbuf.begin(), _readbuf.begin() + ret);
1066 }
1067 }
1068
1069 return true;
1070 }
1071
1072 void WebSocketTransport::sendCloseFrame(uint16_t code, const std::string& reason)
1073 {
1074 bool compress = false;
1075
1076 // if a status is set/was read
1077 if (code != WebSocketCloseConstants::kNoStatusCodeErrorCode)
1078 {
1079 // See list of close events here:
1080 // https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent
1081 std::string closure {(char) (code >> 8), (char) (code & 0xff)};
1082
1083 // copy reason after code
1084 closure.append(reason);
1085
1086 sendData(wsheader_type::CLOSE, closure, compress);
1087 }
1088 else
1089 {
1090 // no close code/reason set
1091 sendData(wsheader_type::CLOSE, std::string(""), compress);
1092 }
1093 }
1094
1095 void WebSocketTransport::closeSocket()
1096 {
1097 std::lock_guard<std::mutex> lock(_socketMutex);
1098 _socket->close();
1099 }
1100
1101 bool WebSocketTransport::wakeUpFromPoll(uint64_t wakeUpCode)
1102 {
1103 std::lock_guard<std::mutex> lock(_socketMutex);
1104 return _socket->wakeUpFromPoll(wakeUpCode);
1105 }
1106
1107 void WebSocketTransport::closeSocketAndSwitchToClosedState(uint16_t code,
1108 const std::string& reason,
1109 size_t closeWireSize,
1110 bool remote)
1111 {
1112 closeSocket();
1113
1114 setCloseReason(reason);
1115 _closeCode = code;
1116 _closeWireSize = closeWireSize;
1117 _closeRemote = remote;
1118
1119 setReadyState(ReadyState::CLOSED);
1120 _requestInitCancellation = false;
1121 }
1122
1123 void WebSocketTransport::close(uint16_t code,
1124 const std::string& reason,
1125 size_t closeWireSize,
1126 bool remote)
1127 {
1128 _requestInitCancellation = true;
1129
1130 if (_readyState == ReadyState::CLOSING || _readyState == ReadyState::CLOSED) return;
1131
1132 if (closeWireSize == 0)
1133 {
1134 closeWireSize = reason.size();
1135 }
1136
1137 setCloseReason(reason);
1138 _closeCode = code;
1139 _closeWireSize = closeWireSize;
1140 _closeRemote = remote;
1141
1142 {
1143 std::lock_guard<std::mutex> lock(_closingTimePointMutex);
1144 _closingTimePoint = std::chrono::steady_clock::now();
1145 }
1146 setReadyState(ReadyState::CLOSING);
1147
1148 sendCloseFrame(code, reason);
1149
1150 // wake up the poll, but do not close yet
1151 wakeUpFromPoll(SelectInterrupt::kSendRequest);
1152 }
1153
1154 size_t WebSocketTransport::bufferedAmount() const
1155 {
1156 std::lock_guard<std::mutex> lock(_txbufMutex);
1157 return _txbuf.size();
1158 }
1159
1160 bool WebSocketTransport::flushSendBuffer()
1161 {
1162 while (!isSendBufferEmpty() && !_requestInitCancellation)
1163 {
1164 // Wait with a 10ms timeout until the socket is ready to write.
1165 // This way we are not busy looping
1166 PollResultType result = _socket->isReadyToWrite(10);
1167
1168 if (result == PollResultType::Error)
1169 {
1170 closeSocket();
1171 setReadyState(ReadyState::CLOSED);
1172 return false;
1173 }
1174 else if (result == PollResultType::ReadyForWrite)
1175 {
1176 if (!sendOnSocket())
1177 {
1178 return false;
1179 }
1180 }
1181 }
1182
1183 return true;
1184 }
1185
1186 void WebSocketTransport::setCloseReason(const std::string& reason)
1187 {
1188 std::lock_guard<std::mutex> lock(_closeReasonMutex);
1189 _closeReason = reason;
1190 }
1191
1192 const std::string& WebSocketTransport::getCloseReason() const
1193 {
1194 std::lock_guard<std::mutex> lock(_closeReasonMutex);
1195 return _closeReason;
1196 }
1197} // namespace ix
1198