1/*
2 * IXWebSocketHandshake.h
3 * Author: Benjamin Sergeant
4 * Copyright (c) 2019 Machine Zone, Inc. All rights reserved.
5 */
6
7#include "IXWebSocketHandshake.h"
8
9#include "IXHttp.h"
10#include "IXSocketConnect.h"
11#include "IXStrCaseCompare.h"
12#include "IXUrlParser.h"
13#include "IXUserAgent.h"
14#include "IXWebSocketHandshakeKeyGen.h"
15#include <algorithm>
16#include <iostream>
17#include <random>
18#include <sstream>
19
20
21namespace ix
22{
23 WebSocketHandshake::WebSocketHandshake(
24 std::atomic<bool>& requestInitCancellation,
25 std::unique_ptr<Socket>& socket,
26 WebSocketPerMessageDeflatePtr& perMessageDeflate,
27 WebSocketPerMessageDeflateOptions& perMessageDeflateOptions,
28 std::atomic<bool>& enablePerMessageDeflate)
29 : _requestInitCancellation(requestInitCancellation)
30 , _socket(socket)
31 , _perMessageDeflate(perMessageDeflate)
32 , _perMessageDeflateOptions(perMessageDeflateOptions)
33 , _enablePerMessageDeflate(enablePerMessageDeflate)
34 {
35 }
36
37 bool WebSocketHandshake::insensitiveStringCompare(const std::string& a, const std::string& b)
38 {
39 return CaseInsensitiveLess::cmp(a, b) == 0;
40 }
41
42 std::string WebSocketHandshake::genRandomString(const int len)
43 {
44 std::string alphanum = "0123456789"
45 "ABCDEFGH"
46 "abcdefgh";
47
48 std::random_device r;
49 std::default_random_engine e1(r());
50 std::uniform_int_distribution<int> dist(0, (int) alphanum.size() - 1);
51
52 std::string s;
53 s.resize(len);
54
55 for (int i = 0; i < len; ++i)
56 {
57 int x = dist(e1);
58 s[i] = alphanum[x];
59 }
60
61 return s;
62 }
63
64 WebSocketInitResult WebSocketHandshake::sendErrorResponse(int code, const std::string& reason)
65 {
66 std::stringstream ss;
67 ss << "HTTP/1.1 ";
68 ss << code;
69 ss << " ";
70 ss << reason;
71 ss << "\r\n";
72 ss << "Server: " << userAgent() << "\r\n";
73
74 // Socket write can only be cancelled through a timeout here, not manually.
75 static std::atomic<bool> requestInitCancellation(false);
76 auto isCancellationRequested =
77 makeCancellationRequestWithTimeout(1, requestInitCancellation);
78
79 if (!_socket->writeBytes(ss.str(), isCancellationRequested))
80 {
81 return WebSocketInitResult(false, 500, "Timed out while sending error response");
82 }
83
84 return WebSocketInitResult(false, code, reason);
85 }
86
87 WebSocketInitResult WebSocketHandshake::clientHandshake(
88 const std::string& url,
89 const WebSocketHttpHeaders& extraHeaders,
90 const std::string& host,
91 const std::string& path,
92 int port,
93 int timeoutSecs)
94 {
95 _requestInitCancellation = false;
96
97 auto isCancellationRequested =
98 makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
99
100 std::string errMsg;
101 bool success = _socket->connect(host, port, errMsg, isCancellationRequested);
102 if (!success)
103 {
104 std::stringstream ss;
105 ss << "Unable to connect to " << host << " on port " << port << ", error: " << errMsg;
106 return WebSocketInitResult(false, 0, ss.str());
107 }
108
109 //
110 // Generate a random 24 bytes string which looks like it is base64 encoded
111 // y3JJHMbDL1EzLkh9GBhXDw==
112 // 0cb3Vd9HkbpVVumoS3Noka==
113 //
114 // See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for
115 //
116 std::string secWebSocketKey = genRandomString(22);
117 secWebSocketKey += "==";
118
119 std::stringstream ss;
120 ss << "GET " << path << " HTTP/1.1\r\n";
121 ss << "Host: " << host << ":" << port << "\r\n";
122 ss << "Upgrade: websocket\r\n";
123 ss << "Connection: Upgrade\r\n";
124 ss << "Sec-WebSocket-Version: 13\r\n";
125 ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n";
126
127 // User-Agent can be customized by users
128 if (extraHeaders.find("User-Agent") == extraHeaders.end())
129 {
130 ss << "User-Agent: " << userAgent() << "\r\n";
131 }
132
133 for (auto& it : extraHeaders)
134 {
135 ss << it.first << ": " << it.second << "\r\n";
136 }
137
138 if (_enablePerMessageDeflate)
139 {
140 ss << _perMessageDeflateOptions.generateHeader();
141 }
142
143 ss << "\r\n";
144
145 if (!_socket->writeBytes(ss.str(), isCancellationRequested))
146 {
147 return WebSocketInitResult(
148 false, 0, std::string("Failed sending GET request to ") + url);
149 }
150
151 // Read HTTP status line
152 auto lineResult = _socket->readLine(isCancellationRequested);
153 auto lineValid = lineResult.first;
154 auto line = lineResult.second;
155
156 if (!lineValid)
157 {
158 return WebSocketInitResult(
159 false, 0, std::string("Failed reading HTTP status line from ") + url);
160 }
161
162 // Validate status
163 auto statusLine = Http::parseStatusLine(line);
164 std::string httpVersion = statusLine.first;
165 int status = statusLine.second;
166
167 // HTTP/1.0 is too old.
168 if (httpVersion != "HTTP/1.1")
169 {
170 std::stringstream ss;
171 ss << "Expecting HTTP/1.1, got " << httpVersion << ". "
172 << "Rejecting connection to " << url << ", status: " << status
173 << ", HTTP Status line: " << line;
174 return WebSocketInitResult(false, status, ss.str());
175 }
176
177 auto result = parseHttpHeaders(_socket, isCancellationRequested);
178 auto headersValid = result.first;
179 auto headers = result.second;
180
181 if (!headersValid)
182 {
183 return WebSocketInitResult(false, status, "Error parsing HTTP headers");
184 }
185
186 // We want an 101 HTTP status for websocket, otherwise it could be
187 // a redirection (like 301)
188 if (status != 101)
189 {
190 std::stringstream ss;
191 ss << "Expecting status 101 (Switching Protocol), got " << status
192 << " status connecting to " << url << ", HTTP Status line: " << line;
193
194 return WebSocketInitResult(false, status, ss.str(), headers, path);
195 }
196
197 // Check the presence of the connection field
198 if (headers.find("connection") == headers.end())
199 {
200 std::string errorMsg("Missing connection value");
201 return WebSocketInitResult(false, status, errorMsg);
202 }
203
204 // Check the value of the connection field
205 // Some websocket servers (Go/Gorilla?) send lowercase values for the
206 // connection header, so do a case insensitive comparison
207 //
208 // See https://github.com/apache/thrift/commit/7c4bdf9914fcba6c89e0f69ae48b9675578f084a
209 //
210 if (!insensitiveStringCompare(headers["connection"], "Upgrade"))
211 {
212 std::stringstream ss;
213 ss << "Invalid connection value: " << headers["connection"];
214 return WebSocketInitResult(false, status, ss.str());
215 }
216
217 char output[29] = {};
218 WebSocketHandshakeKeyGen::generate(secWebSocketKey, output);
219 if (std::string(output) != headers["sec-websocket-accept"])
220 {
221 std::string errorMsg("Invalid Sec-WebSocket-Accept value");
222 return WebSocketInitResult(false, status, errorMsg);
223 }
224
225 if (_enablePerMessageDeflate)
226 {
227 // Parse the server response. Does it support deflate ?
228 std::string header = headers["sec-websocket-extensions"];
229 WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);
230
231 // If the server does not support that extension, disable it.
232 if (!webSocketPerMessageDeflateOptions.enabled())
233 {
234 _enablePerMessageDeflate = false;
235 }
236 // Otherwise try to initialize the deflate engine (zlib)
237 else if (!_perMessageDeflate->init(webSocketPerMessageDeflateOptions))
238 {
239 return WebSocketInitResult(
240 false, 0, "Failed to initialize per message deflate engine");
241 }
242 }
243
244 return WebSocketInitResult(true, status, "", headers, path);
245 }
246
247 WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs,
248 bool enablePerMessageDeflate)
249 {
250 _requestInitCancellation = false;
251
252 auto isCancellationRequested =
253 makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation);
254
255 // Read first line
256 auto lineResult = _socket->readLine(isCancellationRequested);
257 auto lineValid = lineResult.first;
258 auto line = lineResult.second;
259
260 if (!lineValid)
261 {
262 return sendErrorResponse(400, "Error reading HTTP request line");
263 }
264
265 // Validate request line (GET /foo HTTP/1.1\r\n)
266 auto requestLine = Http::parseRequestLine(line);
267 auto method = std::get<0>(requestLine);
268 auto uri = std::get<1>(requestLine);
269 auto httpVersion = std::get<2>(requestLine);
270
271 if (method != "GET")
272 {
273 return sendErrorResponse(400, "Invalid HTTP method, need GET, got " + method);
274 }
275
276 if (httpVersion != "HTTP/1.1")
277 {
278 return sendErrorResponse(400,
279 "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion);
280 }
281
282 // Retrieve and validate HTTP headers
283 auto result = parseHttpHeaders(_socket, isCancellationRequested);
284 auto headersValid = result.first;
285 auto headers = result.second;
286
287 if (!headersValid)
288 {
289 return sendErrorResponse(400, "Error parsing HTTP headers");
290 }
291
292 if (headers.find("sec-websocket-key") == headers.end())
293 {
294 return sendErrorResponse(400, "Missing Sec-WebSocket-Key value");
295 }
296
297 if (headers.find("upgrade") == headers.end())
298 {
299 return sendErrorResponse(400, "Missing Upgrade header");
300 }
301
302 if (!insensitiveStringCompare(headers["upgrade"], "WebSocket") &&
303 headers["Upgrade"] != "keep-alive, Upgrade") // special case for firefox
304 {
305 return sendErrorResponse(400,
306 "Invalid Upgrade header, "
307 "need WebSocket, got " +
308 headers["upgrade"]);
309 }
310
311 if (headers.find("sec-websocket-version") == headers.end())
312 {
313 return sendErrorResponse(400, "Missing Sec-WebSocket-Version value");
314 }
315
316 {
317 std::stringstream ss;
318 ss << headers["sec-websocket-version"];
319 int version;
320 ss >> version;
321
322 if (version != 13)
323 {
324 return sendErrorResponse(400,
325 "Invalid Sec-WebSocket-Version, "
326 "need 13, got " +
327 ss.str());
328 }
329 }
330
331 char output[29] = {};
332 WebSocketHandshakeKeyGen::generate(headers["sec-websocket-key"], output);
333
334 std::stringstream ss;
335 ss << "HTTP/1.1 101 Switching Protocols\r\n";
336 ss << "Sec-WebSocket-Accept: " << std::string(output) << "\r\n";
337 ss << "Upgrade: websocket\r\n";
338 ss << "Connection: Upgrade\r\n";
339 ss << "Server: " << userAgent() << "\r\n";
340
341 // Parse the client headers. Does it support deflate ?
342 std::string header = headers["sec-websocket-extensions"];
343 WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header);
344
345 // If the client has requested that extension,
346 if (webSocketPerMessageDeflateOptions.enabled() && enablePerMessageDeflate)
347 {
348 _enablePerMessageDeflate = true;
349
350 if (!_perMessageDeflate->init(webSocketPerMessageDeflateOptions))
351 {
352 return WebSocketInitResult(
353 false, 0, "Failed to initialize per message deflate engine");
354 }
355 ss << webSocketPerMessageDeflateOptions.generateHeader();
356 }
357
358 ss << "\r\n";
359
360 if (!_socket->writeBytes(ss.str(), isCancellationRequested))
361 {
362 return WebSocketInitResult(
363 false, 0, std::string("Failed sending response to remote end"));
364 }
365
366 return WebSocketInitResult(true, 200, "", headers, uri);
367 }
368} // namespace ix
369