1/*
2Copyright (c) 2007-2019 Contributors as noted in the AUTHORS file
3
4This file is part of libzmq, the ZeroMQ core engine in C++.
5
6libzmq is free software; you can redistribute it and/or modify it under
7the terms of the GNU Lesser General Public License (LGPL) as published
8by the Free Software Foundation; either version 3 of the License, or
9(at your option) any later version.
10
11As a special exception, the Contributors give you permission to link
12this library with independent modules to produce an executable,
13regardless of the license terms of these independent modules, and to
14copy and distribute the resulting executable under terms of your choice,
15provided that you also meet, for each linked independent module, the
16terms and conditions of the license of that module. An independent
17module is a module which is not derived from or based on this library.
18If you modify this library, you must extend this exception to your
19version of the library.
20
21libzmq is distributed in the hope that it will be useful, but WITHOUT
22ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
23FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
24License for more details.
25
26You should have received a copy of the GNU Lesser General Public License
27along with this program. If not, see <http://www.gnu.org/licenses/>.
28*/
29
30#include "precompiled.hpp"
31
32#ifdef ZMQ_USE_NSS
33#include <secoid.h>
34#include <sechash.h>
35#define SHA_DIGEST_LENGTH 20
36#elif defined ZMQ_USE_BUILTIN_SHA1
37#include "../external/sha1/sha1.h"
38#elif defined ZMQ_USE_GNUTLS
39#define SHA_DIGEST_LENGTH 20
40#include <gnutls/gnutls.h>
41#include <gnutls/crypto.h>
42#endif
43
44#if !defined ZMQ_HAVE_WINDOWS
45#include <sys/types.h>
46#include <unistd.h>
47#include <sys/socket.h>
48#include <netinet/in.h>
49#include <arpa/inet.h>
50#ifdef ZMQ_HAVE_VXWORKS
51#include <sockLib.h>
52#endif
53#endif
54
55#include "tcp.hpp"
56#include "ws_engine.hpp"
57#include "session_base.hpp"
58#include "err.hpp"
59#include "ip.hpp"
60#include "random.hpp"
61#include "ws_decoder.hpp"
62#include "ws_encoder.hpp"
63#include "null_mechanism.hpp"
64#include "plain_server.hpp"
65#include "plain_client.hpp"
66
67#ifdef ZMQ_HAVE_CURVE
68#include "curve_client.hpp"
69#include "curve_server.hpp"
70#endif
71
72#ifdef ZMQ_HAVE_WINDOWS
73#define strcasecmp _stricmp
74#endif
75
76// OSX uses a different name for this socket option
77#ifndef IPV6_ADD_MEMBERSHIP
78#define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP
79#endif
80
81#ifdef __APPLE__
82#include <TargetConditionals.h>
83#endif
84
85static int
86encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_);
87
88static void compute_accept_key (char *key_,
89 unsigned char hash_[SHA_DIGEST_LENGTH]);
90
91zmq::ws_engine_t::ws_engine_t (fd_t fd_,
92 const options_t &options_,
93 const endpoint_uri_pair_t &endpoint_uri_pair_,
94 ws_address_t &address_,
95 bool client_) :
96 stream_engine_base_t (fd_, options_, endpoint_uri_pair_),
97 _client (client_),
98 _address (address_),
99 _client_handshake_state (client_handshake_initial),
100 _server_handshake_state (handshake_initial),
101 _header_name_position (0),
102 _header_value_position (0),
103 _header_upgrade_websocket (false),
104 _header_connection_upgrade (false)
105{
106 memset (_websocket_key, 0, MAX_HEADER_VALUE_LENGTH + 1);
107 memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1);
108 memset (_websocket_protocol, 0, MAX_HEADER_VALUE_LENGTH + 1);
109
110 _next_msg = &ws_engine_t::next_handshake_command;
111 _process_msg = &ws_engine_t::process_handshake_command;
112}
113
114zmq::ws_engine_t::~ws_engine_t ()
115{
116}
117
118void zmq::ws_engine_t::start_ws_handshake ()
119{
120 if (_client) {
121 char protocol[21];
122 if (_options.mechanism == ZMQ_NULL)
123 strcpy (protocol, "ZWS2.0/NULL,ZWS2.0");
124 else if (_options.mechanism == ZMQ_PLAIN)
125 strcpy (protocol, "ZWS2.0/PLAIN");
126#ifdef ZMQ_HAVE_CURVE
127 else if (_options.mechanism == ZMQ_CURVE)
128 strcpy (protocol, "ZWS2.0/CURVE");
129#endif
130 else
131 assert (false);
132
133 unsigned char nonce[16];
134 int *p = reinterpret_cast<int *> (nonce);
135
136 // The nonce doesn't have to be secure one, it is just use to avoid proxy cache
137 *p = zmq::generate_random ();
138 *(p + 1) = zmq::generate_random ();
139 *(p + 2) = zmq::generate_random ();
140 *(p + 3) = zmq::generate_random ();
141
142 int size =
143 encode_base64 (nonce, 16, _websocket_key, MAX_HEADER_VALUE_LENGTH);
144 assert (size > 0);
145
146 size = snprintf (
147 reinterpret_cast<char *> (_write_buffer), WS_BUFFER_SIZE,
148 "GET %s HTTP/1.1\r\n"
149 "Host: %s\r\n"
150 "Upgrade: websocket\r\n"
151 "Connection: Upgrade\r\n"
152 "Sec-WebSocket-Key: %s\r\n"
153 "Sec-WebSocket-Protocol: %s\r\n"
154 "Sec-WebSocket-Version: 13\r\n\r\n",
155 _address.path (), _address.host (), _websocket_key, protocol);
156 assert (size > 0 && size < WS_BUFFER_SIZE);
157 _outpos = _write_buffer;
158 _outsize = size;
159 set_pollout ();
160 }
161}
162
163void zmq::ws_engine_t::plug_internal ()
164{
165 start_ws_handshake ();
166 set_pollin ();
167 in_event ();
168}
169
170int zmq::ws_engine_t::routing_id_msg (msg_t *msg_)
171{
172 int rc = msg_->init_size (_options.routing_id_size);
173 errno_assert (rc == 0);
174 if (_options.routing_id_size > 0)
175 memcpy (msg_->data (), _options.routing_id, _options.routing_id_size);
176 _next_msg = &ws_engine_t::pull_msg_from_session;
177
178 return 0;
179}
180
181int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_)
182{
183 if (_options.recv_routing_id) {
184 msg_->set_flags (msg_t::routing_id);
185 int rc = session ()->push_msg (msg_);
186 errno_assert (rc == 0);
187 } else {
188 int rc = msg_->close ();
189 errno_assert (rc == 0);
190 rc = msg_->init ();
191 errno_assert (rc == 0);
192 }
193
194 _process_msg = &ws_engine_t::push_msg_to_session;
195
196 return 0;
197}
198
199bool zmq::ws_engine_t::select_protocol (char *protocol_)
200{
201 if (_options.mechanism == ZMQ_NULL && (strcmp ("ZWS2.0", protocol_) == 0)) {
202 _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
203 &ws_engine_t::routing_id_msg);
204 _process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
205 &ws_engine_t::process_routing_id_msg);
206 return true;
207 }
208 if (_options.mechanism == ZMQ_NULL
209 && strcmp ("ZWS2.0/NULL", protocol_) == 0) {
210 _mechanism = new (std::nothrow)
211 null_mechanism_t (session (), _peer_address, _options);
212 alloc_assert (_mechanism);
213 return true;
214 } else if (_options.mechanism == ZMQ_PLAIN
215 && strcmp ("ZWS2.0/PLAIN", protocol_) == 0) {
216 if (_options.as_server)
217 _mechanism = new (std::nothrow)
218 plain_server_t (session (), _peer_address, _options);
219 else
220 _mechanism =
221 new (std::nothrow) plain_client_t (session (), _options);
222 alloc_assert (_mechanism);
223 return true;
224 }
225#ifdef ZMQ_HAVE_CURVE
226 else if (_options.mechanism == ZMQ_CURVE
227 && strcmp ("ZWS2.0/CURVE", protocol_) == 0) {
228 if (_options.as_server)
229 _mechanism = new (std::nothrow)
230 curve_server_t (session (), _peer_address, _options);
231 else
232 _mechanism =
233 new (std::nothrow) curve_client_t (session (), _options);
234 alloc_assert (_mechanism);
235 return true;
236 }
237#endif
238
239 return false;
240}
241
242bool zmq::ws_engine_t::handshake ()
243{
244 bool complete;
245
246 if (_client)
247 complete = client_handshake ();
248 else
249 complete = server_handshake ();
250
251 if (complete) {
252 _encoder =
253 new (std::nothrow) ws_encoder_t (_options.out_batch_size, _client);
254 alloc_assert (_encoder);
255
256 _decoder = new (std::nothrow)
257 ws_decoder_t (_options.in_batch_size, _options.maxmsgsize,
258 _options.zero_copy, !_client);
259 alloc_assert (_decoder);
260
261 socket ()->event_handshake_succeeded (_endpoint_uri_pair, 0);
262
263 set_pollout ();
264 }
265
266 return complete;
267}
268
269bool zmq::ws_engine_t::server_handshake ()
270{
271 int nbytes = read (_read_buffer, WS_BUFFER_SIZE);
272 if (nbytes == -1) {
273 if (errno != EAGAIN)
274 error (zmq::i_engine::connection_error);
275 return false;
276 }
277
278 _inpos = _read_buffer;
279 _insize = nbytes;
280
281 while (_insize > 0) {
282 char c = static_cast<char> (*_inpos);
283
284 switch (_server_handshake_state) {
285 case handshake_initial:
286 if (c == 'G')
287 _server_handshake_state = request_line_G;
288 else
289 _server_handshake_state = handshake_error;
290 break;
291 case request_line_G:
292 if (c == 'E')
293 _server_handshake_state = request_line_GE;
294 else
295 _server_handshake_state = handshake_error;
296 break;
297 case request_line_GE:
298 if (c == 'T')
299 _server_handshake_state = request_line_GET;
300 else
301 _server_handshake_state = handshake_error;
302 break;
303 case request_line_GET:
304 if (c == ' ')
305 _server_handshake_state = request_line_GET_space;
306 else
307 _server_handshake_state = handshake_error;
308 break;
309 case request_line_GET_space:
310 if (c == '\r' || c == '\n')
311 _server_handshake_state = handshake_error;
312 // TODO: instead of check what is not allowed check what is allowed
313 if (c != ' ')
314 _server_handshake_state = request_line_resource;
315 else
316 _server_handshake_state = request_line_GET_space;
317 break;
318 case request_line_resource:
319 if (c == '\r' || c == '\n')
320 _server_handshake_state = handshake_error;
321 else if (c == ' ')
322 _server_handshake_state = request_line_resource_space;
323 else
324 _server_handshake_state = request_line_resource;
325 break;
326 case request_line_resource_space:
327 if (c == 'H')
328 _server_handshake_state = request_line_H;
329 else
330 _server_handshake_state = handshake_error;
331 break;
332 case request_line_H:
333 if (c == 'T')
334 _server_handshake_state = request_line_HT;
335 else
336 _server_handshake_state = handshake_error;
337 break;
338 case request_line_HT:
339 if (c == 'T')
340 _server_handshake_state = request_line_HTT;
341 else
342 _server_handshake_state = handshake_error;
343 break;
344 case request_line_HTT:
345 if (c == 'P')
346 _server_handshake_state = request_line_HTTP;
347 else
348 _server_handshake_state = handshake_error;
349 break;
350 case request_line_HTTP:
351 if (c == '/')
352 _server_handshake_state = request_line_HTTP_slash;
353 else
354 _server_handshake_state = handshake_error;
355 break;
356 case request_line_HTTP_slash:
357 if (c == '1')
358 _server_handshake_state = request_line_HTTP_slash_1;
359 else
360 _server_handshake_state = handshake_error;
361 break;
362 case request_line_HTTP_slash_1:
363 if (c == '.')
364 _server_handshake_state = request_line_HTTP_slash_1_dot;
365 else
366 _server_handshake_state = handshake_error;
367 break;
368 case request_line_HTTP_slash_1_dot:
369 if (c == '1')
370 _server_handshake_state = request_line_HTTP_slash_1_dot_1;
371 else
372 _server_handshake_state = handshake_error;
373 break;
374 case request_line_HTTP_slash_1_dot_1:
375 if (c == '\r')
376 _server_handshake_state = request_line_cr;
377 else
378 _server_handshake_state = handshake_error;
379 break;
380 case request_line_cr:
381 if (c == '\n')
382 _server_handshake_state = header_field_begin_name;
383 else
384 _server_handshake_state = handshake_error;
385 break;
386 case header_field_begin_name:
387 switch (c) {
388 case '\r':
389 _server_handshake_state = handshake_end_line_cr;
390 break;
391 case '\n':
392 _server_handshake_state = handshake_error;
393 break;
394 default:
395 _header_name[0] = c;
396 _header_name_position = 1;
397 _server_handshake_state = header_field_name;
398 break;
399 }
400 break;
401 case header_field_name:
402 if (c == '\r' || c == '\n')
403 _server_handshake_state = handshake_error;
404 else if (c == ':') {
405 _header_name[_header_name_position] = '\0';
406 _server_handshake_state = header_field_colon;
407 } else if (_header_name_position + 1 > MAX_HEADER_NAME_LENGTH)
408 _server_handshake_state = handshake_error;
409 else {
410 _header_name[_header_name_position] = c;
411 _header_name_position++;
412 _server_handshake_state = header_field_name;
413 }
414 break;
415 case header_field_colon:
416 case header_field_value_trailing_space:
417 if (c == '\n')
418 _server_handshake_state = handshake_error;
419 else if (c == '\r')
420 _server_handshake_state = header_field_cr;
421 else if (c == ' ')
422 _server_handshake_state = header_field_value_trailing_space;
423 else {
424 _header_value[0] = c;
425 _header_value_position = 1;
426 _server_handshake_state = header_field_value;
427 }
428 break;
429 case header_field_value:
430 if (c == '\n')
431 _server_handshake_state = handshake_error;
432 else if (c == '\r') {
433 _header_value[_header_value_position] = '\0';
434
435 if (strcasecmp ("upgrade", _header_name) == 0)
436 _header_upgrade_websocket =
437 strcasecmp ("websocket", _header_value) == 0;
438 else if (strcasecmp ("connection", _header_name) == 0)
439 _header_connection_upgrade =
440 strcasecmp ("upgrade", _header_value) == 0;
441 else if (strcasecmp ("Sec-WebSocket-Key", _header_name)
442 == 0)
443 strcpy (_websocket_key, _header_value);
444 else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
445 == 0) {
446 // Currently only the ZWS2.0 is supported
447 // Sec-WebSocket-Protocol can appear multiple times or be a comma separated list
448 // if _websocket_protocol is already set we skip the check
449 if (_websocket_protocol[0] == '\0') {
450 char *p = strtok (_header_value, ",");
451 while (p != NULL) {
452 if (*p == ' ')
453 p++;
454
455 if (select_protocol (p)) {
456 strcpy (_websocket_protocol, p);
457 break;
458 }
459
460 p = strtok (NULL, ",");
461 }
462 }
463 }
464
465 _server_handshake_state = header_field_cr;
466 } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)
467 _server_handshake_state = handshake_error;
468 else {
469 _header_value[_header_value_position] = c;
470 _header_value_position++;
471 _server_handshake_state = header_field_value;
472 }
473 break;
474 case header_field_cr:
475 if (c == '\n')
476 _server_handshake_state = header_field_begin_name;
477 else
478 _server_handshake_state = handshake_error;
479 break;
480 case handshake_end_line_cr:
481 if (c == '\n') {
482 if (_header_connection_upgrade && _header_upgrade_websocket
483 && _websocket_protocol[0] != '\0'
484 && _websocket_key[0] != '\0') {
485 _server_handshake_state = handshake_complete;
486
487 unsigned char hash[SHA_DIGEST_LENGTH];
488 compute_accept_key (_websocket_key, hash);
489
490 int accept_key_len = encode_base64 (
491 hash, SHA_DIGEST_LENGTH, _websocket_accept,
492 MAX_HEADER_VALUE_LENGTH);
493 assert (accept_key_len > 0);
494 _websocket_accept[accept_key_len] = '\0';
495
496 int written =
497 snprintf (reinterpret_cast<char *> (_write_buffer),
498 WS_BUFFER_SIZE,
499 "HTTP/1.1 101 Switching Protocols\r\n"
500 "Upgrade: websocket\r\n"
501 "Connection: Upgrade\r\n"
502 "Sec-WebSocket-Accept: %s\r\n"
503 "Sec-WebSocket-Protocol: %s\r\n"
504 "\r\n",
505 _websocket_accept, _websocket_protocol);
506 assert (written >= 0 && written < WS_BUFFER_SIZE);
507 _outpos = _write_buffer;
508 _outsize = written;
509
510 _inpos++;
511 _insize--;
512
513 return true;
514 }
515 _server_handshake_state = handshake_error;
516 } else
517 _server_handshake_state = handshake_error;
518 break;
519 default:
520 assert (false);
521 }
522
523 _inpos++;
524 _insize--;
525
526 if (_server_handshake_state == handshake_error) {
527 // TODO: send bad request
528
529 socket ()->event_handshake_failed_protocol (
530 _endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
531
532 error (zmq::i_engine::protocol_error);
533 return false;
534 }
535 }
536 return false;
537}
538
539bool zmq::ws_engine_t::client_handshake ()
540{
541 int nbytes = read (_read_buffer, WS_BUFFER_SIZE);
542 if (nbytes == -1) {
543 if (errno != EAGAIN)
544 error (zmq::i_engine::connection_error);
545 return false;
546 }
547
548 _inpos = _read_buffer;
549 _insize = nbytes;
550
551 while (_insize > 0) {
552 char c = static_cast<char> (*_inpos);
553
554 switch (_client_handshake_state) {
555 case client_handshake_initial:
556 if (c == 'H')
557 _client_handshake_state = response_line_H;
558 else
559 _client_handshake_state = client_handshake_error;
560 break;
561 case response_line_H:
562 if (c == 'T')
563 _client_handshake_state = response_line_HT;
564 else
565 _client_handshake_state = client_handshake_error;
566 break;
567 case response_line_HT:
568 if (c == 'T')
569 _client_handshake_state = response_line_HTT;
570 else
571 _client_handshake_state = client_handshake_error;
572 break;
573 case response_line_HTT:
574 if (c == 'P')
575 _client_handshake_state = response_line_HTTP;
576 else
577 _client_handshake_state = client_handshake_error;
578 break;
579 case response_line_HTTP:
580 if (c == '/')
581 _client_handshake_state = response_line_HTTP_slash;
582 else
583 _client_handshake_state = client_handshake_error;
584 break;
585 case response_line_HTTP_slash:
586 if (c == '1')
587 _client_handshake_state = response_line_HTTP_slash_1;
588 else
589 _client_handshake_state = client_handshake_error;
590 break;
591 case response_line_HTTP_slash_1:
592 if (c == '.')
593 _client_handshake_state = response_line_HTTP_slash_1_dot;
594 else
595 _client_handshake_state = client_handshake_error;
596 break;
597 case response_line_HTTP_slash_1_dot:
598 if (c == '1')
599 _client_handshake_state = response_line_HTTP_slash_1_dot_1;
600 else
601 _client_handshake_state = client_handshake_error;
602 break;
603 case response_line_HTTP_slash_1_dot_1:
604 if (c == ' ')
605 _client_handshake_state =
606 response_line_HTTP_slash_1_dot_1_space;
607 else
608 _client_handshake_state = client_handshake_error;
609 break;
610 case response_line_HTTP_slash_1_dot_1_space:
611 if (c == ' ')
612 _client_handshake_state =
613 response_line_HTTP_slash_1_dot_1_space;
614 else if (c == '1')
615 _client_handshake_state = response_line_status_1;
616 else
617 _client_handshake_state = client_handshake_error;
618 break;
619 case response_line_status_1:
620 if (c == '0')
621 _client_handshake_state = response_line_status_10;
622 else
623 _client_handshake_state = client_handshake_error;
624 break;
625 case response_line_status_10:
626 if (c == '1')
627 _client_handshake_state = response_line_status_101;
628 else
629 _client_handshake_state = client_handshake_error;
630 break;
631 case response_line_status_101:
632 if (c == ' ')
633 _client_handshake_state = response_line_status_101_space;
634 else
635 _client_handshake_state = client_handshake_error;
636 break;
637 case response_line_status_101_space:
638 if (c == ' ')
639 _client_handshake_state = response_line_status_101_space;
640 else if (c == 'S')
641 _client_handshake_state = response_line_s;
642 else
643 _client_handshake_state = client_handshake_error;
644 break;
645 case response_line_s:
646 if (c == 'w')
647 _client_handshake_state = response_line_sw;
648 else
649 _client_handshake_state = client_handshake_error;
650 break;
651 case response_line_sw:
652 if (c == 'i')
653 _client_handshake_state = response_line_swi;
654 else
655 _client_handshake_state = client_handshake_error;
656 break;
657 case response_line_swi:
658 if (c == 't')
659 _client_handshake_state = response_line_swit;
660 else
661 _client_handshake_state = client_handshake_error;
662 break;
663 case response_line_swit:
664 if (c == 'c')
665 _client_handshake_state = response_line_switc;
666 else
667 _client_handshake_state = client_handshake_error;
668 break;
669 case response_line_switc:
670 if (c == 'h')
671 _client_handshake_state = response_line_switch;
672 else
673 _client_handshake_state = client_handshake_error;
674 break;
675 case response_line_switch:
676 if (c == 'i')
677 _client_handshake_state = response_line_switchi;
678 else
679 _client_handshake_state = client_handshake_error;
680 break;
681 case response_line_switchi:
682 if (c == 'n')
683 _client_handshake_state = response_line_switchin;
684 else
685 _client_handshake_state = client_handshake_error;
686 break;
687 case response_line_switchin:
688 if (c == 'g')
689 _client_handshake_state = response_line_switching;
690 else
691 _client_handshake_state = client_handshake_error;
692 break;
693 case response_line_switching:
694 if (c == ' ')
695 _client_handshake_state = response_line_switching_space;
696 else
697 _client_handshake_state = client_handshake_error;
698 break;
699 case response_line_switching_space:
700 if (c == 'P')
701 _client_handshake_state = response_line_p;
702 else
703 _client_handshake_state = client_handshake_error;
704 break;
705 case response_line_p:
706 if (c == 'r')
707 _client_handshake_state = response_line_pr;
708 else
709 _client_handshake_state = client_handshake_error;
710 break;
711 case response_line_pr:
712 if (c == 'o')
713 _client_handshake_state = response_line_pro;
714 else
715 _client_handshake_state = client_handshake_error;
716 break;
717 case response_line_pro:
718 if (c == 't')
719 _client_handshake_state = response_line_prot;
720 else
721 _client_handshake_state = client_handshake_error;
722 break;
723 case response_line_prot:
724 if (c == 'o')
725 _client_handshake_state = response_line_proto;
726 else
727 _client_handshake_state = client_handshake_error;
728 break;
729 case response_line_proto:
730 if (c == 'c')
731 _client_handshake_state = response_line_protoc;
732 else
733 _client_handshake_state = client_handshake_error;
734 break;
735 case response_line_protoc:
736 if (c == 'o')
737 _client_handshake_state = response_line_protoco;
738 else
739 _client_handshake_state = client_handshake_error;
740 break;
741 case response_line_protoco:
742 if (c == 'l')
743 _client_handshake_state = response_line_protocol;
744 else
745 _client_handshake_state = client_handshake_error;
746 break;
747 case response_line_protocol:
748 if (c == 's')
749 _client_handshake_state = response_line_protocols;
750 else
751 _client_handshake_state = client_handshake_error;
752 break;
753 case response_line_protocols:
754 if (c == '\r')
755 _client_handshake_state = response_line_cr;
756 else
757 _client_handshake_state = client_handshake_error;
758 break;
759 case response_line_cr:
760 if (c == '\n')
761 _client_handshake_state = client_header_field_begin_name;
762 else
763 _client_handshake_state = client_handshake_error;
764 break;
765 case client_header_field_begin_name:
766 switch (c) {
767 case '\r':
768 _client_handshake_state = client_handshake_end_line_cr;
769 break;
770 case '\n':
771 _client_handshake_state = client_handshake_error;
772 break;
773 default:
774 _header_name[0] = c;
775 _header_name_position = 1;
776 _client_handshake_state = client_header_field_name;
777 break;
778 }
779 break;
780 case client_header_field_name:
781 if (c == '\r' || c == '\n')
782 _client_handshake_state = client_handshake_error;
783 else if (c == ':') {
784 _header_name[_header_name_position] = '\0';
785 _client_handshake_state = client_header_field_colon;
786 } else if (_header_name_position + 1 > MAX_HEADER_NAME_LENGTH)
787 _client_handshake_state = client_handshake_error;
788 else {
789 _header_name[_header_name_position] = c;
790 _header_name_position++;
791 _client_handshake_state = client_header_field_name;
792 }
793 break;
794 case client_header_field_colon:
795 case client_header_field_value_trailing_space:
796 if (c == '\n')
797 _client_handshake_state = client_handshake_error;
798 else if (c == '\r')
799 _client_handshake_state = client_header_field_cr;
800 else if (c == ' ')
801 _client_handshake_state =
802 client_header_field_value_trailing_space;
803 else {
804 _header_value[0] = c;
805 _header_value_position = 1;
806 _client_handshake_state = client_header_field_value;
807 }
808 break;
809 case client_header_field_value:
810 if (c == '\n')
811 _client_handshake_state = client_handshake_error;
812 else if (c == '\r') {
813 _header_value[_header_value_position] = '\0';
814
815 if (strcasecmp ("upgrade", _header_name) == 0)
816 _header_upgrade_websocket =
817 strcasecmp ("websocket", _header_value) == 0;
818 else if (strcasecmp ("connection", _header_name) == 0)
819 _header_connection_upgrade =
820 strcasecmp ("upgrade", _header_value) == 0;
821 else if (strcasecmp ("Sec-WebSocket-Accept", _header_name)
822 == 0)
823 strcpy (_websocket_accept, _header_value);
824 else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
825 == 0) {
826 if (select_protocol (_header_value))
827 strcpy (_websocket_protocol, _header_value);
828 }
829 _client_handshake_state = client_header_field_cr;
830 } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)
831 _client_handshake_state = client_handshake_error;
832 else {
833 _header_value[_header_value_position] = c;
834 _header_value_position++;
835 _client_handshake_state = client_header_field_value;
836 }
837 break;
838 case client_header_field_cr:
839 if (c == '\n')
840 _client_handshake_state = client_header_field_begin_name;
841 else
842 _client_handshake_state = client_handshake_error;
843 break;
844 case client_handshake_end_line_cr:
845 if (c == '\n') {
846 if (_header_connection_upgrade && _header_upgrade_websocket
847 && _websocket_protocol[0] != '\0'
848 && _websocket_accept[0] != '\0') {
849 _client_handshake_state = client_handshake_complete;
850
851 // TODO: validate accept key
852
853 _inpos++;
854 _insize--;
855
856 return true;
857 }
858 _client_handshake_state = client_handshake_error;
859 } else
860 _client_handshake_state = client_handshake_error;
861 break;
862 default:
863 assert (false);
864 }
865
866 _inpos++;
867 _insize--;
868
869 if (_client_handshake_state == client_handshake_error) {
870 socket ()->event_handshake_failed_protocol (
871 _endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
872
873 error (zmq::i_engine::protocol_error);
874 return false;
875 }
876 }
877
878 return false;
879}
880
881static int
882encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_)
883{
884 static const unsigned char base64enc_tab[65] =
885 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
886
887 int ii, io;
888 uint32_t v;
889 int rem;
890
891 for (io = 0, ii = 0, v = 0, rem = 0; ii < in_len_; ii++) {
892 unsigned char ch;
893 ch = in_[ii];
894 v = (v << 8) | ch;
895 rem += 8;
896 while (rem >= 6) {
897 rem -= 6;
898 if (io >= out_len_)
899 return -1; /* truncation is failure */
900 out_[io++] = base64enc_tab[(v >> rem) & 63];
901 }
902 }
903 if (rem) {
904 v <<= (6 - rem);
905 if (io >= out_len_)
906 return -1; /* truncation is failure */
907 out_[io++] = base64enc_tab[v & 63];
908 }
909 while (io & 3) {
910 if (io >= out_len_)
911 return -1; /* truncation is failure */
912 out_[io++] = '=';
913 }
914 if (io >= out_len_)
915 return -1; /* no room for null terminator */
916 out_[io] = 0;
917 return io;
918}
919
920static void compute_accept_key (char *key_, unsigned char *hash_)
921{
922 const char *magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
923#ifdef ZMQ_USE_NSS
924 unsigned int len;
925 HASH_HashType type = HASH_GetHashTypeByOidTag (SEC_OID_SHA1);
926 HASHContext *ctx = HASH_Create (type);
927 assert (ctx);
928
929 HASH_Begin (ctx);
930 HASH_Update (ctx, (unsigned char *) key_, (unsigned int) strlen (key_));
931 HASH_Update (ctx, (unsigned char *) magic_string,
932 (unsigned int) strlen (magic_string));
933 HASH_End (ctx, hash_, &len, SHA_DIGEST_LENGTH);
934 HASH_Destroy (ctx);
935#elif defined ZMQ_USE_BUILTIN_SHA1
936 sha1_ctxt ctx;
937 SHA1_Init (&ctx);
938 SHA1_Update (&ctx, (unsigned char *) key_, strlen (key_));
939 SHA1_Update (&ctx, (unsigned char *) magic_string, strlen (magic_string));
940
941 SHA1_Final (hash_, &ctx);
942#elif defined ZMQ_USE_GNUTLS
943 gnutls_hash_hd_t hd;
944 gnutls_hash_init (&hd, GNUTLS_DIG_SHA1);
945 gnutls_hash (hd, key_, strlen (key_));
946 gnutls_hash (hd, magic_string, strlen (magic_string));
947 gnutls_hash_deinit (hd, hash_);
948#else
949#error "No sha1 implementation set"
950#endif
951}
952