1 | /* |
2 | * The MIT License (MIT) |
3 | * |
4 | * Copyright (c) 2012, 2013 <dhbaird@gmail.com> |
5 | * |
6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
7 | * of this software and associated documentation files (the "Software"), to deal |
8 | * in the Software without restriction, including without limitation the rights |
9 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
10 | * copies of the Software, and to permit persons to whom the Software is |
11 | * furnished to do so, subject to the following conditions: |
12 | * |
13 | * The above copyright notice and this permission notice shall be included in |
14 | * all copies or substantial portions of the Software. |
15 | * |
16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
22 | * THE SOFTWARE. |
23 | */ |
24 | |
25 | /* |
26 | * IXWebSocketTransport.cpp |
27 | * Author: Benjamin Sergeant |
28 | * Copyright (c) 2017-2019 Machine Zone, Inc. All rights reserved. |
29 | */ |
30 | |
31 | // |
32 | // Adapted from https://github.com/dhbaird/easywsclient |
33 | // |
34 | |
35 | #include "IXWebSocketTransport.h" |
36 | |
37 | #include "IXSocketFactory.h" |
38 | #include "IXSocketTLSOptions.h" |
39 | #include "IXUniquePtr.h" |
40 | #include "IXUrlParser.h" |
41 | #include "IXUtf8Validator.h" |
42 | #include "IXWebSocketHandshake.h" |
43 | #include "IXWebSocketHttpHeaders.h" |
44 | #include <chrono> |
45 | #include <cstdarg> |
46 | #include <cstdlib> |
47 | #include <sstream> |
48 | #include <stdlib.h> |
49 | #include <string.h> |
50 | #include <string> |
51 | #include <thread> |
52 | #include <vector> |
53 | |
54 | |
55 | namespace ix |
56 | { |
57 | const std::string WebSocketTransport::kPingMessage("ixwebsocket::heartbeat" ); |
58 | const int WebSocketTransport::kDefaultPingIntervalSecs(-1); |
59 | const bool WebSocketTransport::kDefaultEnablePong(true); |
60 | const int WebSocketTransport::kClosingMaximumWaitingDelayInMs(300); |
61 | constexpr size_t WebSocketTransport::kChunkSize; |
62 | |
63 | WebSocketTransport::WebSocketTransport() |
64 | : _useMask(true) |
65 | , _blockingSend(false) |
66 | , _receivedMessageCompressed(false) |
67 | , _readyState(ReadyState::CLOSED) |
68 | , _closeCode(WebSocketCloseConstants::kInternalErrorCode) |
69 | , _closeWireSize(0) |
70 | , _closeRemote(false) |
71 | , _enablePerMessageDeflate(false) |
72 | , _requestInitCancellation(false) |
73 | , _closingTimePoint(std::chrono::steady_clock::now()) |
74 | , _enablePong(kDefaultEnablePong) |
75 | , _pingIntervalSecs(kDefaultPingIntervalSecs) |
76 | , _pongReceived(false) |
77 | , _pingCount(0) |
78 | , _lastSendPingTimePoint(std::chrono::steady_clock::now()) |
79 | { |
80 | setCloseReason(WebSocketCloseConstants::kInternalErrorMessage); |
81 | _readbuf.resize(kChunkSize); |
82 | } |
83 | |
84 | WebSocketTransport::~WebSocketTransport() |
85 | { |
86 | ; |
87 | } |
88 | |
89 | void WebSocketTransport::configure( |
90 | const WebSocketPerMessageDeflateOptions& perMessageDeflateOptions, |
91 | const SocketTLSOptions& socketTLSOptions, |
92 | bool enablePong, |
93 | int pingIntervalSecs) |
94 | { |
95 | _perMessageDeflateOptions = perMessageDeflateOptions; |
96 | _enablePerMessageDeflate = _perMessageDeflateOptions.enabled(); |
97 | _socketTLSOptions = socketTLSOptions; |
98 | _enablePong = enablePong; |
99 | _pingIntervalSecs = pingIntervalSecs; |
100 | } |
101 | |
102 | // Client |
103 | WebSocketInitResult WebSocketTransport::connectToUrl(const std::string& url, |
104 | const WebSocketHttpHeaders& , |
105 | int timeoutSecs) |
106 | { |
107 | std::lock_guard<std::mutex> lock(_socketMutex); |
108 | |
109 | std::string protocol, host, path, query; |
110 | int port; |
111 | std::string remoteUrl(url); |
112 | |
113 | WebSocketInitResult result; |
114 | const int maxRedirections = 10; |
115 | |
116 | for (int i = 0; i < maxRedirections; ++i) |
117 | { |
118 | if (!UrlParser::parse(remoteUrl, protocol, host, path, query, port)) |
119 | { |
120 | std::stringstream ss; |
121 | ss << "Could not parse url: '" << url << "'" ; |
122 | return WebSocketInitResult(false, 0, ss.str()); |
123 | } |
124 | |
125 | std::string errorMsg; |
126 | bool tls = protocol == "wss" ; |
127 | _socket = createSocket(tls, -1, errorMsg, _socketTLSOptions); |
128 | _perMessageDeflate = ix::make_unique<WebSocketPerMessageDeflate>(); |
129 | |
130 | if (!_socket) |
131 | { |
132 | return WebSocketInitResult(false, 0, errorMsg); |
133 | } |
134 | |
135 | WebSocketHandshake webSocketHandshake(_requestInitCancellation, |
136 | _socket, |
137 | _perMessageDeflate, |
138 | _perMessageDeflateOptions, |
139 | _enablePerMessageDeflate); |
140 | |
141 | result = webSocketHandshake.clientHandshake( |
142 | remoteUrl, headers, host, path, port, timeoutSecs); |
143 | |
144 | if (result.http_status >= 300 && result.http_status < 400) |
145 | { |
146 | auto it = result.headers.find("Location" ); |
147 | if (it == result.headers.end()) |
148 | { |
149 | std::stringstream ss; |
150 | ss << "Missing Location Header for HTTP Redirect response. " |
151 | << "Rejecting connection to " << url << ", status: " << result.http_status; |
152 | result.errorStr = ss.str(); |
153 | break; |
154 | } |
155 | |
156 | remoteUrl = it->second; |
157 | continue; |
158 | } |
159 | |
160 | if (result.success) |
161 | { |
162 | setReadyState(ReadyState::OPEN); |
163 | } |
164 | return result; |
165 | } |
166 | |
167 | return result; |
168 | } |
169 | |
170 | // Server |
171 | WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr<Socket> socket, |
172 | int timeoutSecs, |
173 | bool enablePerMessageDeflate) |
174 | { |
175 | std::lock_guard<std::mutex> lock(_socketMutex); |
176 | |
177 | // Server should not mask the data it sends to the client |
178 | _useMask = false; |
179 | _blockingSend = true; |
180 | |
181 | _socket = std::move(socket); |
182 | _perMessageDeflate = ix::make_unique<WebSocketPerMessageDeflate>(); |
183 | |
184 | WebSocketHandshake webSocketHandshake(_requestInitCancellation, |
185 | _socket, |
186 | _perMessageDeflate, |
187 | _perMessageDeflateOptions, |
188 | _enablePerMessageDeflate); |
189 | |
190 | auto result = webSocketHandshake.serverHandshake(timeoutSecs, enablePerMessageDeflate); |
191 | if (result.success) |
192 | { |
193 | setReadyState(ReadyState::OPEN); |
194 | } |
195 | return result; |
196 | } |
197 | |
198 | WebSocketTransport::ReadyState WebSocketTransport::getReadyState() const |
199 | { |
200 | return _readyState; |
201 | } |
202 | |
203 | void WebSocketTransport::setReadyState(ReadyState readyState) |
204 | { |
205 | // No state change, return |
206 | if (_readyState == readyState) return; |
207 | |
208 | if (readyState == ReadyState::CLOSED) |
209 | { |
210 | if (_onCloseCallback) |
211 | { |
212 | _onCloseCallback(_closeCode, getCloseReason(), _closeWireSize, _closeRemote); |
213 | } |
214 | setCloseReason(WebSocketCloseConstants::kInternalErrorMessage); |
215 | _closeCode = WebSocketCloseConstants::kInternalErrorCode; |
216 | _closeWireSize = 0; |
217 | _closeRemote = false; |
218 | } |
219 | else if (readyState == ReadyState::OPEN) |
220 | { |
221 | initTimePointsAfterConnect(); |
222 | _pongReceived = false; |
223 | } |
224 | |
225 | _readyState = readyState; |
226 | } |
227 | |
228 | void WebSocketTransport::setOnCloseCallback(const OnCloseCallback& onCloseCallback) |
229 | { |
230 | _onCloseCallback = onCloseCallback; |
231 | } |
232 | |
233 | void WebSocketTransport::initTimePointsAfterConnect() |
234 | { |
235 | { |
236 | std::lock_guard<std::mutex> lock(_lastSendPingTimePointMutex); |
237 | _lastSendPingTimePoint = std::chrono::steady_clock::now(); |
238 | } |
239 | } |
240 | |
241 | // Only consider send PING time points for that computation. |
242 | bool WebSocketTransport::pingIntervalExceeded() |
243 | { |
244 | if (_pingIntervalSecs <= 0) return false; |
245 | |
246 | std::lock_guard<std::mutex> lock(_lastSendPingTimePointMutex); |
247 | auto now = std::chrono::steady_clock::now(); |
248 | return now - _lastSendPingTimePoint > std::chrono::seconds(_pingIntervalSecs); |
249 | } |
250 | |
251 | WebSocketSendInfo WebSocketTransport::sendHeartBeat() |
252 | { |
253 | _pongReceived = false; |
254 | std::stringstream ss; |
255 | ss << kPingMessage << "::" << _pingIntervalSecs << "s" |
256 | << "::" << _pingCount++; |
257 | return sendPing(ss.str()); |
258 | } |
259 | |
260 | bool WebSocketTransport::closingDelayExceeded() |
261 | { |
262 | std::lock_guard<std::mutex> lock(_closingTimePointMutex); |
263 | auto now = std::chrono::steady_clock::now(); |
264 | return now - _closingTimePoint > std::chrono::milliseconds(kClosingMaximumWaitingDelayInMs); |
265 | } |
266 | |
267 | WebSocketTransport::PollResult WebSocketTransport::poll() |
268 | { |
269 | if (_readyState == ReadyState::OPEN) |
270 | { |
271 | if (pingIntervalExceeded()) |
272 | { |
273 | if (!_pongReceived) |
274 | { |
275 | // ping response (PONG) exceeds the maximum delay, close the connection |
276 | close(WebSocketCloseConstants::kInternalErrorCode, |
277 | WebSocketCloseConstants::kPingTimeoutMessage); |
278 | } |
279 | else |
280 | { |
281 | sendHeartBeat(); |
282 | } |
283 | } |
284 | } |
285 | |
286 | // No timeout if state is not OPEN, otherwise computed |
287 | // pingIntervalOrTimeoutGCD (equals to -1 if no ping and no ping timeout are set) |
288 | int lastingTimeoutDelayInMs = (_readyState != ReadyState::OPEN) ? 0 : _pingIntervalSecs; |
289 | |
290 | if (_pingIntervalSecs > 0) |
291 | { |
292 | // compute lasting delay to wait for next ping / timeout, if at least one set |
293 | auto now = std::chrono::steady_clock::now(); |
294 | int timeSinceLastPingMs = (int) std::chrono::duration_cast<std::chrono::milliseconds>( |
295 | now - _lastSendPingTimePoint) |
296 | .count(); |
297 | lastingTimeoutDelayInMs = (1000 * _pingIntervalSecs) - timeSinceLastPingMs; |
298 | } |
299 | |
300 | // The platform may not have select interrupt capabilities, so wait with a small timeout |
301 | if (lastingTimeoutDelayInMs <= 0 && !_socket->isWakeUpFromPollSupported()) |
302 | { |
303 | lastingTimeoutDelayInMs = 20; |
304 | } |
305 | |
306 | // If we are requesting a cancellation, pass in a positive and small timeout |
307 | // to never poll forever without a timeout. |
308 | if (_requestInitCancellation) |
309 | { |
310 | lastingTimeoutDelayInMs = 100; |
311 | } |
312 | |
313 | // poll the socket |
314 | PollResultType pollResult = _socket->isReadyToRead(lastingTimeoutDelayInMs); |
315 | |
316 | // Make sure we send all the buffered data |
317 | // there can be a lot of it for large messages. |
318 | if (pollResult == PollResultType::SendRequest) |
319 | { |
320 | if (!flushSendBuffer()) |
321 | { |
322 | return PollResult::CannotFlushSendBuffer; |
323 | } |
324 | } |
325 | else if (pollResult == PollResultType::ReadyForRead) |
326 | { |
327 | if (!receiveFromSocket()) |
328 | { |
329 | return PollResult::AbnormalClose; |
330 | } |
331 | } |
332 | else if (pollResult == PollResultType::Error) |
333 | { |
334 | closeSocket(); |
335 | } |
336 | else if (pollResult == PollResultType::CloseRequest) |
337 | { |
338 | closeSocket(); |
339 | } |
340 | |
341 | if (_readyState == ReadyState::CLOSING && closingDelayExceeded()) |
342 | { |
343 | _rxbuf.clear(); |
344 | // close code and reason were set when calling close() |
345 | closeSocket(); |
346 | setReadyState(ReadyState::CLOSED); |
347 | } |
348 | |
349 | return PollResult::Succeeded; |
350 | } |
351 | |
352 | bool WebSocketTransport::isSendBufferEmpty() const |
353 | { |
354 | std::lock_guard<std::mutex> lock(_txbufMutex); |
355 | return _txbuf.empty(); |
356 | } |
357 | |
358 | template<class Iterator> |
359 | void WebSocketTransport::appendToSendBuffer(const std::vector<uint8_t>& , |
360 | Iterator begin, |
361 | Iterator end, |
362 | uint64_t message_size, |
363 | uint8_t masking_key[4]) |
364 | { |
365 | std::lock_guard<std::mutex> lock(_txbufMutex); |
366 | |
367 | _txbuf.insert(_txbuf.end(), header.begin(), header.end()); |
368 | _txbuf.insert(_txbuf.end(), begin, end); |
369 | |
370 | if (_useMask) |
371 | { |
372 | for (size_t i = 0; i != (size_t) message_size; ++i) |
373 | { |
374 | *(_txbuf.end() - (size_t) message_size + i) ^= masking_key[i & 0x3]; |
375 | } |
376 | } |
377 | } |
378 | |
379 | void WebSocketTransport::(const wsheader_type& ws) |
380 | { |
381 | if (ws.mask) |
382 | { |
383 | for (size_t j = 0; j != ws.N; ++j) |
384 | { |
385 | _rxbuf[j + ws.header_size] ^= ws.masking_key[j & 0x3]; |
386 | } |
387 | } |
388 | } |
389 | |
390 | // |
391 | // http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol |
392 | // |
393 | // 0 1 2 3 |
394 | // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 |
395 | // +-+-+-+-+-------+-+-------------+-------------------------------+ |
396 | // |F|R|R|R| opcode|M| Payload len | Extended payload length | |
397 | // |I|S|S|S| (4) |A| (7) | (16/64) | |
398 | // |N|V|V|V| |S| | (if payload len==126/127) | |
399 | // | |1|2|3| |K| | | |
400 | // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + |
401 | // | Extended payload length continued, if payload len == 127 | |
402 | // + - - - - - - - - - - - - - - - +-------------------------------+ |
403 | // | |Masking-key, if MASK set to 1 | |
404 | // +-------------------------------+-------------------------------+ |
405 | // | Masking-key (continued) | Payload Data | |
406 | // +-------------------------------- - - - - - - - - - - - - - - - + |
407 | // : Payload Data continued ... : |
408 | // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + |
409 | // | Payload Data continued ... | |
410 | // +---------------------------------------------------------------+ |
411 | // |
412 | void WebSocketTransport::dispatch(WebSocketTransport::PollResult pollResult, |
413 | const OnMessageCallback& onMessageCallback) |
414 | { |
415 | while (true) |
416 | { |
417 | wsheader_type ws; |
418 | if (_rxbuf.size() < 2) break; /* Need at least 2 */ |
419 | const uint8_t* data = (uint8_t*) &_rxbuf[0]; // peek, but don't consume |
420 | ws.fin = (data[0] & 0x80) == 0x80; |
421 | ws.rsv1 = (data[0] & 0x40) == 0x40; |
422 | ws.rsv2 = (data[0] & 0x20) == 0x20; |
423 | ws.rsv3 = (data[0] & 0x10) == 0x10; |
424 | ws.opcode = (wsheader_type::opcode_type)(data[0] & 0x0f); |
425 | ws.mask = (data[1] & 0x80) == 0x80; |
426 | ws.N0 = (data[1] & 0x7f); |
427 | ws.header_size = |
428 | 2 + (ws.N0 == 126 ? 2 : 0) + (ws.N0 == 127 ? 8 : 0) + (ws.mask ? 4 : 0); |
429 | if (_rxbuf.size() < ws.header_size) break; /* Need: ws.header_size - _rxbuf.size() */ |
430 | |
431 | if ((ws.rsv1 && !_enablePerMessageDeflate) || ws.rsv2 || ws.rsv3) |
432 | { |
433 | close(WebSocketCloseConstants::kProtocolErrorCode, |
434 | WebSocketCloseConstants::kProtocolErrorReservedBitUsed, |
435 | _rxbuf.size()); |
436 | return; |
437 | } |
438 | |
439 | // |
440 | // Calculate payload length: |
441 | // 0-125 mean the payload is that long. |
442 | // 126 means that the following two bytes indicate the length, |
443 | // 127 means the next 8 bytes indicate the length. |
444 | // |
445 | int i = 0; |
446 | if (ws.N0 < 126) |
447 | { |
448 | ws.N = ws.N0; |
449 | i = 2; |
450 | } |
451 | else if (ws.N0 == 126) |
452 | { |
453 | ws.N = 0; |
454 | ws.N |= ((uint64_t) data[2]) << 8; |
455 | ws.N |= ((uint64_t) data[3]) << 0; |
456 | i = 4; |
457 | } |
458 | else if (ws.N0 == 127) |
459 | { |
460 | ws.N = 0; |
461 | ws.N |= ((uint64_t) data[2]) << 56; |
462 | ws.N |= ((uint64_t) data[3]) << 48; |
463 | ws.N |= ((uint64_t) data[4]) << 40; |
464 | ws.N |= ((uint64_t) data[5]) << 32; |
465 | ws.N |= ((uint64_t) data[6]) << 24; |
466 | ws.N |= ((uint64_t) data[7]) << 16; |
467 | ws.N |= ((uint64_t) data[8]) << 8; |
468 | ws.N |= ((uint64_t) data[9]) << 0; |
469 | i = 10; |
470 | } |
471 | else |
472 | { |
473 | // invalid payload length according to the spec. bail out |
474 | return; |
475 | } |
476 | |
477 | if (ws.mask) |
478 | { |
479 | ws.masking_key[0] = ((uint8_t) data[i + 0]) << 0; |
480 | ws.masking_key[1] = ((uint8_t) data[i + 1]) << 0; |
481 | ws.masking_key[2] = ((uint8_t) data[i + 2]) << 0; |
482 | ws.masking_key[3] = ((uint8_t) data[i + 3]) << 0; |
483 | } |
484 | else |
485 | { |
486 | ws.masking_key[0] = 0; |
487 | ws.masking_key[1] = 0; |
488 | ws.masking_key[2] = 0; |
489 | ws.masking_key[3] = 0; |
490 | } |
491 | |
492 | // Prevent integer overflow in the next conditional |
493 | const uint64_t maxFrameSize(1ULL << 63); |
494 | if (ws.N > maxFrameSize) |
495 | { |
496 | return; |
497 | } |
498 | |
499 | if (_rxbuf.size() < ws.header_size + ws.N) |
500 | { |
501 | return; /* Need: ws.header_size+ws.N - _rxbuf.size() */ |
502 | } |
503 | |
504 | if (!ws.fin && (ws.opcode == wsheader_type::PING || ws.opcode == wsheader_type::PONG || |
505 | ws.opcode == wsheader_type::CLOSE)) |
506 | { |
507 | // Control messages should not be fragmented |
508 | close(WebSocketCloseConstants::kProtocolErrorCode, |
509 | WebSocketCloseConstants::kProtocolErrorCodeControlMessageFragmented); |
510 | return; |
511 | } |
512 | |
513 | unmaskReceiveBuffer(ws); |
514 | std::string frameData(_rxbuf.begin() + ws.header_size, |
515 | _rxbuf.begin() + ws.header_size + (size_t) ws.N); |
516 | |
517 | // We got a whole message, now do something with it: |
518 | if (ws.opcode == wsheader_type::TEXT_FRAME || |
519 | ws.opcode == wsheader_type::BINARY_FRAME || |
520 | ws.opcode == wsheader_type::CONTINUATION) |
521 | { |
522 | if (ws.opcode != wsheader_type::CONTINUATION) |
523 | { |
524 | _fragmentedMessageKind = (ws.opcode == wsheader_type::TEXT_FRAME) |
525 | ? MessageKind::MSG_TEXT |
526 | : MessageKind::MSG_BINARY; |
527 | |
528 | _receivedMessageCompressed = _enablePerMessageDeflate && ws.rsv1; |
529 | |
530 | // Continuation message needs to follow a non-fin TEXT or BINARY message |
531 | if (!_chunks.empty()) |
532 | { |
533 | close(WebSocketCloseConstants::kProtocolErrorCode, |
534 | WebSocketCloseConstants::kProtocolErrorCodeDataOpcodeOutOfSequence); |
535 | } |
536 | } |
537 | else if (_chunks.empty()) |
538 | { |
539 | // Continuation message need to follow a non-fin TEXT or BINARY message |
540 | close( |
541 | WebSocketCloseConstants::kProtocolErrorCode, |
542 | WebSocketCloseConstants::kProtocolErrorCodeContinuationOpCodeOutOfSequence); |
543 | } |
544 | |
545 | // |
546 | // Usual case. Small unfragmented messages |
547 | // |
548 | if (ws.fin && _chunks.empty()) |
549 | { |
550 | emitMessage(_fragmentedMessageKind, |
551 | frameData, |
552 | _receivedMessageCompressed, |
553 | onMessageCallback); |
554 | |
555 | _receivedMessageCompressed = false; |
556 | } |
557 | else |
558 | { |
559 | // |
560 | // Add intermediary message to our chunk list. |
561 | // We use a chunk list instead of a big buffer because resizing |
562 | // large buffer can be very costly when we need to re-allocate |
563 | // the internal buffer which is slow and can let the internal OS |
564 | // receive buffer fill out. |
565 | // |
566 | _chunks.emplace_back(frameData); |
567 | |
568 | if (ws.fin) |
569 | { |
570 | emitMessage(_fragmentedMessageKind, |
571 | getMergedChunks(), |
572 | _receivedMessageCompressed, |
573 | onMessageCallback); |
574 | |
575 | _chunks.clear(); |
576 | _receivedMessageCompressed = false; |
577 | } |
578 | else |
579 | { |
580 | emitMessage(MessageKind::FRAGMENT, std::string(), false, onMessageCallback); |
581 | } |
582 | } |
583 | } |
584 | else if (ws.opcode == wsheader_type::PING) |
585 | { |
586 | // too large |
587 | if (frameData.size() > 125) |
588 | { |
589 | // Unexpected frame type |
590 | close(WebSocketCloseConstants::kProtocolErrorCode, |
591 | WebSocketCloseConstants::kProtocolErrorPingPayloadOversized); |
592 | return; |
593 | } |
594 | |
595 | if (_enablePong) |
596 | { |
597 | // Reply back right away |
598 | bool compress = false; |
599 | sendData(wsheader_type::PONG, frameData, compress); |
600 | } |
601 | |
602 | emitMessage(MessageKind::PING, frameData, false, onMessageCallback); |
603 | } |
604 | else if (ws.opcode == wsheader_type::PONG) |
605 | { |
606 | _pongReceived = true; |
607 | emitMessage(MessageKind::PONG, frameData, false, onMessageCallback); |
608 | } |
609 | else if (ws.opcode == wsheader_type::CLOSE) |
610 | { |
611 | std::string reason; |
612 | uint16_t code = 0; |
613 | |
614 | if (ws.N >= 2) |
615 | { |
616 | // Extract the close code first, available as the first 2 bytes |
617 | code |= ((uint64_t) _rxbuf[ws.header_size]) << 8; |
618 | code |= ((uint64_t) _rxbuf[ws.header_size + 1]) << 0; |
619 | |
620 | // Get the reason. |
621 | if (ws.N > 2) |
622 | { |
623 | reason = frameData.substr(2, frameData.size()); |
624 | } |
625 | |
626 | // Validate that the reason is proper utf-8. Autobahn 7.5.1 |
627 | if (!validateUtf8(reason)) |
628 | { |
629 | code = WebSocketCloseConstants::kInvalidFramePayloadData; |
630 | reason = WebSocketCloseConstants::kInvalidFramePayloadDataMessage; |
631 | } |
632 | |
633 | // |
634 | // Validate close codes. Autobahn 7.9.* |
635 | // 1014, 1015 are debattable. The firefox MSDN has a description for them. |
636 | // Full list of status code and status range is defined in the dedicated |
637 | // RFC section at https://tools.ietf.org/html/rfc6455#page-45 |
638 | // |
639 | if (code < 1000 || code == 1004 || code == 1006 || (code > 1013 && code < 3000)) |
640 | { |
641 | // build up an error message containing the bad error code |
642 | std::stringstream ss; |
643 | ss << WebSocketCloseConstants::kInvalidCloseCodeMessage << ": " << code; |
644 | reason = ss.str(); |
645 | |
646 | code = WebSocketCloseConstants::kProtocolErrorCode; |
647 | } |
648 | } |
649 | else |
650 | { |
651 | // no close code received |
652 | code = WebSocketCloseConstants::kNoStatusCodeErrorCode; |
653 | reason = WebSocketCloseConstants::kNoStatusCodeErrorMessage; |
654 | } |
655 | |
656 | // We receive a CLOSE frame from remote and are NOT the ones who triggered the close |
657 | if (_readyState != ReadyState::CLOSING) |
658 | { |
659 | // send back the CLOSE frame |
660 | sendCloseFrame(code, reason); |
661 | |
662 | wakeUpFromPoll(SelectInterrupt::kCloseRequest); |
663 | |
664 | bool remote = true; |
665 | closeSocketAndSwitchToClosedState(code, reason, _rxbuf.size(), remote); |
666 | } |
667 | else |
668 | { |
669 | // we got the CLOSE frame answer from our close, so we can close the connection |
670 | // if the code/reason are the same |
671 | bool identicalReason = _closeCode == code && getCloseReason() == reason; |
672 | |
673 | if (identicalReason) |
674 | { |
675 | bool remote = false; |
676 | closeSocketAndSwitchToClosedState(code, reason, _rxbuf.size(), remote); |
677 | } |
678 | } |
679 | } |
680 | else |
681 | { |
682 | // Unexpected frame type |
683 | close(WebSocketCloseConstants::kProtocolErrorCode, |
684 | WebSocketCloseConstants::kProtocolErrorMessage, |
685 | _rxbuf.size()); |
686 | } |
687 | |
688 | // Erase the message that has been processed from the input/read buffer |
689 | _rxbuf.erase(_rxbuf.begin(), _rxbuf.begin() + ws.header_size + (size_t) ws.N); |
690 | } |
691 | |
692 | // if an abnormal closure was raised in poll, and nothing else triggered a CLOSED state in |
693 | // the received and processed data then close the connection |
694 | if (pollResult != PollResult::Succeeded) |
695 | { |
696 | _rxbuf.clear(); |
697 | |
698 | // if we previously closed the connection (CLOSING state), then set state to CLOSED |
699 | // (code/reason were set before) |
700 | if (_readyState == ReadyState::CLOSING) |
701 | { |
702 | closeSocket(); |
703 | setReadyState(ReadyState::CLOSED); |
704 | } |
705 | // if we weren't closing, then close using abnormal close code and message |
706 | else if (_readyState != ReadyState::CLOSED) |
707 | { |
708 | closeSocketAndSwitchToClosedState(WebSocketCloseConstants::kAbnormalCloseCode, |
709 | WebSocketCloseConstants::kAbnormalCloseMessage, |
710 | 0, |
711 | false); |
712 | } |
713 | } |
714 | } |
715 | |
716 | std::string WebSocketTransport::getMergedChunks() const |
717 | { |
718 | size_t length = 0; |
719 | for (auto&& chunk : _chunks) |
720 | { |
721 | length += chunk.size(); |
722 | } |
723 | |
724 | std::string msg; |
725 | msg.reserve(length); |
726 | |
727 | for (auto&& chunk : _chunks) |
728 | { |
729 | msg += chunk; |
730 | } |
731 | |
732 | return msg; |
733 | } |
734 | |
735 | void WebSocketTransport::emitMessage(MessageKind messageKind, |
736 | const std::string& message, |
737 | bool compressedMessage, |
738 | const OnMessageCallback& onMessageCallback) |
739 | { |
740 | size_t wireSize = message.size(); |
741 | |
742 | // When the RSV1 bit is 1 it means the message is compressed |
743 | if (compressedMessage && messageKind != MessageKind::FRAGMENT) |
744 | { |
745 | bool success = _perMessageDeflate->decompress(message, _decompressedMessage); |
746 | |
747 | if (messageKind == MessageKind::MSG_TEXT && !validateUtf8(_decompressedMessage)) |
748 | { |
749 | close(WebSocketCloseConstants::kInvalidFramePayloadData, |
750 | WebSocketCloseConstants::kInvalidFramePayloadDataMessage); |
751 | } |
752 | else |
753 | { |
754 | onMessageCallback(_decompressedMessage, wireSize, !success, messageKind); |
755 | } |
756 | } |
757 | else |
758 | { |
759 | if (messageKind == MessageKind::MSG_TEXT && !validateUtf8(message)) |
760 | { |
761 | close(WebSocketCloseConstants::kInvalidFramePayloadData, |
762 | WebSocketCloseConstants::kInvalidFramePayloadDataMessage); |
763 | } |
764 | else |
765 | { |
766 | onMessageCallback(message, wireSize, false, messageKind); |
767 | } |
768 | } |
769 | } |
770 | |
771 | unsigned WebSocketTransport::getRandomUnsigned() |
772 | { |
773 | auto now = std::chrono::system_clock::now(); |
774 | auto seconds = |
775 | std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch()).count(); |
776 | return static_cast<unsigned>(seconds); |
777 | } |
778 | |
779 | WebSocketSendInfo WebSocketTransport::(wsheader_type::opcode_type type, |
780 | const IXWebSocketSendData& message, |
781 | bool compress, |
782 | const OnProgressCallback& onProgressCallback) |
783 | { |
784 | if (_readyState != ReadyState::OPEN && _readyState != ReadyState::CLOSING) |
785 | { |
786 | return WebSocketSendInfo(false); |
787 | } |
788 | |
789 | size_t payloadSize = message.size(); |
790 | size_t wireSize = message.size(); |
791 | bool compressionError = false; |
792 | |
793 | auto message_begin = message.cbegin(); |
794 | auto message_end = message.cend(); |
795 | |
796 | if (compress) |
797 | { |
798 | if (!_perMessageDeflate->compress(message, _compressedMessage)) |
799 | { |
800 | bool success = false; |
801 | compressionError = true; |
802 | payloadSize = 0; |
803 | wireSize = 0; |
804 | return WebSocketSendInfo(success, compressionError, payloadSize, wireSize); |
805 | } |
806 | compressionError = false; |
807 | wireSize = _compressedMessage.size(); |
808 | |
809 | IXWebSocketSendData compressedSendData(_compressedMessage); |
810 | message_begin = compressedSendData.cbegin(); |
811 | message_end = compressedSendData.cend(); |
812 | } |
813 | |
814 | { |
815 | std::lock_guard<std::mutex> lock(_txbufMutex); |
816 | _txbuf.reserve(wireSize); |
817 | } |
818 | |
819 | bool success = true; |
820 | |
821 | // Common case for most message. No fragmentation required. |
822 | if (wireSize < kChunkSize) |
823 | { |
824 | success = sendFragment(type, true, message_begin, message_end, compress); |
825 | |
826 | if (onProgressCallback) |
827 | { |
828 | onProgressCallback(0, 1); |
829 | } |
830 | } |
831 | else |
832 | { |
833 | // |
834 | // Large messages need to be fragmented |
835 | // |
836 | // Rules: |
837 | // First message needs to specify a proper type (BINARY or TEXT) |
838 | // Intermediary and last messages need to be of type CONTINUATION |
839 | // Last message must set the fin byte. |
840 | // |
841 | auto steps = wireSize / kChunkSize; |
842 | |
843 | auto begin = message_begin; |
844 | auto end = message_end; |
845 | |
846 | for (uint64_t i = 0; i < steps; ++i) |
847 | { |
848 | bool firstStep = i == 0; |
849 | bool lastStep = (i + 1) == steps; |
850 | bool fin = lastStep; |
851 | |
852 | end = begin + kChunkSize; |
853 | if (lastStep) |
854 | { |
855 | end = message_end; |
856 | } |
857 | |
858 | auto opcodeType = type; |
859 | if (!firstStep) |
860 | { |
861 | opcodeType = wsheader_type::CONTINUATION; |
862 | } |
863 | |
864 | // Send message |
865 | if (!sendFragment(opcodeType, fin, begin, end, compress)) |
866 | { |
867 | return WebSocketSendInfo(false); |
868 | } |
869 | |
870 | if (onProgressCallback && !onProgressCallback((int) i, (int) steps)) |
871 | { |
872 | break; |
873 | } |
874 | |
875 | begin += kChunkSize; |
876 | } |
877 | } |
878 | |
879 | // Request to flush the send buffer on the background thread if it isn't empty |
880 | if (!isSendBufferEmpty()) |
881 | { |
882 | wakeUpFromPoll(SelectInterrupt::kSendRequest); |
883 | |
884 | // FIXME: we should have a timeout when sending large messages: see #131 |
885 | if (_blockingSend && !flushSendBuffer()) |
886 | { |
887 | success = false; |
888 | } |
889 | } |
890 | |
891 | return WebSocketSendInfo(success, compressionError, payloadSize, wireSize); |
892 | } |
893 | |
894 | template<class Iterator> |
895 | bool WebSocketTransport::(wsheader_type::opcode_type type, |
896 | bool fin, |
897 | Iterator message_begin, |
898 | Iterator message_end, |
899 | bool compress) |
900 | { |
901 | uint64_t message_size = static_cast<uint64_t>(message_end - message_begin); |
902 | |
903 | unsigned x = getRandomUnsigned(); |
904 | uint8_t masking_key[4] = {}; |
905 | masking_key[0] = (x >> 24); |
906 | masking_key[1] = (x >> 16) & 0xff; |
907 | masking_key[2] = (x >> 8) & 0xff; |
908 | masking_key[3] = (x) &0xff; |
909 | |
910 | std::vector<uint8_t> ; |
911 | header.assign(2 + (message_size >= 126 ? 2 : 0) + (message_size >= 65536 ? 6 : 0) + |
912 | (_useMask ? 4 : 0), |
913 | 0); |
914 | header[0] = type; |
915 | |
916 | // The fin bit indicate that this is the last fragment. Fin is French for end. |
917 | if (fin) |
918 | { |
919 | header[0] |= 0x80; |
920 | } |
921 | |
922 | // The rsv1 bit indicate that the frame is compressed |
923 | // continuation opcodes should not set it. Autobahn 12.2.10 and others 12.X |
924 | if (compress && type != wsheader_type::CONTINUATION) |
925 | { |
926 | header[0] |= 0x40; |
927 | } |
928 | |
929 | if (message_size < 126) |
930 | { |
931 | header[1] = (message_size & 0xff) | (_useMask ? 0x80 : 0); |
932 | |
933 | if (_useMask) |
934 | { |
935 | header[2] = masking_key[0]; |
936 | header[3] = masking_key[1]; |
937 | header[4] = masking_key[2]; |
938 | header[5] = masking_key[3]; |
939 | } |
940 | } |
941 | else if (message_size < 65536) |
942 | { |
943 | header[1] = 126 | (_useMask ? 0x80 : 0); |
944 | header[2] = (message_size >> 8) & 0xff; |
945 | header[3] = (message_size >> 0) & 0xff; |
946 | |
947 | if (_useMask) |
948 | { |
949 | header[4] = masking_key[0]; |
950 | header[5] = masking_key[1]; |
951 | header[6] = masking_key[2]; |
952 | header[7] = masking_key[3]; |
953 | } |
954 | } |
955 | else |
956 | { // TODO: run coverage testing here |
957 | header[1] = 127 | (_useMask ? 0x80 : 0); |
958 | header[2] = (message_size >> 56) & 0xff; |
959 | header[3] = (message_size >> 48) & 0xff; |
960 | header[4] = (message_size >> 40) & 0xff; |
961 | header[5] = (message_size >> 32) & 0xff; |
962 | header[6] = (message_size >> 24) & 0xff; |
963 | header[7] = (message_size >> 16) & 0xff; |
964 | header[8] = (message_size >> 8) & 0xff; |
965 | header[9] = (message_size >> 0) & 0xff; |
966 | |
967 | if (_useMask) |
968 | { |
969 | header[10] = masking_key[0]; |
970 | header[11] = masking_key[1]; |
971 | header[12] = masking_key[2]; |
972 | header[13] = masking_key[3]; |
973 | } |
974 | } |
975 | |
976 | // _txbuf will keep growing until it can be transmitted over the socket: |
977 | appendToSendBuffer(header, message_begin, message_end, message_size, masking_key); |
978 | |
979 | // Now actually send this data |
980 | return sendOnSocket(); |
981 | } |
982 | |
983 | WebSocketSendInfo WebSocketTransport::sendPing(const IXWebSocketSendData& message) |
984 | { |
985 | bool compress = false; |
986 | WebSocketSendInfo info = sendData(wsheader_type::PING, message, compress); |
987 | |
988 | if (info.success) |
989 | { |
990 | std::lock_guard<std::mutex> lck(_lastSendPingTimePointMutex); |
991 | _lastSendPingTimePoint = std::chrono::steady_clock::now(); |
992 | } |
993 | |
994 | return info; |
995 | } |
996 | |
997 | WebSocketSendInfo WebSocketTransport::sendBinary(const IXWebSocketSendData& message, |
998 | const OnProgressCallback& onProgressCallback) |
999 | |
1000 | { |
1001 | return sendData( |
1002 | wsheader_type::BINARY_FRAME, message, _enablePerMessageDeflate, onProgressCallback); |
1003 | } |
1004 | |
1005 | WebSocketSendInfo WebSocketTransport::sendText(const IXWebSocketSendData& message, |
1006 | const OnProgressCallback& onProgressCallback) |
1007 | |
1008 | { |
1009 | return sendData( |
1010 | wsheader_type::TEXT_FRAME, message, _enablePerMessageDeflate, onProgressCallback); |
1011 | } |
1012 | |
1013 | bool WebSocketTransport::sendOnSocket() |
1014 | { |
1015 | std::lock_guard<std::mutex> lock(_txbufMutex); |
1016 | |
1017 | while (_txbuf.size()) |
1018 | { |
1019 | ssize_t ret = 0; |
1020 | { |
1021 | std::lock_guard<std::mutex> lock(_socketMutex); |
1022 | ret = _socket->send((char*) &_txbuf[0], _txbuf.size()); |
1023 | } |
1024 | |
1025 | if (ret < 0 && Socket::isWaitNeeded()) |
1026 | { |
1027 | break; |
1028 | } |
1029 | else if (ret <= 0) |
1030 | { |
1031 | closeSocket(); |
1032 | setReadyState(ReadyState::CLOSED); |
1033 | return false; |
1034 | } |
1035 | else |
1036 | { |
1037 | _txbuf.erase(_txbuf.begin(), _txbuf.begin() + ret); |
1038 | } |
1039 | } |
1040 | |
1041 | return true; |
1042 | } |
1043 | |
1044 | bool WebSocketTransport::receiveFromSocket() |
1045 | { |
1046 | while (true) |
1047 | { |
1048 | ssize_t ret = _socket->recv((char*) &_readbuf[0], _readbuf.size()); |
1049 | |
1050 | if (ret < 0 && Socket::isWaitNeeded()) |
1051 | { |
1052 | break; |
1053 | } |
1054 | else if (ret <= 0) |
1055 | { |
1056 | // if there are received data pending to be processed, then delay the abnormal |
1057 | // closure to after dispatch (other close code/reason could be read from the |
1058 | // buffer) |
1059 | |
1060 | closeSocket(); |
1061 | return false; |
1062 | } |
1063 | else |
1064 | { |
1065 | _rxbuf.insert(_rxbuf.end(), _readbuf.begin(), _readbuf.begin() + ret); |
1066 | } |
1067 | } |
1068 | |
1069 | return true; |
1070 | } |
1071 | |
1072 | void WebSocketTransport::sendCloseFrame(uint16_t code, const std::string& reason) |
1073 | { |
1074 | bool compress = false; |
1075 | |
1076 | // if a status is set/was read |
1077 | if (code != WebSocketCloseConstants::kNoStatusCodeErrorCode) |
1078 | { |
1079 | // See list of close events here: |
1080 | // https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent |
1081 | std::string closure {(char) (code >> 8), (char) (code & 0xff)}; |
1082 | |
1083 | // copy reason after code |
1084 | closure.append(reason); |
1085 | |
1086 | sendData(wsheader_type::CLOSE, closure, compress); |
1087 | } |
1088 | else |
1089 | { |
1090 | // no close code/reason set |
1091 | sendData(wsheader_type::CLOSE, std::string("" ), compress); |
1092 | } |
1093 | } |
1094 | |
1095 | void WebSocketTransport::closeSocket() |
1096 | { |
1097 | std::lock_guard<std::mutex> lock(_socketMutex); |
1098 | _socket->close(); |
1099 | } |
1100 | |
1101 | bool WebSocketTransport::wakeUpFromPoll(uint64_t wakeUpCode) |
1102 | { |
1103 | std::lock_guard<std::mutex> lock(_socketMutex); |
1104 | return _socket->wakeUpFromPoll(wakeUpCode); |
1105 | } |
1106 | |
1107 | void WebSocketTransport::closeSocketAndSwitchToClosedState(uint16_t code, |
1108 | const std::string& reason, |
1109 | size_t closeWireSize, |
1110 | bool remote) |
1111 | { |
1112 | closeSocket(); |
1113 | |
1114 | setCloseReason(reason); |
1115 | _closeCode = code; |
1116 | _closeWireSize = closeWireSize; |
1117 | _closeRemote = remote; |
1118 | |
1119 | setReadyState(ReadyState::CLOSED); |
1120 | _requestInitCancellation = false; |
1121 | } |
1122 | |
1123 | void WebSocketTransport::close(uint16_t code, |
1124 | const std::string& reason, |
1125 | size_t closeWireSize, |
1126 | bool remote) |
1127 | { |
1128 | _requestInitCancellation = true; |
1129 | |
1130 | if (_readyState == ReadyState::CLOSING || _readyState == ReadyState::CLOSED) return; |
1131 | |
1132 | if (closeWireSize == 0) |
1133 | { |
1134 | closeWireSize = reason.size(); |
1135 | } |
1136 | |
1137 | setCloseReason(reason); |
1138 | _closeCode = code; |
1139 | _closeWireSize = closeWireSize; |
1140 | _closeRemote = remote; |
1141 | |
1142 | { |
1143 | std::lock_guard<std::mutex> lock(_closingTimePointMutex); |
1144 | _closingTimePoint = std::chrono::steady_clock::now(); |
1145 | } |
1146 | setReadyState(ReadyState::CLOSING); |
1147 | |
1148 | sendCloseFrame(code, reason); |
1149 | |
1150 | // wake up the poll, but do not close yet |
1151 | wakeUpFromPoll(SelectInterrupt::kSendRequest); |
1152 | } |
1153 | |
1154 | size_t WebSocketTransport::bufferedAmount() const |
1155 | { |
1156 | std::lock_guard<std::mutex> lock(_txbufMutex); |
1157 | return _txbuf.size(); |
1158 | } |
1159 | |
1160 | bool WebSocketTransport::flushSendBuffer() |
1161 | { |
1162 | while (!isSendBufferEmpty() && !_requestInitCancellation) |
1163 | { |
1164 | // Wait with a 10ms timeout until the socket is ready to write. |
1165 | // This way we are not busy looping |
1166 | PollResultType result = _socket->isReadyToWrite(10); |
1167 | |
1168 | if (result == PollResultType::Error) |
1169 | { |
1170 | closeSocket(); |
1171 | setReadyState(ReadyState::CLOSED); |
1172 | return false; |
1173 | } |
1174 | else if (result == PollResultType::ReadyForWrite) |
1175 | { |
1176 | if (!sendOnSocket()) |
1177 | { |
1178 | return false; |
1179 | } |
1180 | } |
1181 | } |
1182 | |
1183 | return true; |
1184 | } |
1185 | |
1186 | void WebSocketTransport::setCloseReason(const std::string& reason) |
1187 | { |
1188 | std::lock_guard<std::mutex> lock(_closeReasonMutex); |
1189 | _closeReason = reason; |
1190 | } |
1191 | |
1192 | const std::string& WebSocketTransport::getCloseReason() const |
1193 | { |
1194 | std::lock_guard<std::mutex> lock(_closeReasonMutex); |
1195 | return _closeReason; |
1196 | } |
1197 | } // namespace ix |
1198 | |